ATLAS Offline Software
Public Member Functions | Protected Member Functions | Protected Attributes | Private Types | Private Member Functions | Private Attributes | List of all members
AthOnnx::OnnxRuntimeInferenceTool Class Reference

#include <OnnxRuntimeInferenceTool.h>

Inheritance diagram for AthOnnx::OnnxRuntimeInferenceTool:
Collaboration diagram for AthOnnx::OnnxRuntimeInferenceTool:

Public Member Functions

 OnnxRuntimeInferenceTool (const std::string &name)
 Standard constructor. More...
 
virtual ~OnnxRuntimeInferenceTool ()=default
 
virtual StatusCode initialize () override
 Initialize the tool. More...
 
virtual void setBatchSize (int64_t batchSize) override final
 set batch size. More...
 
virtual int64_t getBatchSize (int64_t inputDataSize, int idx=0) const override final
 methods for determining batch size from the data size More...
 
virtual StatusCode inference (std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
 perform inference More...
 
virtual void printModelInfo () const override final
 
virtual StatusCode inference (AthInfer::InputDataMap &inputData, AthInfer::OutputDataMap &outputData) const override final
 
virtual void print () const
 Print the state of the tool. More...
 
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & evtStore () const
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc. More...
 
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm. More...
 
virtual StatusCode sysStart () override
 Handle START transition. More...
 
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles. More...
 
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles. More...
 
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T > &t)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKey &hndl, const std::string &doc, const SG::VarHandleKeyType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleBase &hndl, const std::string &doc, const SG::VarHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKeyArray &hndArr, const std::string &doc, const SG::VarHandleKeyArrayType &)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc, const SG::NotHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc="none")
 Declare a new Gaudi property. More...
 
void updateVHKA (Gaudi::Details::PropertyBase &)
 
MsgStream & msg () const
 
MsgStream & msg (const MSG::Level lvl) const
 
bool msgLvl (const MSG::Level lvl) const
 
template<typename T >
StatusCode addInput (std::vector< Ort::Value > &inputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
 add the input data to the input tensors More...
 
template<typename T >
StatusCode addOutput (std::vector< Ort::Value > &outputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
 add the output data to the output tensors More...
 

Protected Member Functions

 OnnxRuntimeInferenceTool ()=delete
 
 OnnxRuntimeInferenceTool (const OnnxRuntimeInferenceTool &)=delete
 
OnnxRuntimeInferenceTooloperator= (const OnnxRuntimeInferenceTool &)=delete
 
void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution More...
 
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
 
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed. More...
 

Protected Attributes

unsigned m_numInputs
 
unsigned m_numOutputs
 
std::vector< std::vector< int64_t > > m_inputShapes
 
std::vector< std::vector< int64_t > > m_outputShapes
 

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t
 

Private Member Functions

StatusCode getNodeInfo ()
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyArrayType &)
 specialization for handling Gaudi::Property<SG::VarHandleKeyArray> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleType &)
 specialization for handling Gaudi::Property<SG::VarHandleBase> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &t, const SG::NotHandleType &)
 specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray> More...
 
template<typename T >
Ort::Value createTensor (std::vector< T > &data, const std::vector< int64_t > &dataShape, int64_t batchSize) const
 

Private Attributes

ServiceHandle< IOnnxRuntimeSvcm_onnxRuntimeSvc {"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"}
 
ToolHandle< IOnnxRuntimeSessionToolm_onnxSessionTool
 
std::vector< std::string > m_inputNodeNames
 
std::vector< std::string > m_outputNodeNames
 
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default) More...
 
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default) More...
 
std::vector< SG::VarHandleKeyArray * > m_vhka
 
bool m_varHandleArraysDeclared
 

Detailed Description

