ATLAS Offline Software
Loading...
Searching...
No Matches
AthOnnx::OnnxRuntimeInferenceTool Class Reference

#include <OnnxRuntimeInferenceTool.h>

Inheritance diagram for AthOnnx::OnnxRuntimeInferenceTool:
Collaboration diagram for AthOnnx::OnnxRuntimeInferenceTool:

Public Member Functions

 OnnxRuntimeInferenceTool (const std::string &name)
 Standard constructor.
virtual ~OnnxRuntimeInferenceTool ()=default
virtual StatusCode initialize () override
 Initialize the tool.
virtual void setBatchSize (int64_t batchSize) override final
 set batch size.
virtual int64_t getBatchSize (int64_t inputDataSize, int idx=0) const override final
 methods for determining batch size from the data size
virtual StatusCode inference (std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
 perform inference
virtual void printModelInfo () const override final
virtual StatusCode inference (AthInfer::InputDataMap &inputData, AthInfer::OutputDataMap &outputData) const override final
virtual void print () const
 Print the state of the tool.
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const
template<typename T>
StatusCode addInput (std::vector< Ort::Value > &inputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
 add the input data to the input tensors
template<typename T>
StatusCode addOutput (std::vector< Ort::Value > &outputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
 add the output data to the output tensors
Additional helper functions, not directly mimicking Athena
template<class T>
const T * getProperty (const std::string &name) const
 Get one of the tool's properties.
const std::string & msg_level_name () const __attribute__((deprecated))
 A deprecated function for getting the message level's name.
const std::string & getName (const void *ptr) const
 Get the name of an object that is / should be in the event store.
SG::sgkey_t getKey (const void *ptr) const
 Get the (hashed) key of an object that is in the event store.

Protected Member Functions

 OnnxRuntimeInferenceTool ()=delete
 OnnxRuntimeInferenceTool (const OnnxRuntimeInferenceTool &)=delete
OnnxRuntimeInferenceTooloperator= (const OnnxRuntimeInferenceTool &)=delete
void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Protected Attributes

unsigned m_numInputs
unsigned m_numOutputs
std::vector< std::vector< int64_t > > m_inputShapes
std::vector< std::vector< int64_t > > m_outputShapes

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

StatusCode getNodeInfo ()
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>
template<typename T>
Ort::Value createTensor (std::vector< T > &data, const std::vector< int64_t > &dataShape, int64_t batchSize) const

Private Attributes

ServiceHandle< IOnnxRuntimeSvcm_onnxRuntimeSvc {this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"}
ToolHandle< IOnnxRuntimeSessionToolm_onnxSessionTool
std::vector< std::string > m_inputNodeNames
std::vector< std::string > m_outputNodeNames
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default)
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default)
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared

Detailed Description

Definition at line 21 of file OnnxRuntimeInferenceTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ OnnxRuntimeInferenceTool() [1/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( const std::string & name)

Standard constructor.

Definition at line 8 of file OnnxRuntimeInferenceTool.cxx.

◆ ~OnnxRuntimeInferenceTool()

virtual AthOnnx::OnnxRuntimeInferenceTool::~OnnxRuntimeInferenceTool ( )
virtualdefault

◆ OnnxRuntimeInferenceTool() [2/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( )
protecteddelete

◆ OnnxRuntimeInferenceTool() [3/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( const OnnxRuntimeInferenceTool & )
protecteddelete

Member Function Documentation

◆ addInput()

template<typename T>
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput ( std::vector< Ort::Value > & inputTensors,
std::vector< T > & data,
unsigned idx = 0,
int64_t batchSize = -1 ) const
inherited

add the input data to the input tensors

Parameters
inputTensorsthe input tensor container
datathe input data
idxthe index of the input node
batchSizethe batch size
Returns
StatusCode::SUCCESS if the input data is added successfully

Definition at line 24 of file IOnnxRuntimeInferenceTool.h.

◆ addOutput()

template<typename T>
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput ( std::vector< Ort::Value > & outputTensors,
std::vector< T > & data,
unsigned idx = 0,
int64_t batchSize = -1 ) const
inherited

add the output data to the output tensors

Parameters
outputTensorsthe output tensor container
datathe output data
idxthe index of the output node
batchSizethe batch size
Returns
StatusCode::SUCCESS if the output data is added successfully

Definition at line 35 of file IOnnxRuntimeInferenceTool.h.

◆ createTensor()

template<typename T>
Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor ( std::vector< T > & data,
const std::vector< int64_t > & dataShape,
int64_t batchSize ) const
privateinherited

Definition at line 4 of file IOnnxRuntimeInferenceTool.h.

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ getBatchSize()

int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize ( int64_t dataSize,
int idx = 0 ) const
finaloverridevirtual

methods for determining batch size from the data size

Parameters
dataSizethe size of the input data, like std::vector<T>::size()
idxthe index of the input node
Returns
the batch size, which equals to dataSize / size of the rest dimensions.

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 60 of file OnnxRuntimeInferenceTool.cxx.

◆ getKey()

SG::sgkey_t asg::AsgTool::getKey ( const void * ptr) const
inherited

Get the (hashed) key of an object that is in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the SG::sgkey_t key for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getName
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The hashed key of the object in the store. If not found, an invalid (zero) key.

Definition at line 119 of file AsgTool.cxx.

◆ getName()

const std::string & asg::AsgTool::getName ( const void * ptr) const
inherited

Get the name of an object that is / should be in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the std::string name for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getKey
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The string name of the object in the store. If not found, an empty string.

Definition at line 106 of file AsgTool.cxx.

◆ getNodeInfo()

StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo ( )
private

Definition at line 26 of file OnnxRuntimeInferenceTool.cxx.

◆ getProperty()

template<class T>
const T * asg::AsgTool::getProperty ( const std::string & name) const
inherited

Get one of the tool's properties.

◆ inference() [1/2]

StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference ( AthInfer::InputDataMap & inputData,
AthInfer::OutputDataMap & outputData ) const
finaloverridevirtual

Implements AthInfer::IAthInferenceTool.

Definition at line 118 of file OnnxRuntimeInferenceTool.cxx.

◆ inference() [2/2]

StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference ( std::vector< Ort::Value > & inputTensors,
std::vector< Ort::Value > & outputTensors ) const
finaloverridevirtual

perform inference

Parameters
inputTensorsthe input tensor container
outputTensorsthe output tensor container
Returns
StatusCode::SUCCESS if the inference is performed successfully

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 70 of file OnnxRuntimeInferenceTool.cxx.

◆ initialize()

StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize ( void )
overridevirtual

Initialize the tool.

Reimplemented from asg::AsgTool.

Definition at line 13 of file OnnxRuntimeInferenceTool.cxx.

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ msg()

MsgStream & AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

◆ msg_level_name()

const std::string & asg::AsgTool::msg_level_name ( ) const
inherited

A deprecated function for getting the message level's name.

Instead of using this, weirdly named function, user code should get the string name of the current minimum message level (in case they really need it...), with:

MSG::name( msg().level() )

This function's name doesn't follow the ATLAS coding rules, and as such will be removed in the not too distant future.

Returns
The string name of the current minimum message level that's printed

Definition at line 101 of file AsgTool.cxx.

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

◆ operator=()

OnnxRuntimeInferenceTool & AthOnnx::OnnxRuntimeInferenceTool::operator= ( const OnnxRuntimeInferenceTool & )
protecteddelete

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ print()

◆ printModelInfo()

void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo ( ) const
finaloverridevirtual

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 84 of file OnnxRuntimeInferenceTool.cxx.

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

◆ setBatchSize()

void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize ( int64_t batchSize)
finaloverridevirtual

set batch size.

If the model has dynamic batch size, the batchSize value will be set to both input shapes and output shapes

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 40 of file OnnxRuntimeInferenceTool.cxx.

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in asg::AsgMetadataTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and DerivationFramework::CfAthAlgTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

Member Data Documentation

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_inputNodeNames

std::vector<std::string> AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
private

Definition at line 56 of file OnnxRuntimeInferenceTool.h.

◆ m_inputShapes

std::vector<std::vector<int64_t> > AthOnnx::IOnnxRuntimeInferenceTool::m_inputShapes
protectedinherited

Definition at line 104 of file IOnnxRuntimeInferenceTool.h.

◆ m_numInputs

unsigned AthOnnx::IOnnxRuntimeInferenceTool::m_numInputs
protectedinherited

Definition at line 102 of file IOnnxRuntimeInferenceTool.h.

◆ m_numOutputs

unsigned AthOnnx::IOnnxRuntimeInferenceTool::m_numOutputs
protectedinherited

Definition at line 103 of file IOnnxRuntimeInferenceTool.h.

◆ m_onnxRuntimeSvc

ServiceHandle<IOnnxRuntimeSvc> AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc {this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"}
private

Definition at line 50 of file OnnxRuntimeInferenceTool.h.

◆ m_onnxSessionTool

ToolHandle<IOnnxRuntimeSessionTool> AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
private
Initial value:

Definition at line 51 of file OnnxRuntimeInferenceTool.h.

◆ m_outputNodeNames

std::vector<std::string> AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
private

Definition at line 57 of file OnnxRuntimeInferenceTool.h.

◆ m_outputShapes

std::vector<std::vector<int64_t> > AthOnnx::IOnnxRuntimeInferenceTool::m_outputShapes
protectedinherited

Definition at line 105 of file IOnnxRuntimeInferenceTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: