ATLAS Offline Software
OnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
3 #ifndef OnnxRuntimeInferenceTool_H
4 #define OnnxRuntimeInferenceTool_H
5 
6 #include "AsgTools/AsgTool.h"
11 #include "AsgTools/ToolHandle.h"
12 
13 namespace AthOnnx {
14  // @class OnnxRuntimeInferenceTool
15  //
16  // @brief Tool to create Onnx Runtime session with CPU backend
17  //
18  // @author Xiangyang Ju <xiangyang.ju@cern.ch>
20  {
22  public:
24  OnnxRuntimeInferenceTool( const std::string& name );
25  virtual ~OnnxRuntimeInferenceTool() = default;
26 
28  virtual StatusCode initialize() override;
29 
30 
31  virtual void setBatchSize(int64_t batchSize) override final;
32  virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
33 
34  virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
35 
36  virtual void printModelInfo() const override final;
37 
38  protected:
42 
43  private:
45 
46  ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
47  ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
48  this, "ORTSessionTool",
49  "AthOnnx::OnnxRuntimeSessionToolCPU"
50  };
51  std::vector<std::string> m_inputNodeNames;
52  std::vector<std::string> m_outputNodeNames;
53  };
54 } // namespace AthOnnx
55 
56 #endif
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:72
AthOnnx::IOnnxRuntimeSvc
Service used for managing global objects used by Onnx Runtime.
Definition: IOnnxRuntimeSvc.h:25
AthOnnx::IOnnxRuntimeInferenceTool
Interface class for creating Onnx Runtime sessions.
Definition: IOnnxRuntimeInferenceTool.h:48
AthOnnx::OnnxRuntimeInferenceTool::initialize
virtual StatusCode initialize() override
Initialize the tool.
Definition: OnnxRuntimeInferenceTool.cxx:15
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
IOnnxRuntimeInferenceTool.h
AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
Definition: OnnxRuntimeInferenceTool.h:46
Value
tag-value pair class.
Definition: Value.h:39
const
bool const RAWDATA *ch2 const
Definition: LArRodBlockPhysicsV0.cxx:562
protected
#define protected
Definition: DetDescrConditionsDict_dict_fixes.cxx:14
AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
std::vector< std::string > m_inputNodeNames
Definition: OnnxRuntimeInferenceTool.h:51
AthOnnx::OnnxRuntimeInferenceTool::setBatchSize
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
Definition: OnnxRuntimeInferenceTool.cxx:42
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
vector
Definition: MultiHisto.h:13
IOnnxRuntimeSessionTool.h
private
#define private
Definition: DetDescrConditionsDict_dict_fixes.cxx:13
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
ServiceHandle.h
AthOnnx::OnnxRuntimeInferenceTool::~OnnxRuntimeInferenceTool
virtual ~OnnxRuntimeInferenceTool()=default
IOnnxRuntimeSvc.h
AthOnnx::OnnxRuntimeInferenceTool::printModelInfo
virtual void printModelInfo() const override final
Definition: OnnxRuntimeInferenceTool.cxx:86
calibdata.delete
list delete
Definition: calibdata.py:46
AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
std::vector< std::string > m_outputNodeNames
Definition: OnnxRuntimeInferenceTool.h:52
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool
OnnxRuntimeInferenceTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: OnnxRuntimeInferenceTool.h:47
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:28
ToolHandle.h
AsgTool.h
AthOnnx::OnnxRuntimeInferenceTool
Definition: OnnxRuntimeInferenceTool.h:20
AthOnnx::OnnxRuntimeInferenceTool::getBatchSize
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
Definition: OnnxRuntimeInferenceTool.cxx:62
ServiceHandle
Definition: ClusterMakerTool.h:37
AthOnnx
Namespace holding all of the Onnx Runtime example code.
Definition: EvaluateModel.cxx:13