ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3#ifndef OnnxRuntimeInferenceTool_H
4#define OnnxRuntimeInferenceTool_H
5
6#include "AsgTools/AsgTool.h"
9
13#include "AsgTools/ToolHandle.h"
14
15namespace AthOnnx {
16 // @class OnnxRuntimeInferenceTool
17 //
18 // @brief Tool to create Onnx Runtime session with CPU backend
19 //
20 // @author Xiangyang Ju <xiangyang.ju@cern.ch>
22 {
24 public:
26 OnnxRuntimeInferenceTool( const std::string& name );
27 virtual ~OnnxRuntimeInferenceTool() = default;
28
30 virtual StatusCode initialize() override;
31
32
33 virtual void setBatchSize(int64_t batchSize) override final;
34 virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
35
36 virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
37
38 virtual void printModelInfo() const override final;
39
40 virtual StatusCode inference(AthInfer::InputDataMap& inputData, AthInfer::OutputDataMap& outputData) const override final;
41
46
47 private:
48 StatusCode getNodeInfo();
49
50 ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"};
51 ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
52 this, "ORTSessionTool",
53 "AthOnnx::OnnxRuntimeSessionToolCPU/OnnxRuntimeSessionTool",
54 "The Onnx session tool"
55 };
56 std::vector<std::string> m_inputNodeNames;
57 std::vector<std::string> m_outputNodeNames;
58 };
59} // namespace AthOnnx
60
61#endif
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
#define protected
Interface class for creating Onnx Runtime sessions.
Service used for managing global objects used by Onnx Runtime.
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
std::vector< std::string > m_inputNodeNames
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
OnnxRuntimeInferenceTool(const std::string &name)
Standard constructor.
virtual StatusCode initialize() override
Initialize the tool.
virtual void printModelInfo() const override final
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
std::vector< std::string > m_outputNodeNames
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
virtual ~OnnxRuntimeInferenceTool()=default
tag-value pair class.
Definition Value.h:39
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47
Namespace holding all of the Onnx Runtime example code.
STL namespace.
#define private