Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Public Member Functions | Protected Member Functions | Protected Attributes | Private Types | Private Member Functions | Private Attributes | List of all members
AthOnnx::OnnxRuntimeInferenceTool Class Reference

#include <OnnxRuntimeInferenceTool.h>

Inheritance diagram for AthOnnx::OnnxRuntimeInferenceTool:
Collaboration diagram for AthOnnx::OnnxRuntimeInferenceTool:

Public Member Functions

 OnnxRuntimeInferenceTool (const std::string &name)
 Standard constructor. More...
 
virtual ~OnnxRuntimeInferenceTool ()=default
 
virtual StatusCode initialize () override
 Initialize the tool. More...
 
virtual void setBatchSize (int64_t batchSize) override final
 set batch size. More...
 
virtual int64_t getBatchSize (int64_t inputDataSize, int idx=0) const override final
 methods for determining batch size from the data size More...
 
virtual StatusCode inference (std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
 perform inference More...
 
virtual void printModelInfo () const override final
 
virtual StatusCode inference (AthInfer::InputDataMap &inputData, AthInfer::OutputDataMap &outputData) const override final
 
virtual void print () const
 Print the state of the tool. More...
 
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & evtStore () const
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc. More...
 
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm. More...
 
virtual StatusCode sysStart () override
 Handle START transition. More...
 
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles. More...
 
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles. More...
 
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T > &t)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKey &hndl, const std::string &doc, const SG::VarHandleKeyType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleBase &hndl, const std::string &doc, const SG::VarHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKeyArray &hndArr, const std::string &doc, const SG::VarHandleKeyArrayType &)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc, const SG::NotHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc="none")
 Declare a new Gaudi property. More...
 
void updateVHKA (Gaudi::Details::PropertyBase &)
 
MsgStream & msg () const
 
MsgStream & msg (const MSG::Level lvl) const
 
bool msgLvl (const MSG::Level lvl) const
 
template<typename T >
StatusCode addInput (std::vector< Ort::Value > &inputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
 add the input data to the input tensors More...
 
template<typename T >
StatusCode addOutput (std::vector< Ort::Value > &outputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
 add the output data to the output tensors More...
 

Protected Member Functions

 OnnxRuntimeInferenceTool ()=delete
 
 OnnxRuntimeInferenceTool (const OnnxRuntimeInferenceTool &)=delete
 
OnnxRuntimeInferenceTooloperator= (const OnnxRuntimeInferenceTool &)=delete
 
void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution More...
 
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
 
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed. More...
 

Protected Attributes

unsigned m_numInputs
 
unsigned m_numOutputs
 
std::vector< std::vector< int64_t > > m_inputShapes
 
std::vector< std::vector< int64_t > > m_outputShapes
 

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t
 

Private Member Functions

StatusCode getNodeInfo ()
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyArrayType &)
 specialization for handling Gaudi::Property<SG::VarHandleKeyArray> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleType &)
 specialization for handling Gaudi::Property<SG::VarHandleBase> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &t, const SG::NotHandleType &)
 specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray> More...
 
template<typename T >
Ort::Value createTensor (std::vector< T > &data, const std::vector< int64_t > &dataShape, int64_t batchSize) const
 

Private Attributes

ServiceHandle< IOnnxRuntimeSvcm_onnxRuntimeSvc {this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"}
 
ToolHandle< IOnnxRuntimeSessionToolm_onnxSessionTool
 
std::vector< std::string > m_inputNodeNames
 
std::vector< std::string > m_outputNodeNames
 
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default) More...
 
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default) More...
 
std::vector< SG::VarHandleKeyArray * > m_vhka
 
bool m_varHandleArraysDeclared
 

Detailed Description

