Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
OnnxRuntimeSessionToolCUDA.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9  : asg::AsgTool(name)
10 {
11 }
12 
14 {
15  // Get the Onnx Runtime service.
16  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
17 
18  ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
19  // Create the session options.
20  Ort::SessionOptions sessionOptions;
21  sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
22  sessionOptions.DisablePerSessionThreads(); // use global thread pool.
23 
24  // TODO: add more cuda options to the interface
25  // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
26  // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
27  OrtCUDAProviderOptions cuda_options;
28  cuda_options.device_id = m_deviceId;
29  cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
30  cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
31 
32  // memorry arena options for cuda memory shrinkage
33  // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
34  if (m_enableMemoryShrinkage) {
35  Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
36  // other options are not available in this release.
37  // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
38  // arena_cfg.max_mem = 0; // let ORT pick default max memory
39  // arena_cfg.arena_extend_strategy = 0; // 0: kNextPowerOfTwo, 1: kSameAsRequested
40  // arena_cfg.initial_chunk_size_bytes = 1024;
41  // arena_cfg.max_dead_bytes_per_chunk = 0;
42  // arena_cfg.initial_growth_chunk_size_bytes = 256;
43  // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
44 
45  cuda_options.default_memory_arena_cfg = arena_cfg;
46  }
47 
48  sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
49 
50  // Create the session.
51  ATH_MSG_INFO("Asking model from: " << m_modelFileName.value());
52  std::string modelFilePath = PathResolver::find_file(m_modelFileName.value(), "CALIBPATH", PathResolver::RecursiveSearch);
53  ATH_MSG_INFO("Loading model from: " << modelFilePath);
54  m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFilePath.c_str(), sessionOptions);
55 
56  return StatusCode::SUCCESS;
57 }
58 
60 {
61  return *m_session;
62 }
PathResolver::RecursiveSearch
@ RecursiveSearch
Definition: PathResolver.h:28
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
PathResolver::find_file
static std::string find_file(const std::string &logical_file_name, const std::string &search_path, SearchType search_type=LocalSearch)
Definition: PathResolver.cxx:251
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
AthOnnx::OnnxRuntimeSessionToolCUDA::session
virtual Ort::Session & session() const override final
Create Onnx Runtime session.
Definition: OnnxRuntimeSessionToolCUDA.cxx:59
asg
Definition: DataHandleTestTool.h:28
python.oracle.Session
Session
Definition: oracle.py:78
OnnxRuntimeSessionToolCUDA.h
AthOnnx::OnnxRuntimeSessionToolCUDA::initialize
virtual StatusCode initialize() override final
Initialize the tool.
Definition: OnnxRuntimeSessionToolCUDA.cxx:13
AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA
OnnxRuntimeSessionToolCUDA()=delete
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240