ATLAS Offline Software
Loading...
Searching...
No Matches
AthOnnx::EvaluateModelWithAthInfer Class Reference

Algorithm demonstrating the usage of the ONNX Runtime C++ API. More...

#include <EvaluateModelWithAthInfer.h>

Inheritance diagram for AthOnnx::EvaluateModelWithAthInfer:
Collaboration diagram for AthOnnx::EvaluateModelWithAthInfer:

Public Member Functions

virtual StatusCode sysInitialize () override
 Override sysInitialize.
virtual bool isClonable () const override
 Specify if the algorithm is clonable.
virtual StatusCode sysExecute (const EventContext &ctx) override
 Execute an algorithm.
virtual const DataObjIDColl & extraOutputDeps () const override
 Return the list of extra output dependencies.
virtual bool filterPassed (const EventContext &ctx) const
 Get filter decision:
virtual void setFilterPassed (bool state, const EventContext &ctx) const
 Set filter decision:
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const
Function(s) inherited from @c AthAlgorithm
virtual StatusCode initialize () override
 Function initialising the algorithm.
virtual StatusCode execute (const EventContext &ctx) const override
 Function executing the algorithm for a single event.

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>

Private Attributes

DataObjIDColl m_extendedExtraObjects
 Extra output dependency collection, extended by AthAlgorithmDHUpdate to add symlinks.
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default).
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default).
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared
Algorithm properties
Gaudi::Property< std::string > m_pixelFileName
 Name of the model file to load.
Gaudi::Property< int > m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"}
 Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
ToolHandle< AthInfer::IAthInferenceToolm_onnxTool
 Tool handler for onnx inference session.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat

Detailed Description

Algorithm demonstrating the usage of the ONNX Runtime C++ API.

In most cases this should be preferred over the C API...

Author
Debottam Bakshi Gupta Debot.nosp@m.tam..nosp@m.Baksh.nosp@m.i.Gu.nosp@m.pta@c.nosp@m.ern..nosp@m.ch
Attila Krasznahorkay Attil.nosp@m.a.Kr.nosp@m.aszna.nosp@m.hork.nosp@m.ay@ce.nosp@m.rn.c.nosp@m.h

Definition at line 29 of file EvaluateModelWithAthInfer.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Member Function Documentation

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158 {
160 hndl.value(),
161 hndl.documentation());
162
163 }
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145 {
146 typedef typename SG::HandleClassifier<T>::type htype;
148 }
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ execute()

StatusCode AthOnnx::EvaluateModelWithAthInfer::execute ( const EventContext & ctx) const
overridevirtual

Function executing the algorithm for a single event.

Definition at line 61 of file EvaluateModelWithAthInfer.cxx.

