ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeSessionToolCUDA.h
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3#ifndef OnnxRuntimeSessionToolCUDA_H
4#define OnnxRuntimeSessionToolCUDA_H
5
6#include "AsgTools/AsgTool.h"
11
12#include <string>
13
14namespace AthOnnx {
15 // @class OnnxRuntimeSessionToolCUDA
16 //
17 // @brief Tool to create Onnx Runtime session with CUDA backend
18 //
19 // @author Xiangyang Ju <xiangyang.ju@cern.ch>
21 {
23 public:
25 OnnxRuntimeSessionToolCUDA( const std::string& name);
26 virtual ~OnnxRuntimeSessionToolCUDA() = default;
27
29 virtual StatusCode initialize() override final;
30
32 virtual Ort::Session& session() const override final;
33
35 virtual bool supportsAsync() const override final;
36
38 int deviceId() const { return m_deviceId; }
39
40 protected:
44
45 private:
46 Gaudi::Property<std::string> m_modelFileName{this, "ModelFileName", "", "The model file name"};
48 Gaudi::Property<int> m_deviceId{this, "DeviceId", 0, "Device ID to use"};
49 Gaudi::Property<bool> m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false, "Enable automatic memory shrinkage"};
50
52 ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"};
53 std::unique_ptr<Ort::Session> m_session;
54 };
55}
56
57#endif
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
OnnxRuntimeSessionToolCUDA(const std::string &name)
Standard constructor.
virtual ~OnnxRuntimeSessionToolCUDA()=default
OnnxRuntimeSessionToolCUDA(const OnnxRuntimeSessionToolCUDA &)=delete
virtual bool supportsAsync() const override final
Check if asynchronous inference is supported (yes, it is).
OnnxRuntimeSessionToolCUDA & operator=(const OnnxRuntimeSessionToolCUDA &)=delete
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
runtime service
Gaudi::Property< int > m_deviceId
The device ID to use.
virtual Ort::Session & session() const override final
Create Onnx Runtime session.
Gaudi::Property< std::string > m_modelFileName
int deviceId() const
Device ID passed to the CUDA provider (needed to build Ort::MemoryInfo for IoBinding).
virtual StatusCode initialize() override final
Initialize the tool.
std::unique_ptr< Ort::Session > m_session
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47
Namespace holding all of the Onnx Runtime example code.