Definition at line 21 of file OnnxRuntimeInferenceTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ OnnxRuntimeInferenceTool() [1/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( const std::string &  name)

Standard constructor.

Definition at line 8 of file OnnxRuntimeInferenceTool.cxx.

9  : asg::AsgTool ( name )
10 {
11 }

◆ ~OnnxRuntimeInferenceTool()

virtual AthOnnx::OnnxRuntimeInferenceTool::~OnnxRuntimeInferenceTool ( )
virtualdefault

◆ OnnxRuntimeInferenceTool() [2/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( )
protecteddelete

◆ OnnxRuntimeInferenceTool() [3/3]

AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool ( const OnnxRuntimeInferenceTool )
protecteddelete

Member Function Documentation

◆ addInput()

template<typename T >
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput ( std::vector< Ort::Value > &  inputTensors,
std::vector< T > &  data,
unsigned  idx = 0,
int64_t  batchSize = -1 
) const
inherited

add the input data to the input tensors

Parameters
inputTensorsthe input tensor container
datathe input data
idxthe index of the input node
batchSizethe batch size
Returns
StatusCode::SUCCESS if the input data is added successfully

◆ addOutput()

template<typename T >
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput ( std::vector< Ort::Value > &  outputTensors,
std::vector< T > &  data,
unsigned  idx = 0,
int64_t  batchSize = -1 
) const
inherited

add the output data to the output tensors

Parameters
outputTensorsthe output tensor container
datathe output data
idxthe index of the output node
batchSizethe batch size
Returns
StatusCode::SUCCESS if the output data is added successfully

◆ createTensor()

template<typename T >
Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor ( std::vector< T > &  data,
const std::vector< int64_t > &  dataShape,
int64_t  batchSize 
) const
privateinherited

◆ declareGaudiProperty() [1/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyArrayType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKeyArray>

Definition at line 170 of file AthCommonDataStore.h.

172  {
173  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
174  hndl.value(),
175  hndl.documentation());
176 
177  }

◆ declareGaudiProperty() [2/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158  {
159  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
160  hndl.value(),
161  hndl.documentation());
162 
163  }

◆ declareGaudiProperty() [3/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleBase>

Definition at line 184 of file AthCommonDataStore.h.

186  {
187  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
188  hndl.value(),
189  hndl.documentation());
190  }

◆ declareGaudiProperty() [4/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  t,
const SG::NotHandleType  
)
inlineprivateinherited

specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray>

Definition at line 199 of file AthCommonDataStore.h.

200  {
201  return PBASE::declareProperty(t);
202  }

◆ declareProperty() [1/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleBase hndl,
const std::string &  doc,
const SG::VarHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleBase. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 245 of file AthCommonDataStore.h.

249  {
250  this->declare(hndl.vhKey());
251  hndl.vhKey().setOwner(this);
252 
253  return PBASE::declareProperty(name,hndl,doc);
254  }

◆ declareProperty() [2/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKey hndl,
const std::string &  doc,
const SG::VarHandleKeyType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleKey. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 221 of file AthCommonDataStore.h.

225  {
226  this->declare(hndl);
227  hndl.setOwner(this);
228 
229  return PBASE::declareProperty(name,hndl,doc);
230  }

◆ declareProperty() [3/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKeyArray hndArr,
const std::string &  doc,
const SG::VarHandleKeyArrayType  
)
inlineinherited

Definition at line 259 of file AthCommonDataStore.h.

263  {
264 
265  // std::ostringstream ost;
266  // ost << Algorithm::name() << " VHKA declareProp: " << name
267  // << " size: " << hndArr.keys().size()
268  // << " mode: " << hndArr.mode()
269  // << " vhka size: " << m_vhka.size()
270  // << "\n";
271  // debug() << ost.str() << endmsg;
272 
273  hndArr.setOwner(this);
274  m_vhka.push_back(&hndArr);
275 
276  Gaudi::Details::PropertyBase* p = PBASE::declareProperty(name, hndArr, doc);
277  if (p != 0) {
278  p->declareUpdateHandler(&AthCommonDataStore<PBASE>::updateVHKA, this);
279  } else {
280  ATH_MSG_ERROR("unable to call declareProperty on VarHandleKeyArray "
281  << name);
282  }
283 
284  return p;
285 
286  }

◆ declareProperty() [4/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc,
const SG::NotHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This is the generic version, for types that do not derive from SG::VarHandleKey. It just forwards to the base class version of declareProperty.

Definition at line 333 of file AthCommonDataStore.h.

337  {
338  return PBASE::declareProperty(name, property, doc);
339  }

◆ declareProperty() [5/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc = "none" 
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This dispatches to either the generic declareProperty or the one for VarHandle/Key/KeyArray.

Definition at line 352 of file AthCommonDataStore.h.

355  {
356  typedef typename SG::HandleClassifier<T>::type htype;
357  return declareProperty (name, property, doc, htype());
358  }

◆ declareProperty() [6/6]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T > &  t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145  {
146  typedef typename SG::HandleClassifier<T>::type htype;
148  }

◆ detStore()

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

95 { return m_detStore; }

◆ evtStore() [1/2]

ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

85 { return m_evtStore; }

◆ evtStore() [2/2]

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( ) const
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 90 of file AthCommonDataStore.h.

90 { return m_evtStore; }

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase &  ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ getBatchSize()

int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize ( int64_t  dataSize,
int  idx = 0 
) const
finaloverridevirtual

methods for determining batch size from the data size

Parameters
dataSizethe size of the input data, like std::vector<T>::size()
idxthe index of the input node
Returns
the batch size, which equals to dataSize / size of the rest dimensions.

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 60 of file OnnxRuntimeInferenceTool.cxx.

61 {
62  auto tensorSize = AthOnnxUtils::getTensorSize(m_inputShapes[idx]);
63  if (tensorSize < 0) {
64  return inputDataSize / abs(tensorSize);
65  } else {
66  return -1;
67  }
68 }

◆ getKey()

SG::sgkey_t asg::AsgTool::getKey ( const void *  ptr) const
inherited

Get the (hashed) key of an object that is in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the SG::sgkey_t key for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getName
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The hashed key of the object in the store. If not found, an invalid (zero) key.

Definition at line 119 of file AsgTool.cxx.

119  {
120 
121 #ifdef XAOD_STANDALONE
122  // In case we use @c xAOD::TEvent, we have a direct function call
123  // for this.
124  return evtStore()->event()->getKey( ptr );
125 #else
126  const SG::DataProxy* proxy = evtStore()->proxy( ptr );
127  return ( proxy == nullptr ? 0 : proxy->sgkey() );
128 #endif // XAOD_STANDALONE
129  }

◆ getName()

const std::string & asg::AsgTool::getName ( const void *  ptr) const
inherited

Get the name of an object that is / should be in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the std::string name for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getKey
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The string name of the object in the store. If not found, an empty string.

Definition at line 106 of file AsgTool.cxx.

106  {
107 
108 #ifdef XAOD_STANDALONE
109  // In case we use @c xAOD::TEvent, we have a direct function call
110  // for this.
111  return evtStore()->event()->getName( ptr );
112 #else
113  const SG::DataProxy* proxy = evtStore()->proxy( ptr );
114  static const std::string dummy = "";
115  return ( proxy == nullptr ? dummy : proxy->name() );
116 #endif // XAOD_STANDALONE
117  }

◆ getNodeInfo()

StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo ( )
private

Definition at line 26 of file OnnxRuntimeInferenceTool.cxx.

27 {
28  auto& session = m_onnxSessionTool->session();
29  // obtain the model information
30  m_numInputs = session.GetInputCount();
31  m_numOutputs = session.GetOutputCount();
32 
35 
36  return StatusCode::SUCCESS;
37 }

◆ getProperty()

template<class T >
const T* asg::AsgTool::getProperty ( const std::string &  name) const
inherited

Get one of the tool's properties.

◆ inference() [1/2]

StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference ( AthInfer::InputDataMap inputData,
AthInfer::OutputDataMap outputData 
) const
finaloverridevirtual

Implements AthInfer::IAthInferenceTool.

Definition at line 118 of file OnnxRuntimeInferenceTool.cxx.

119 {
120  // Create input tensors.
121  std::vector<Ort::Value> inputTensors;
122  for (auto& [inputName, inputInfo] : inputData) {
123  const std::vector<int64_t>& shape = inputInfo.first;
124  if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
125  auto& data = std::get<std::vector<float>>(inputInfo.second);
126  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
127  } else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
128  auto& data = std::get<std::vector<int64_t>>(inputInfo.second);
129  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
130  } else {
131  ATH_MSG_ERROR("Unsupported data type");
132  return StatusCode::FAILURE;
133  }
134  }
135 
136  // Create output tensors.
137  std::vector<Ort::Value> outputTensors;
138  outputTensors.reserve(inputData.size());
139  for (auto& [outputName, outputInfo] : outputData) {
140  auto& shape = outputInfo.first;
141  auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
142 
143  if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
144  auto& data = std::get<std::vector<float>>(outputInfo.second);
145  data.resize(tensorSize);
146  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
147  } else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
148  auto& data = std::get<std::vector<int64_t>>(outputInfo.second);
149  data.resize(tensorSize);
150  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
151  } else {
152  ATH_MSG_ERROR("Unsupported data type");
153  return StatusCode::FAILURE;
154  }
155  }
156 
157  ATH_CHECK(inference(inputTensors, outputTensors));
158 
159  return StatusCode::SUCCESS;
160 }

◆ inference() [2/2]

StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference ( std::vector< Ort::Value > &  inputTensors,
std::vector< Ort::Value > &  outputTensors 
) const
finaloverridevirtual

perform inference

Parameters
inputTensorsthe input tensor container
outputTensorsthe output tensor container
Returns
StatusCode::SUCCESS if the inference is performed successfully

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 70 of file OnnxRuntimeInferenceTool.cxx.

71 {
72  assert (inputTensors.size() == m_numInputs);
73  assert (outputTensors.size() == m_numOutputs);
74 
75  // Run the model.
77  m_onnxSessionTool->session(),
78  m_inputNodeNames, inputTensors,
79  m_outputNodeNames, outputTensors);
80 
81  return StatusCode::SUCCESS;
82 }

◆ initialize()

StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize ( )
overridevirtual

Initialize the tool.

Reimplemented from asg::AsgTool.

Definition at line 13 of file OnnxRuntimeInferenceTool.cxx.

14 {
15  // Get the Onnx Runtime service.
16  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
17 
18  // Create the session.
19  ATH_CHECK(m_onnxSessionTool.retrieve());
20 
22 
23  return StatusCode::SUCCESS;
24 }

◆ inputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ msg() [1/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24  {
25  return this->msgStream();
26  }

◆ msg() [2/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( const MSG::Level  lvl) const
inlineinherited

Definition at line 27 of file AthCommonMsg.h.

27  {
28  return this->msgStream(lvl);
29  }

◆ msg_level_name()

const std::string & asg::AsgTool::msg_level_name ( ) const
inherited

A deprecated function for getting the message level's name.

Instead of using this, weirdly named function, user code should get the string name of the current minimum message level (in case they really need it...), with:

MSG::name( msg().level() )

This function's name doesn't follow the ATLAS coding rules, and as such will be removed in the not too distant future.

Returns
The string name of the current minimum message level that's printed

Definition at line 101 of file AsgTool.cxx.

101  {
102 
103  return MSG::name( msg().level() );
104  }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30  {
31  return this->msgLevel(lvl);
32  }

◆ operator=()

OnnxRuntimeInferenceTool& AthOnnx::OnnxRuntimeInferenceTool::operator= ( const OnnxRuntimeInferenceTool )
protecteddelete

◆ outputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ print()

void asg::AsgTool::print ( ) const
virtualinherited

◆ printModelInfo()

void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo ( ) const
finaloverridevirtual

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 84 of file OnnxRuntimeInferenceTool.cxx.

85 {
86  ATH_MSG_INFO("Number of inputs: " << m_numInputs);
87  ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
88 
89  ATH_MSG_INFO("Input node names: ");
90  for (const auto& name : m_inputNodeNames) {
91  ATH_MSG_INFO("\t" << name);
92  }
93 
94  ATH_MSG_INFO("Output node names: ");
95  for (const auto& name : m_outputNodeNames) {
96  ATH_MSG_INFO("\t" << name);
97  }
98 
99  ATH_MSG_INFO("Input shapes: ");
100  for (const auto& shape : m_inputShapes) {
101  std::string shapeStr = "\t";
102  for (const auto& dim : shape) {
103  shapeStr += std::to_string(dim) + " ";
104  }
105  ATH_MSG_INFO(shapeStr);
106  }
107 
108  ATH_MSG_INFO("Output shapes: ");
109  for (const auto& shape : m_outputShapes) {
110  std::string shapeStr = "\t";
111  for (const auto& dim : shape) {
112  shapeStr += std::to_string(dim) + " ";
113  }
114  ATH_MSG_INFO(shapeStr);
115  }
116 }

◆ renounce()

std::enable_if_t<std::is_void_v<std::result_of_t<decltype(&T::renounce)(T)> > && !std::is_base_of_v<SG::VarHandleKeyArray, T> && std::is_base_of_v<Gaudi::DataHandle, T>, void> AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T &  h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381  {
382  h.renounce();
383  PBASE::renounce (h);
384  }

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364  {
365  handlesArray.renounce();
366  }

◆ setBatchSize()

void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize ( int64_t  batchSize)
finaloverridevirtual

set batch size.

If the model has dynamic batch size, the batchSize value will be set to both input shapes and output shapes

Implements AthOnnx::IOnnxRuntimeInferenceTool.

Definition at line 40 of file OnnxRuntimeInferenceTool.cxx.

41 {
42  if (batchSize <= 0) {
43  ATH_MSG_ERROR("Batch size should be positive");
44  return;
45  }
46 
47  for (auto& shape : m_inputShapes) {
48  if (shape[0] == -1) {
49  shape[0] = batchSize;
50  }
51  }
52 
53  for (auto& shape : m_outputShapes) {
54  if (shape[0] == -1) {
55  shape[0] = batchSize;
56  }
57  }
58 }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in DerivationFramework::CfAthAlgTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and asg::AsgMetadataTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase &  )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308  {
309  // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310  // << " size: " << m_vhka.size() << endmsg;
311  for (auto &a : m_vhka) {
312  std::vector<SG::VarHandleKey*> keys = a->keys();
313  for (auto k : keys) {
314  k->setOwner(this);
315  }
316  }
317  }

Member Data Documentation

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_inputNodeNames

std::vector<std::string> AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
private

Definition at line 56 of file OnnxRuntimeInferenceTool.h.

◆ m_inputShapes

std::vector<std::vector<int64_t> > AthOnnx::IOnnxRuntimeInferenceTool::m_inputShapes
protectedinherited

Definition at line 104 of file IOnnxRuntimeInferenceTool.h.

◆ m_numInputs

unsigned AthOnnx::IOnnxRuntimeInferenceTool::m_numInputs
protectedinherited

Definition at line 102 of file IOnnxRuntimeInferenceTool.h.

◆ m_numOutputs

unsigned AthOnnx::IOnnxRuntimeInferenceTool::m_numOutputs
protectedinherited

Definition at line 103 of file IOnnxRuntimeInferenceTool.h.

◆ m_onnxRuntimeSvc

ServiceHandle<IOnnxRuntimeSvc> AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc {this, "OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc/OnnxRuntimeSvc", "The Onnx runtime service"}
private

Definition at line 50 of file OnnxRuntimeInferenceTool.h.

◆ m_onnxSessionTool

ToolHandle<IOnnxRuntimeSessionTool> AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
private
Initial value:
{
this, "ORTSessionTool",
"AthOnnx::OnnxRuntimeSessionToolCPU",
"The Onnx session tool"
}

Definition at line 51 of file OnnxRuntimeInferenceTool.h.

◆ m_outputNodeNames

std::vector<std::string> AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
private

Definition at line 57 of file OnnxRuntimeInferenceTool.h.

◆ m_outputShapes

std::vector<std::vector<int64_t> > AthOnnx::IOnnxRuntimeInferenceTool::m_outputShapes
protectedinherited

Definition at line 105 of file IOnnxRuntimeInferenceTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files:
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:70
AthOnnx::IOnnxRuntimeInferenceTool::m_inputShapes
std::vector< std::vector< int64_t > > m_inputShapes
Definition: IOnnxRuntimeInferenceTool.h:104
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
AthOnnxUtils::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
StateLessPT_NewConfig.proxy
proxy
Definition: StateLessPT_NewConfig.py:395
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthCommonDataStore::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
Definition: OnnxRuntimeInferenceTool.h:50
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
StoreGateSvc_t m_evtStore
Pointer to StoreGate (event store by default)
Definition: AthCommonDataStore.h:390
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
std::vector< SG::VarHandleKeyArray * > m_vhka
Definition: AthCommonDataStore.h:398
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
dbg::ptr
void * ptr(T *p)
Definition: SGImplSvc.cxx:74
python.iconfTool.models.loaders.level
level
Definition: loaders.py:20
SG::VarHandleKeyArray::setOwner
virtual void setOwner(IDataHandleHolder *o)=0
IDTPMcnv.htype
htype
Definition: IDTPMcnv.py:29
AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore
ServiceHandle< StoreGateSvc > & evtStore()
The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
Definition: AthCommonDataStore.h:85
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
AthCommonDataStore
Definition: AthCommonDataStore.h:52
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
AthOnnx::OnnxRuntimeInferenceTool::m_inputNodeNames
std::vector< std::string > m_inputNodeNames
Definition: OnnxRuntimeInferenceTool.h:56
python.xAODType.dummy
dummy
Definition: xAODType.py:4
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MSG::name
const std::string & name(Level lvl)
Convenience function for translating message levels to strings.
Definition: MsgLevel.cxx:19
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
StoreGateSvc_t m_detStore
Pointer to StoreGate (detector store by default)
Definition: AthCommonDataStore.h:393
AthOnnxUtils::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
SG::VarHandleKeyArray::renounce
virtual void renounce()=0
SG::HandleClassifier::type
std::conditional< std::is_base_of< SG::VarHandleKeyArray, T >::value, VarHandleKeyArrayType, type2 >::type type
Definition: HandleClassifier.h:54
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lumiFormat.outputName
string outputName
Definition: lumiFormat.py:65
runIDAlign.accumulate
accumulate
Update flags based on parser line args.
Definition: runIDAlign.py:63
AthOnnxUtils::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnxUtils::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
a
TList * a
Definition: liststreamerinfos.cxx:10
h
AthOnnx::OnnxRuntimeInferenceTool::m_outputNodeNames
std::vector< std::string > m_outputNodeNames
Definition: OnnxRuntimeInferenceTool.h:57
AthOnnx::IOnnxRuntimeInferenceTool::m_numOutputs
unsigned m_numOutputs
Definition: IOnnxRuntimeInferenceTool.h:103
AthCommonMsg< AlgTool >::msg
MsgStream & msg() const
Definition: AthCommonMsg.h:24
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
SG::VarHandleBase::vhKey
SG::VarHandleKey & vhKey()
Return a non-const reference to the HandleKey.
Definition: StoreGate/src/VarHandleBase.cxx:629
AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: OnnxRuntimeInferenceTool.h:51
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:26
AthOnnx::IOnnxRuntimeInferenceTool::m_outputShapes
std::vector< std::vector< int64_t > > m_outputShapes
Definition: IOnnxRuntimeInferenceTool.h:105
AthOnnxUtils::createTensor
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.h:78
SG::DataProxy
Definition: DataProxy.h:45
AthCommonDataStore::declareGaudiProperty
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>
Definition: AthCommonDataStore.h:156
fitman.k
k
Definition: fitman.py:528
AthOnnx::IOnnxRuntimeInferenceTool::m_numInputs
unsigned m_numInputs
Definition: IOnnxRuntimeInferenceTool.h:102