62 {
63 // We know we have at least one image, otherwise we would have errored out
64 // earlier
65 const std::size_t n_batches =
67 const auto n_rows = std::int64_t(m_input_tensor_values_notFlat[0].size());
68 const auto n_cols = std::int64_t(m_input_tensor_values_notFlat[0][0].size());
69
70 for (std::size_t batch_idx = 0; batch_idx < n_batches; ++batch_idx) {
71 // prepare inputs
72 std::vector<float> inputDataVector;
73 inputDataVector.reserve(m_batchSize.value() * n_rows * n_cols);
74 for (const std::vector<std::vector<float>>& imageData :
76 std::views::drop(batch_idx * m_batchSize.value()) |
77 std::views::take(m_batchSize.value())) {
78 std::vector<float> flatten =
80 inputDataVector.insert(inputDataVector.end(), flatten.begin(),
81 flatten.end());
82 }
83
84 std::vector<int64_t> inputShape = {m_batchSize.value(), n_rows, n_cols};
85
86 AthInfer::InputDataMap inputData;
87 inputData["flatten_input:0"] =
88 std::make_pair(inputShape, std::move(inputDataVector));
89
90 const std::int64_t n_scores = 10;
91 AthInfer::OutputDataMap outputData;
92 outputData["dense_1/Softmax:0"] = std::make_pair(
93 std::vector<int64_t>{m_batchSize, n_scores}, std::vector<float>{});
94
95 ATH_CHECK(m_onnxTool->inference(inputData, outputData));
96
97 auto const& outputScores =
98 std::get<std::vector<float>>(outputData["dense_1/Softmax:0"].second);
99
100 if (outputScores.size() != std::size_t(n_scores * m_batchSize.value())) {
101 ATH_MSG_ERROR("Got back " << outputScores.size()
102 << " scores when it should have been "
103 << n_scores << " * " << m_batchSize.value()
104 << " = " << n_scores * m_batchSize.value());
105 return StatusCode::FAILURE;
106 }
107
108 for (int img_idx = 0; img_idx < m_batchSize.value(); img_idx++) {
109 std::span scores(outputScores.begin() + img_idx * n_scores,
110 outputScores.begin() + (img_idx + 1) * n_scores);
111 ATH_MSG_DEBUG("Scores for img " << img_idx << " of batch " << batch_idx
112 << ": "
113 << EvaluateUtils::spanToString(scores));
114 const auto max_elem = std::ranges::max_element(scores);
115 ATH_MSG_DEBUG("Class: " << max_elem - scores.begin()
116 << " has the highest score: " << *max_elem
117 << " in img " << img_idx << " of batch "
118 << batch_idx);
119 }
120 }
121 return StatusCode::SUCCESS;
122}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
size_t size() const
Number of registered mappings.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
ToolHandle< AthInfer::IAthInferenceTool > m_onnxTool
Tool handler for onnx inference session.
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T > > &features)
Definition OnnxUtils.h:24

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ extraOutputDeps()

const DataObjIDColl & AthCommonAlgorithm< Gaudi::Algorithm >::extraOutputDeps ( ) const
overridevirtualinherited

Return the list of extra output dependencies.

This list is extended to include symlinks implied by inheritance relations.

Definition at line 89 of file AthCommonAlgorithm.cxx.

54{
55 // If we didn't find any symlinks to add, just return the collection
56 // from the base class. Otherwise, return the extended collection.
57 if (!m_extendedExtraObjects.empty()) {
59 }
61}
Common base class for algorithms.

◆ filterPassed()

virtual bool AthCommonAlgorithm< Gaudi::Algorithm >::filterPassed ( const EventContext & ctx) const
inlinevirtualinherited

Get filter decision:

Definition at line 93 of file AthCommonAlgorithm.h.

93 {
94 return execState( ctx ).filterPassed();
95 }
virtual bool filterPassed(const EventContext &ctx) const
Get filter decision:

◆ initialize()

StatusCode AthOnnx::EvaluateModelWithAthInfer::initialize ( )
overridevirtual

Function initialising the algorithm.

Definition at line 18 of file EvaluateModelWithAthInfer.cxx.

18 {
19 if (m_batchSize.value() < 1) {
20 ATH_MSG_ERROR("Requested an invalid batch size: " << m_batchSize.value());
21 return StatusCode::FAILURE;
22 }
23
24 // Fetch tools
25 ATH_CHECK(m_onnxTool.retrieve());
26
27 // read input file, and the target file for comparison.
28 std::string pixelFilePath =
30 ATH_MSG_INFO("Using pixel file: " << pixelFilePath);
31
32 try {
36 "Total no. of samples: " << m_input_tensor_values_notFlat.size());
37 } catch (const std::exception& e) {
38 ATH_MSG_ERROR(e.what());
39 return StatusCode::FAILURE;
40 }
41
42 if (std::size_t(m_batchSize.value()) > m_input_tensor_values_notFlat.size()) {
43 ATH_MSG_ERROR("The batch size requested ("
44 << m_batchSize.value()
45 << ") is greater than the number of available "
46 "samples ("
47 << m_input_tensor_values_notFlat.size() << ")");
48 return StatusCode::FAILURE;
49 }
50
51 if (m_input_tensor_values_notFlat.size() % m_batchSize.value() != 0) {
52 ATH_MSG_ERROR("The number of samples ("
54 << ") is not a multiple of the requested batch size ("
55 << m_batchSize.value() << ")");
56 return StatusCode::FAILURE;
57 }
58 return StatusCode::SUCCESS;
59}
#define ATH_MSG_INFO(x)
Gaudi::Property< std::string > m_pixelFileName
Name of the model file to load.
static std::string find_calib_file(const std::string &logical_file_name)
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path)

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ isClonable()

