ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeSessionToolCUDA.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
7
12
14{
15 // Get the Onnx Runtime service.
16 ATH_CHECK(m_onnxRuntimeSvc.retrieve());
17
18 ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
19 // Create the session options.
20 Ort::SessionOptions sessionOptions;
21 sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
22 sessionOptions.DisablePerSessionThreads(); // use global thread pool.
23
24 // TODO: add more cuda options to the interface
25 // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
26 // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
27 OrtCUDAProviderOptions cuda_options;
28 cuda_options.device_id = m_deviceId;
29 cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
30 cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
31
32 // memorry arena options for cuda memory shrinkage
33 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
35 Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
36 // other options are not available in this release.
37 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
38 // arena_cfg.max_mem = 0; // let ORT pick default max memory
39 // arena_cfg.arena_extend_strategy = 0; // 0: kNextPowerOfTwo, 1: kSameAsRequested
40 // arena_cfg.initial_chunk_size_bytes = 1024;
41 // arena_cfg.max_dead_bytes_per_chunk = 0;
42 // arena_cfg.initial_growth_chunk_size_bytes = 256;
43 // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
44
45 cuda_options.default_memory_arena_cfg = arena_cfg;
46 }
47
48 sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
49
50 // Create the session.
51 ATH_MSG_INFO("Asking model from: " << m_modelFileName.value());
52 std::string modelFilePath = PathResolver::find_calib_file(m_modelFileName.value());
53 ATH_MSG_INFO("Loading model from: " << modelFilePath);
54 m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFilePath.c_str(), sessionOptions);
55
56 return StatusCode::SUCCESS;
57}
58
60{
61 return *m_session;
62}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_INFO(x)
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
runtime service
Gaudi::Property< int > m_deviceId
The device ID to use.
virtual Ort::Session & session() const override final
Create Onnx Runtime session.
Gaudi::Property< std::string > m_modelFileName
virtual StatusCode initialize() override final
Initialize the tool.
std::unique_ptr< Ort::Session > m_session
static std::string find_calib_file(const std::string &logical_file_name)
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58