ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3#ifndef OnnxRuntimeInferenceTool_H
4#define OnnxRuntimeInferenceTool_H
5
6#include "AsgTools/AsgTool.h"
9
13#include "AsgTools/ToolHandle.h"
14
15// Forward declaration
17
18namespace AthOnnx {
19 // @class OnnxRuntimeInferenceTool
20 //
21 // @brief Tool to create Onnx Runtime session with CPU backend
22 //
23 // @author Xiangyang Ju <xiangyang.ju@cern.ch>
25 {
27 public:
29 OnnxRuntimeInferenceTool( const std::string& name );
30 virtual ~OnnxRuntimeInferenceTool() = default;
31
33 virtual StatusCode initialize() override;
34
35
36 virtual void setBatchSize(int64_t batchSize) override final;
37 virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
38
39 virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
40
41 virtual void printModelInfo() const override final;
42
43 virtual StatusCode inference(AthInfer::InputDataMap& inputData, AthInfer::OutputDataMap& outputData) const override final;
44
49
50 private:
51 StatusCode getNodeInfo();
52
53 ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"};
54 ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
55 this, "ORTSessionTool",
56 "AthOnnx::OnnxRuntimeSessionToolCPU/OnnxRuntimeSessionTool",
57 "The Onnx session tool"
58 };
59 std::vector<std::string> m_inputNodeNames;
60 std::vector<std::string> m_outputNodeNames;
61
62 // pointer to parent AthAsynchronousAlgorithm if one exists
64 };
65} // namespace AthOnnx
66
67#endif
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
#define protected
An algorithm that can be suspended while work is offloaded to an accelerator.
Interface class for creating Onnx Runtime sessions.
Service used for managing global objects used by Onnx Runtime.
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
std::vector< std::string > m_inputNodeNames
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
OnnxRuntimeInferenceTool(const std::string &name)
Standard constructor.
virtual StatusCode initialize() override
Initialize the tool.
virtual void printModelInfo() const override final
const AthAsynchronousAlgorithm * m_parentAsyncAlg
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
std::vector< std::string > m_outputNodeNames
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
virtual ~OnnxRuntimeInferenceTool()=default
tag-value pair class.
Definition Value.h:39
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47
Namespace holding all of the Onnx Runtime example code.
STL namespace.
#define private