virtual bool AthCommonAlgorithm< Gaudi::Algorithm >::isClonable ( ) const
inlineoverridevirtualinherited

Specify if the algorithm is clonable.

Only relevant for non-reentrant algorithms. Actual number of clones needs to be set via the "Cardinality" property.

Reimplemented in AFP_DigiTop, AlgB, AlgT, BCM_Digitization, CscDigitBuilder, CscDigitToCscRDO, G4AtlasAlg, G4RunAlg, HGTD_Digitization, HiveAlgBase, InDet::GNNSeedingTrackMaker, InDet::SCT_Clusterization, InDet::SiSPGNNTrackMaker, InDet::SiSPSeededTrackFinder, InDet::SiTrackerSpacePointFinder, ISF::SimKernelMT, ITk::StripDigitization, ITkPixelCablingAlg, ITkStripCablingAlg, LArHitEMapMaker, LArTTL1Maker, LUCID_DigiTop, LVL1::L1TopoSimulation, MergeCalibHits, MergeGenericMuonSimHitColl, MergeHijingPars, MergeMcEventCollection, MergeTrackRecordCollection, MergeTruthJets, MergeTruthParticles, MuonDigitizer, PileUpMTAlg, PixelDigitization, RoIBResultToxAOD, SCT_ByteStreamErrorsTestAlg, SCT_CablingCondAlgFromCoraCool, SCT_CablingCondAlgFromText, SCT_ConditionsParameterTestAlg, SCT_ConditionsSummaryTestAlg, SCT_ConfigurationConditionsTestAlg, SCT_Digitization, SCT_FlaggedConditionTestAlg, SCT_LinkMaskingTestAlg, SCT_MajorityConditionsTestAlg, SCT_ModuleVetoTestAlg, SCT_MonitorConditionsTestAlg, SCT_PrepDataToxAOD, SCT_RawDataToxAOD, SCT_ReadCalibChipDataTestAlg, SCT_ReadCalibDataTestAlg, SCT_RODVetoTestAlg, SCT_SensorsTestAlg, SCT_SiliconConditionsTestAlg, SCT_StripVetoTestAlg, SCT_TdaqEnabledTestAlg, SCT_TestCablingAlg, SCTEventFlagWriter, SCTRawDataProvider, SCTSiLorentzAngleTestAlg, SCTSiPropertiesTestAlg, SGInputLoader, Simulation::BeamEffectsAlg, TileHitVecToCnt, TileMuonFitter, TilePulseForTileMuonReceiver, TileRawChannelMaker, TRTDigitization, and ZDC_DigiTop.

Definition at line 68 of file AthCommonAlgorithm.h.

68 {
69 return true;
70 }

◆ msg()

MsgStream & AthCommonMsg< Gaudi::Algorithm >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24 {
25 return this->msgStream();
26 }

◆ msgLvl()

bool AthCommonMsg< Gaudi::Algorithm >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30 {
31 return this->msgLevel(lvl);
32 }

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381 {
382 h.renounce();
384 }
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364 {
366 }

◆ setFilterPassed()

virtual void AthCommonAlgorithm< Gaudi::Algorithm >::setFilterPassed ( bool state,
const EventContext & ctx ) const
inlinevirtualinherited

Set filter decision:

Reimplemented in AthFilterAlgorithm.

Definition at line 99 of file AthCommonAlgorithm.h.

99 {
101 }
virtual void setFilterPassed(bool state, const EventContext &ctx) const
Set filter decision:

◆ sysExecute()

StatusCode AthCommonAlgorithm< Gaudi::Algorithm >::sysExecute ( const EventContext & ctx)
overridevirtualinherited

Execute an algorithm.