Definition at line 21 of file OnnxRuntimeInferenceTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ OnnxRuntimeInferenceTool() [1/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( const std::string &  name)

Standard constructor.

Definition at line 8 of file OnnxRuntimeInferenceTool.cxx.

9  : asg::AsgTool ( name )
10 {
11  declareProperty("OnnxSessionTool", m_onnxSessionTool, "The Onnx session tool");
12  declareProperty("OnnxRuntimeSvc", m_onnxRuntimeSvc, "The Onnx runtime service");
13 }

◆ ~OnnxRuntimeInferenceTool()

virtual AthOnnx::OnnxRuntimeInferenceTool::~OnnxRuntimeInferenceTool ( )
virtualdefault

◆ OnnxRuntimeInferenceTool() [2/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( )
protecteddelete

◆ OnnxRuntimeInferenceTool() [3/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( const OnnxRuntimeInferenceTool )
protecteddelete

Member Function Documentation

◆ addInput()

template<typename T >
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput ( std::vector< Ort::Value > &  inputTensors,
std::vector< T > &  data,
unsigned  idx = 0,
int64_t  batchSize = -1 
) const
inherited

add the input data to the input tensors

Parameters
inputTensorsthe input tensor container
datathe input data
idxthe index of the input node
batchSizethe batch size
Returns
StatusCode::SUCCESS if the input data is added successfully

◆ addOutput()

template<typename T >
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput ( std::vector< Ort::Value > &  outputTensors,
std::vector< T > &  data,
unsigned  idx = 0,
int64_t  batchSize = -1 
) const
inherited

add the output data to the output tensors

Parameters
outputTensorsthe output tensor container
datathe output data
idxthe index of the output node
batchSizethe batch size
Returns
StatusCode::SUCCESS if the output data is added successfully

◆ createTensor()

template<typename T >
Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor ( std::vector< T > &  data,
const std::vector< int64_t > &  dataShape,
int64_t  batchSize 
) const
privateinherited

◆ declareGaudiProperty() [1/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyArrayType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKeyArray>

Definition at line 170 of file AthCommonDataStore.h.

172  {
173  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
174  hndl.value(),
175  hndl.documentation());
176 
177  }

◆ declareGaudiProperty() [2/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158  {
159  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
160  hndl.value(),
161  hndl.documentation());
162 
163  }

◆ declareGaudiProperty() [3/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleBase>

Definition at line 184 of file AthCommonDataStore.h.

186  {
187  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
188  hndl.value(),
189  hndl.documentation());
190  }

◆ declareGaudiProperty() [4/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  t,
const SG::NotHandleType  
)
inlineprivateinherited

specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray>

Definition at line 199 of file AthCommonDataStore.h.

200  {
201  return PBASE::declareProperty(t);
202  }

◆ declareProperty() [1/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleBase hndl,
const std::string &  doc,
const SG::VarHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleBase. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 245 of file AthCommonDataStore.h.

249  {
250  this->declare(hndl.vhKey());
251  hndl.vhKey().setOwner(this);
252 
253  return PBASE::declareProperty(name,hndl,doc);
254  }

◆ declareProperty() [2/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKey hndl,
const std::string &  doc,
const SG::VarHandleKeyType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleKey. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 221 of file AthCommonDataStore.h.

225  {
226  this->declare(hndl);
227  hndl.setOwner(this);
228 
229  return PBASE::declareProperty(name,hndl,doc);
230  }

◆ declareProperty() [3/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKeyArray hndArr,
const std::string &  doc,
const SG::VarHandleKeyArrayType  
)
inlineinherited

Definition at line 259 of file AthCommonDataStore.h.

263  {
264 
265  // std::ostringstream ost;
266  // ost << Algorithm::name() << " VHKA declareProp: " << name
267  // << " size: " << hndArr.keys().size()
268  // << " mode: " << hndArr.mode()
269  // << " vhka size: " << m_vhka.size()
270  // << "\n";
271  // debug() << ost.str() << endmsg;
272 
273  hndArr.setOwner(this);
274  m_vhka.push_back(&hndArr);
275 
276  Gaudi::Details::PropertyBase* p = PBASE::declareProperty(name, hndArr, doc);
277  if (p != 0) {
278  p->declareUpdateHandler(&AthCommonDataStore<PBASE>::updateVHKA, this);
279  } else {
280  ATH_MSG_ERROR("unable to call declareProperty on VarHandleKeyArray "
281  << name);
282  }
283 
284  return p;
285 
286  }

◆ declareProperty() [4/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc,
const SG::NotHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This is the generic version, for types that do not derive from SG::VarHandleKey. It just forwards to the base class version of declareProperty.

Definition at line 333 of file AthCommonDataStore.h.

337  {
338  return PBASE::declareProperty(name, property, doc);
339  }

◆ declareProperty() [5/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc = "none" 
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This dispatches to either the generic declareProperty or the one for VarHandle/Key/KeyArray.

Definition at line 352 of file AthCommonDataStore.h.

355  {
356  typedef typename SG::HandleClassifier<T>::type htype;
357  return declareProperty (name, property, doc, htype());
358  }

◆ declareProperty() [6/6]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T > &  t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145  {
146  typedef typename SG::HandleClassifier<T>::type htype;
148  }

◆ detStore()

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

95 { return m_detStore; }

◆ evtStore() [1/2]

ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

85 { return m_evtStore; }

◆ evtStore() [2/2]

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( ) const
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 90 of file AthCommonDataStore.h.

90 { return m_evtStore; }

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase &  ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ getBatchSize()

int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize ( int64_t  dataSize,
int  idx = 0 
) const
finaloverridevirtual

methods for determining batch size from the data size

Parameters
dataSizethe size of the input data, like std::vector<T>::size()
idxthe index of the input node
Returns
the batch size, which equals to dataSize / size of the rest dimensions.

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 62 of file OnnxRuntimeInferenceTool.cxx.

63 {
64  auto tensorSize = AthOnnxUtils::getTensorSize(m_inputShapes[idx]);
65  if (tensorSize < 0) {
66  return inputDataSize / abs(tensorSize);
67  } else {
68  return -1;
69  }
70 }

◆ getKey()

SG::sgkey_t asg::AsgTool::getKey ( const void *  ptr) const
inherited

Get the (hashed) key of an object that is in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the SG::sgkey_t key for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getName
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The hashed key of the object in the store. If not found, an invalid (zero) key.

Definition at line 119 of file AsgTool.cxx.

119  {
120 
121 #ifdef XAOD_STANDALONE
122  // In case we use @c xAOD::TEvent, we have a direct function call
123  // for this.
124  return evtStore()->event()->getKey( ptr );
125 #else
126  const SG::DataProxy* proxy = evtStore()->proxy( ptr );
127  return ( proxy == nullptr ? 0 : proxy->sgkey() );
128 #endif // XAOD_STANDALONE
129  }

◆ getName()

const std::string & asg::AsgTool::getName ( const void *  ptr) const
inherited

Get the name of an object that is / should be in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the std::string name for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getKey
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The string name of the object in the store. If not found, an empty string.

Definition at line 106 of file AsgTool.cxx.

106  {
107 
108 #ifdef XAOD_STANDALONE
109  // In case we use @c xAOD::TEvent, we have a direct function call
110  // for this.
111  return evtStore()->event()->getName( ptr );
112 #else
113  const SG::DataProxy* proxy = evtStore()->proxy( ptr );
114  static const std::string dummy = "";
115  return ( proxy == nullptr ? dummy : proxy->name() );
116 #endif // XAOD_STANDALONE
117  }

◆ getNodeInfo()

StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo ( )
private

Definition at line 28 of file OnnxRuntimeInferenceTool.cxx.

29 {
30  auto& session = m_onnxSessionTool->session();
31  // obtain the model information
32  m_numInputs = session.GetInputCount();
33  m_numOutputs = session.GetOutputCount();
34 
37 
38  return StatusCode::SUCCESS;
39 }

◆ getProperty()

template<class T >
const T* asg::AsgTool::getProperty ( const std::string &  name) const
inherited

Get one of the tool's properties.

◆ inference() [1/2]

StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference ( AthInfer::InputDataMap inputData,
AthInfer::OutputDataMap outputData 
) const
finaloverridevirtual

Implements AthInfer::IAthInferenceTool.

Definition at line 120 of file OnnxRuntimeInferenceTool.cxx.

121 {
122  // Create input tensors.
123  std::vector<Ort::Value> inputTensors;
124  for (auto& [inputName, inputInfo] : inputData) {
125  const std::vector<int64_t>& shape = inputInfo.first;
126  if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
127  auto& data = std::get<std::vector<float>>(inputInfo.second);
128  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
129  } else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
130  auto& data = std::get<std::vector<int64_t>>(inputInfo.second);
131  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
132  } else {
133  ATH_MSG_ERROR("Unsupported data type");
134  return StatusCode::FAILURE;
135  }
136  }
137 
138  // Create output tensors.
139  std::vector<Ort::Value> outputTensors;
140  outputTensors.reserve(inputData.size());
141  for (auto& [outputName, outputInfo] : outputData) {
142  auto& shape = outputInfo.first;
143  auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
144 
145  if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
146  auto& data = std::get<std::vector<float>>(outputInfo.second);
147  data.resize(tensorSize);
148  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
149  } else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
150  auto& data = std::get<std::vector<int64_t>>(outputInfo.second);
151  data.resize(tensorSize);
152  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
153  } else {
154  ATH_MSG_ERROR("Unsupported data type");
155  return StatusCode::FAILURE;
156  }
157  }
158 
159  ATH_CHECK(inference(inputTensors, outputTensors));
160 
161  return StatusCode::SUCCESS;
162 }

◆ inference() [2/2]

StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference ( std::vector< Ort::Value > &  inputTensors,
std::vector< Ort::Value > &  outputTensors 
) const
finaloverridevirtual

perform inference

Parameters
inputTensorsthe input tensor container
outputTensorsthe output tensor container
Returns
StatusCode::SUCCESS if the inference is performed successfully

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 72 of file OnnxRuntimeInferenceTool.cxx.

73 {
74  assert (inputTensors.size() == m_numInputs);
75  assert (outputTensors.size() == m_numOutputs);
76 
77  // Run the model.
79  m_onnxSessionTool->session(),
80  m_inputNodeNames, inputTensors,
81  m_outputNodeNames, outputTensors);
82 
83  return StatusCode::SUCCESS;
84 }

◆ initialize()

StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize ( )
overridevirtual

Initialize the tool.

Reimplemented from asg::AsgTool.

Definition at line 15 of file OnnxRuntimeInferenceTool.cxx.

16 {
17  // Get the Onnx Runtime service.
18  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
19 
20  // Create the session.
21  ATH_CHECK(m_onnxSessionTool.retrieve());
22 
24 
25  return StatusCode::SUCCESS;
26 }

◆ inputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ msg() [1/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24  {
25  return this->msgStream();
26  }

◆ msg() [2/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( const MSG::Level  lvl) const
inlineinherited

Definition at line 27 of file AthCommonMsg.h.

27  {
28  return this->msgStream(lvl);
29  }

◆ msg_level_name()

const std::string & asg::AsgTool::msg_level_name ( ) const
inherited

A deprecated function for getting the message level's name.

Instead of using this, weirdly named function, user code should get the string name of the current minimum message level (in case they really need it...), with:

MSG::name( msg().level() )

This function's name doesn't follow the ATLAS coding rules, and as such will be removed in the not too distant future.

Returns
The string name of the current minimum message level that's printed

Definition at line 101 of file AsgTool.cxx.

101  {
102 
103  return MSG::name( msg().level() );
104  }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30  {
31  return this->msgLevel(lvl);
32  }

◆ operator=()

OnnxRuntimeInferenceTool& AthOnnx::OnnxRuntimeInferenceTool::operator= ( const OnnxRuntimeInferenceTool )
protecteddelete

◆ outputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ print()

void asg::AsgTool::print ( ) const
virtualinherited

◆ printModelInfo()

void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo ( ) const
finaloverridevirtual

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 86 of file OnnxRuntimeInferenceTool.cxx.

87 {
88  ATH_MSG_INFO("Number of inputs: " << m_numInputs);
89  ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
90 
91  ATH_MSG_INFO("Input node names: ");
92  for (const auto& name : m_inputNodeNames) {
93  ATH_MSG_INFO("\t" << name);
94  }
95 
96  ATH_MSG_INFO("Output node names: ");
97  for (const auto& name : m_outputNodeNames) {
98  ATH_MSG_INFO("\t" << name);
99  }
100 
101  ATH_MSG_INFO("Input shapes: ");
102  for (const auto& shape : m_inputShapes) {
103  std::string shapeStr = "\t";
104  for (const auto& dim : shape) {
105  shapeStr += std::to_string(dim) + " ";
106  }
107  ATH_MSG_INFO(shapeStr);
108  }
109 
110  ATH_MSG_INFO("Output shapes: ");
111  for (const auto& shape : m_outputShapes) {
112  std::string shapeStr = "\t";
113  for (const auto& dim : shape) {
114  shapeStr += std::to_string(dim) + " ";
115  }
116  ATH_MSG_INFO(shapeStr);
117  }
118 }

◆ renounce()

std::enable_if_t<std::is_void_v<std::result_of_t<decltype(&T::renounce)(T)> > && !std::is_base_of_v<SG::VarHandleKeyArray, T> && std::is_base_of_v<Gaudi::DataHandle, T>, void> AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T &  h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381  {
382  h.renounce();
383  PBASE::renounce (h);
384  }

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364  {
365  handlesArray.renounce();
366  }

◆ setBatchSize()

void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize ( int64_t  batchSize)
finaloverridevirtual

set batch size.

If the model has dynamic batch size, the batchSize value will be set to both input shapes and output shapes

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 42 of file OnnxRuntimeInferenceTool.cxx.

43 {
44  if (batchSize <= 0) {
45  ATH_MSG_ERROR("Batch size should be positive");
46  return;
47  }
48 
49  for (auto& shape : m_inputShapes) {
50  if (shape[0] == -1) {
51  shape[0] = batchSize;
52  }
53  }
54 
55  for (auto& shape : m_outputShapes) {
56  if (shape[0] == -1) {
57  shape[0] = batchSize;
58  }
59  }
60 }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in DerivationFramework::CfAthAlgTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and asg::AsgMetadataTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase &  )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308  {
309  // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310  // << " size: " << m_vhka.size() << endmsg;
311  for (auto &a : m_vhka) {
312  std::vector<SG::VarHandleKey*> keys = a->keys();
313  for (auto k : keys) {
314  k->setOwner(this);
315  }
316  }
317  }

Member Data Documentation

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_inputNodeNames

std::vector<std::string> AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
private

Definition at line 55 of file OnnxRuntimeInferenceTool.h.

◆ m_inputShapes

std::vector<std::vector<int64_t> > AthOnnx::IOnnxRuntimeInferenceTool::m_inputShapes
protectedinherited

Definition at line 104 of file IOnnxRuntimeInferenceTool.h.

◆ m_numInputs

unsigned AthOnnx::IOnnxRuntimeInferenceTool::m_numInputs
protectedinherited

Definition at line 102 of file IOnnxRuntimeInferenceTool.h.

◆ m_numOutputs

unsigned AthOnnx::IOnnxRuntimeInferenceTool::m_numOutputs
protectedinherited

Definition at line 103 of file IOnnxRuntimeInferenceTool.h.

◆ m_onnxRuntimeSvc

ServiceHandle<IOnnxRuntimeSvc> AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc {"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"}
private

Definition at line 50 of file OnnxRuntimeInferenceTool.h.

◆ m_onnxSessionTool

ToolHandle<IOnnxRuntimeSessionTool> AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
private
Initial value:
{
this, "ORTSessionTool",
"AthOnnx::OnnxRuntimeSessionToolCPU"
}

Definition at line 51 of file OnnxRuntimeInferenceTool.h.

◆ m_outputNodeNames

std::vector<std::string> AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
private

Definition at line 56 of file OnnxRuntimeInferenceTool.h.

◆ m_outputShapes

std::vector<std::vector<int64_t> > AthOnnx::IOnnxRuntimeInferenceTool::m_outputShapes
protectedinherited

Definition at line 105 of file IOnnxRuntimeInferenceTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files:
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:72
AthOnnx::IOnnxRuntimeInferenceTool::m_inputShapes
std::vector< std::vector< int64_t > > m_inputShapes
Definition: IOnnxRuntimeInferenceTool.h:104
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
AthOnnxUtils::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
StateLessPT_NewConfig.proxy
proxy
Definition: StateLessPT_NewConfig.py:392
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
Definition: OnnxRuntimeInferenceTool.h:50
accumulate
bool accumulate(AccumulateMap &map, std::vector< module_t > const &modules, FPGATrackSimMatrixAccumulator const &acc)
Accumulates an accumulator (e.g.
Definition: FPGATrackSimMatrixAccumulator.cxx:22
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
StoreGateSvc_t m_evtStore
Pointer to StoreGate (event store by default)
Definition: AthCommonDataStore.h:390
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
std::vector< SG::VarHandleKeyArray * > m_vhka
Definition: AthCommonDataStore.h:398
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
dbg::ptr
void * ptr(T *p)
Definition: SGImplSvc.cxx:74
python.iconfTool.models.loaders.level
level
Definition: loaders.py:20
SG::VarHandleKeyArray::setOwner
virtual void setOwner(IDataHandleHolder *o)=0
IDTPMcnv.htype
htype
Definition: IDTPMcnv.py:29
AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore
ServiceHandle< StoreGateSvc > & evtStore()
The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
Definition: AthCommonDataStore.h:85
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
AthCommonDataStore
Definition: AthCommonDataStore.h:52
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
std::vector< std::string > m_inputNodeNames
Definition: OnnxRuntimeInferenceTool.h:55
python.xAODType.dummy
dummy
Definition: xAODType.py:4
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MSG::name
const std::string & name(Level lvl)
Convenience function for translating message levels to strings.
Definition: MsgLevel.cxx:19
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
StoreGateSvc_t m_detStore
Pointer to StoreGate (detector store by default)
Definition: AthCommonDataStore.h:393
AthOnnxUtils::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
SG::VarHandleKeyArray::renounce
virtual void renounce()=0
SG::HandleClassifier::type
std::conditional< std::is_base_of< SG::VarHandleKeyArray, T >::value, VarHandleKeyArrayType, type2 >::type type
Definition: HandleClassifier.h:54
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lumiFormat.outputName
string outputName
Definition: lumiFormat.py:65
AthOnnxUtils::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnxUtils::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
a
TList * a
Definition: liststreamerinfos.cxx:10
h
AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
std::vector< std::string > m_outputNodeNames
Definition: OnnxRuntimeInferenceTool.h:56
AthOnnx::IOnnxRuntimeInferenceTool::m_numOutputs
unsigned m_numOutputs
Definition: IOnnxRuntimeInferenceTool.h:103
AthCommonMsg< AlgTool >::msg
MsgStream & msg() const
Definition: AthCommonMsg.h:24
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
SG::VarHandleBase::vhKey
SG::VarHandleKey & vhKey()
Return a non-const reference to the HandleKey.
Definition: StoreGate/src/VarHandleBase.cxx:623
AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: OnnxRuntimeInferenceTool.h:51
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:28
AthOnnx::IOnnxRuntimeInferenceTool::m_outputShapes
std::vector< std::vector< int64_t > > m_outputShapes
Definition: IOnnxRuntimeInferenceTool.h:105
AthOnnxUtils::createTensor
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.h:78
SG::DataProxy
Definition: DataProxy.h:45
AthCommonDataStore::declareGaudiProperty
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>
Definition: AthCommonDataStore.h:156
fitman.k
k
Definition: fitman.py:528
AthOnnx::IOnnxRuntimeInferenceTool::m_numInputs
unsigned m_numInputs
Definition: IOnnxRuntimeInferenceTool.h:102