ATLAS Offline Software
OnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
3 #ifndef OnnxRuntimeInferenceTool_H
4 #define OnnxRuntimeInferenceTool_H
5 
6 #include "AsgTools/AsgTool.h"
9 
13 #include "AsgTools/ToolHandle.h"
14 
15 namespace AthOnnx {
16  // @class OnnxRuntimeInferenceTool
17  //
18  // @brief Tool to create Onnx Runtime session with CPU backend
19  //
20  // @author Xiangyang Ju <xiangyang.ju@cern.ch>
22  {
24  public:
26  OnnxRuntimeInferenceTool( const std::string& name );
27  virtual ~OnnxRuntimeInferenceTool() = default;
28 
30  virtual StatusCode initialize() override;
31 
32 
33  virtual void setBatchSize(int64_t batchSize) override final;
34  virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
35 
36  virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
37 
38  virtual void printModelInfo() const override final;
39 
40  virtual StatusCode inference(AthInfer::InputDataMap& inputData, AthInfer::OutputDataMap& outputData) const override final;
41 
42  protected:
46 
47  private:
49 
50  ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
51  ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
52  this, "ORTSessionTool",
53  "AthOnnx::OnnxRuntimeSessionToolCPU"
54  };
55  std::vector<std::string> m_inputNodeNames;
56  std::vector<std::string> m_outputNodeNames;
57  };
58 } // namespace AthOnnx
59 
60 #endif
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:72
AthOnnx::IOnnxRuntimeSvc
Service used for managing global objects used by Onnx Runtime.
Definition: IOnnxRuntimeSvc.h:25
AthOnnx::IOnnxRuntimeInferenceTool
Interface class for creating Onnx Runtime sessions.
Definition: IOnnxRuntimeInferenceTool.h:48
AthInfer
Definition: IAthInferenceTool.h:12
AthOnnx::OnnxRuntimeInferenceTool::initialize
virtual StatusCode initialize() override
Initialize the tool.
Definition: OnnxRuntimeInferenceTool.cxx:15
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
IOnnxRuntimeInferenceTool.h
AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
Definition: OnnxRuntimeInferenceTool.h:50
Value
tag-value pair class.
Definition: Value.h:39
const
bool const RAWDATA *ch2 const
Definition: LArRodBlockPhysicsV0.cxx:560
protected
#define protected
Definition: DetDescrConditionsDict_dict_fixes.cxx:14
AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
std::vector< std::string > m_inputNodeNames
Definition: OnnxRuntimeInferenceTool.h:55
AthOnnx::OnnxRuntimeInferenceTool::setBatchSize
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
Definition: OnnxRuntimeInferenceTool.cxx:42
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
vector
Definition: MultiHisto.h:13
AthInfer::OutputDataMap
std::map< std::string, InferenceData > OutputDataMap
Definition: IAthInferenceTool.h:17
IAthInferenceTool.h
IOnnxRuntimeSessionTool.h
private
#define private
Definition: DetDescrConditionsDict_dict_fixes.cxx:13
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
ServiceHandle.h
AthOnnx::OnnxRuntimeInferenceTool::~OnnxRuntimeInferenceTool
virtual ~OnnxRuntimeInferenceTool()=default
IOnnxRuntimeSvc.h
AthOnnx::OnnxRuntimeInferenceTool::printModelInfo
virtual void printModelInfo() const override final
Definition: OnnxRuntimeInferenceTool.cxx:86
calibdata.delete
list delete
Definition: calibdata.py:46
AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
std::vector< std::string > m_outputNodeNames
Definition: OnnxRuntimeInferenceTool.h:56
AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool
OnnxRuntimeInferenceTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: OnnxRuntimeInferenceTool.h:51
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:28
ToolHandle.h
AsgTool.h
AthInfer::InputDataMap
std::map< std::string, InferenceData > InputDataMap
Definition: IAthInferenceTool.h:16
AthOnnx::OnnxRuntimeInferenceTool
Definition: OnnxRuntimeInferenceTool.h:22
AthInfer::IAthInferenceTool
Definition: IAthInferenceTool.h:24
AthOnnx::OnnxRuntimeInferenceTool::getBatchSize
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
Definition: OnnxRuntimeInferenceTool.cxx:62
ServiceHandle
Definition: ClusterMakerTool.h:37
AthOnnx
Namespace holding all of the Onnx Runtime example code.
Definition: EvaluateModel.cxx:11