ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3#ifndef OnnxRuntimeInferenceTool_H
4#define OnnxRuntimeInferenceTool_H
5
6#include "AsgTools/AsgTool.h"
9
13#include "AsgTools/ToolHandle.h"
14
15// Forward declaration
17
18namespace AthOnnx {
19 // @class OnnxRuntimeInferenceTool
20 //
21 // @brief Tool to create Onnx Runtime session with CPU backend
22 //
23 // @author Xiangyang Ju <xiangyang.ju@cern.ch>
25 {
27 public:
29 OnnxRuntimeInferenceTool( const std::string& name );
30 virtual ~OnnxRuntimeInferenceTool() = default;
31
33 virtual StatusCode initialize() override;
34
35
36 virtual void setBatchSize(int64_t batchSize) override final;
37 virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
38
39 virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
40
41 virtual void printModelInfo() const override final;
42
43 virtual StatusCode inference(AthInfer::InputDataMap& inputData, AthInfer::OutputDataMap& outputData) const override final;
44
49
50 private:
51 StatusCode getNodeInfo();
52
53 ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"};
54 ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
55 this, "ORTSessionTool",
56 "AthOnnx::OnnxRuntimeSessionToolCPU/OnnxRuntimeSessionTool",
57 "The Onnx session tool"
58 };
59 std::vector<std::string> m_inputNodeNames;
60 std::vector<std::string> m_outputNodeNames;
61
62 // pointer to parent AthAsynchronousAlgorithm if one exists
64 };
65} // namespace AthOnnx
66
67#endif
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
An algorithm that can be suspended while work is offloaded to an accelerator.
Interface class for creating Onnx Runtime sessions.
Service used for managing global objects used by Onnx Runtime.
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
std::vector< std::string > m_inputNodeNames
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
OnnxRuntimeInferenceTool(const std::string &name)
Standard constructor.
virtual StatusCode initialize() override
Initialize the tool.
virtual void printModelInfo() const override final
const AthAsynchronousAlgorithm * m_parentAsyncAlg
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
std::vector< std::string > m_outputNodeNames
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
virtual ~OnnxRuntimeInferenceTool()=default
tag-value pair class.
Definition Value.h:39
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47
Namespace holding all of the Onnx Runtime example code.
STL namespace.
#define protected
Definition testRead.cxx:26
#define private
Definition testRead.cxx:27