We override this in order to work around an issue with the Algorithm base class storing the event context in a member variable that can cause crashes in MT jobs.

Reimplemented in AthAnalysisAlgorithm.

Definition at line 80 of file AthCommonAlgorithm.cxx.

41{
42 return BaseAlg::sysExecute (ctx);
43}

◆ sysInitialize()

StatusCode AthCommonAlgorithm< Gaudi::Algorithm >::sysInitialize ( )
overridevirtualinherited

Override sysInitialize.

Override sysInitialize from the base class.

Loop through all output handles, and if they're WriteCondHandles, automatically register them and this Algorithm with the CondSvc

Scan through all outputHandles, and if they're WriteCondHandles, register them with the CondSvc

Reimplemented from AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >.

Reimplemented in AthAnalysisAlgorithm, AthFilterAlgorithm, AthHistogramAlgorithm, HypoBase, InputMakerBase, and PyAthena::Alg.

Definition at line 60 of file AthCommonAlgorithm.cxx.

71 {
73
74 if (sc.isFailure()) {
75 return sc;
76 }
77
78 ServiceHandle<ICondSvc> cs("CondSvc",name());
79 for (auto h : outputHandles()) {
80 if (h->isCondition() && h->mode() == Gaudi::DataHandle::Writer) {
81 // do this inside the loop so we don't create the CondSvc until needed
82 if ( cs.retrieve().isFailure() ) {
83 ATH_MSG_WARNING("no CondSvc found: won't autoreg WriteCondHandles");
85 }
86 if (cs->regHandle(this,*h).isFailure()) {
88 ATH_MSG_ERROR("unable to register WriteCondHandle " << h->fullKey()
89 << " with CondSvc");
90 }
91 }
92 }
93 return sc;
94}
#define ATH_MSG_WARNING(x)
virtual StatusCode sysInitialize() override
virtual std::vector< Gaudi::DataHandle * > outputHandles() const override

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308 {
309 // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310 // << " size: " << m_vhka.size() << endmsg;
311 for (auto &a : m_vhka) {
313 for (auto k : keys) {
314 k->setOwner(this);
315 }
316 }
317 }

Member Data Documentation

◆ m_batchSize

Gaudi::Property<int> AthOnnx::EvaluateModelWithAthInfer::m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"}
private

Following properties needed to be consdered if the .onnx model is evaluated in batch mode.

Definition at line 55 of file EvaluateModelWithAthInfer.h.

55{this, "BatchSize", 1, "No. of elements/example in a batch"};

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default).

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default).

Definition at line 390 of file AthCommonDataStore.h.

◆ m_extendedExtraObjects

DataObjIDColl AthCommonAlgorithm< Gaudi::Algorithm >::m_extendedExtraObjects
privateinherited

Extra output dependency collection, extended by AthAlgorithmDHUpdate to add symlinks.

Empty if no symlinks were found.

Definition at line 108 of file AthCommonAlgorithm.h.

◆ m_input_tensor_values_notFlat

std::vector<std::vector<std::vector<float> > > AthOnnx::EvaluateModelWithAthInfer::m_input_tensor_values_notFlat
private

Definition at line 62 of file EvaluateModelWithAthInfer.h.

◆ m_onnxTool

ToolHandle< AthInfer::IAthInferenceTool > AthOnnx::EvaluateModelWithAthInfer::m_onnxTool
private
Initial value:
{
this, "ORTInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool"
}

Tool handler for onnx inference session.

Definition at line 58 of file EvaluateModelWithAthInfer.h.

58 {
59 this, "ORTInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool"
60 };

◆ m_pixelFileName

Gaudi::Property< std::string > AthOnnx::EvaluateModelWithAthInfer::m_pixelFileName
private
Initial value:
{ this, "InputDataPixel",
"dev/MLTest/2020-03-31/t10k-images-idx3-ubyte",
"Name of the input pixel file to load" }

Name of the model file to load.

Definition at line 50 of file EvaluateModelWithAthInfer.h.

50 { this, "InputDataPixel",
51 "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte",
52 "Name of the input pixel file to load" };

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: