ATLAS Offline Software
OnnxRuntimeSessionToolCUDA.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 /*
5  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
6 */
7 
9 
11  : asg::AsgTool(name)
12 {
13 }
14 
16 {
17  // Get the Onnx Runtime service.
18  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
19 
20  ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
21  // Create the session options.
22  Ort::SessionOptions sessionOptions;
23  sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
24  sessionOptions.DisablePerSessionThreads(); // use global thread pool.
25 
26  // TODO: add more cuda options to the interface
27  // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
28  // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
29  OrtCUDAProviderOptions cuda_options;
30  cuda_options.device_id = m_deviceId;
31  cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
32  cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
33 
34  // memorry arena options for cuda memory shrinkage
35  // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
36  if (m_enableMemoryShrinkage) {
37  Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
38  // other options are not available in this release.
39  // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
40  // arena_cfg.max_mem = 0; // let ORT pick default max memory
41  // arena_cfg.arena_extend_strategy = 0; // 0: kNextPowerOfTwo, 1: kSameAsRequested
42  // arena_cfg.initial_chunk_size_bytes = 1024;
43  // arena_cfg.max_dead_bytes_per_chunk = 0;
44  // arena_cfg.initial_growth_chunk_size_bytes = 256;
45  // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
46 
47  cuda_options.default_memory_arena_cfg = arena_cfg;
48  }
49 
50  sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
51 
52  // Create the session.
53  m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
54 
55  return StatusCode::SUCCESS;
56 }
57 
59 {
60  return *m_session;
61 }
max
#define max(a, b)
Definition: cfImp.cxx:41
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthOnnx::OnnxRuntimeSessionToolCUDA::session
virtual Ort::Session & session() const override final
Create Onnx Runtime session.
Definition: OnnxRuntimeSessionToolCUDA.cxx:58
asg
Definition: DataHandleTestTool.h:28
python.oracle.Session
Session
Definition: oracle.py:78
OnnxRuntimeSessionToolCUDA.h
AthOnnx::OnnxRuntimeSessionToolCUDA::initialize
virtual StatusCode initialize() override final
Initialize the tool.
Definition: OnnxRuntimeSessionToolCUDA.cxx:15
AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA
OnnxRuntimeSessionToolCUDA()=delete
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195