ATLAS Offline Software
Loading...
Searching...
No Matches
AthInfer::ExampleMLInferenceWithTriton Class Reference

Algorithm demonstrating the usage of the Triton Client API. More...

#include <ExampleMLInferenceWithTriton.h>

Inheritance diagram for AthInfer::ExampleMLInferenceWithTriton:
Collaboration diagram for AthInfer::ExampleMLInferenceWithTriton:

Public Member Functions

virtual StatusCode sysInitialize () override
 Override sysInitialize.
virtual bool isClonable () const override
 Specify if the algorithm is clonable.
virtual unsigned int cardinality () const override
 Cardinality (Maximum number of clones that can exist) special value 0 means that algorithm is reentrant.
virtual StatusCode sysExecute (const EventContext &ctx) override
 Execute an algorithm.
virtual const DataObjIDColl & extraOutputDeps () const override
 Return the list of extra output dependencies.
virtual bool filterPassed (const EventContext &ctx) const
virtual void setFilterPassed (bool state, const EventContext &ctx) const
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const
Function(s) inherited from @c AthAlgorithm
virtual StatusCode initialize () override
 Function initialising the algorithm.
virtual StatusCode execute (const EventContext &ctx) const override
 Function executing the algorithm for a single event.

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>

Private Attributes

DataObjIDColl m_extendedExtraObjects
 Extra output dependency collection, extended by AthAlgorithmDHUpdate to add symlinks.
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default)
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default)
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared

Algorithm properties

Gaudi::Property< std::string > m_pixelFileName
 Name of the model file to load.
Gaudi::Property< int > m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"}
 Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
ToolHandle< AthInfer::IAthInferenceToolm_tritonTool
 Tool handle for the Triton client.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat (const std::string &full_path) const
std::vector< float > flattenNestedVectors (const std::vector< std::vector< float > > &nestedVector) const

Detailed Description

Algorithm demonstrating the usage of the Triton Client API.

Author
Xiangyang Ju xju@l.nosp@m.bl.g.nosp@m.ov

Definition at line 20 of file ExampleMLInferenceWithTriton.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Member Function Documentation

◆ cardinality()

unsigned int AthCommonReentrantAlgorithm< Gaudi::Algorithm >::cardinality ( ) const
overridevirtualinherited

Cardinality (Maximum number of clones that can exist) special value 0 means that algorithm is reentrant.

Override this to return 0 for reentrant algorithms.

Definition at line 75 of file AthCommonReentrantAlgorithm.cxx.

64{
65 return 0;
66}

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158 {
160 hndl.value(),
161 hndl.documentation());
162
163 }
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145 {
146 typedef typename SG::HandleClassifier<T>::type htype;
148 }
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ execute()

StatusCode AthInfer::ExampleMLInferenceWithTriton::execute ( const EventContext & ctx) const
overridevirtual

Function executing the algorithm for a single event.

Definition at line 32 of file ExampleMLInferenceWithTriton.cxx.

32 {
33
34 // prepare inputs
35 std::vector<float> inputDataVector;
36 inputDataVector.reserve(m_input_tensor_values_notFlat.size());
37 for (const std::vector<std::vector<float> >& imageData : m_input_tensor_values_notFlat){
38
39 std::vector<float> flatten;
40 int total_size = 0;
41 for(const auto& feature : imageData) total_size += feature.size();
42 flatten.reserve(total_size);
43 for (const auto& feature : imageData)
44 for (const auto& elem : feature)
45 flatten.push_back(elem);
46
47 inputDataVector.insert(inputDataVector.end(), flatten.begin(), flatten.end());
48 }
49 std::vector<int64_t> inputShape = {m_batchSize, 28, 28};
50
51 AthInfer::InputDataMap inputData;
52 inputData["flatten_input:0"] = std::make_pair(
53 inputShape, std::move(inputDataVector)
54 );
55
56 AthInfer::OutputDataMap outputData;
57 outputData["dense_1/Softmax:0"] = std::make_pair(
58 std::vector<int64_t>{m_batchSize, 10}, std::vector<float>{}
59 );
60
61 ATH_CHECK(m_tritonTool->inference(inputData, outputData));
62
63 auto& outputScores = std::get<std::vector<float>>(outputData["dense_1/Softmax:0"].second);
64 auto inRange = [&outputScores](int idx)->bool{return (idx>=0) and (idx<std::ssize(outputScores));};
65 ATH_MSG_DEBUG("Label for the input test data: ");
66 for(int ibatch = 0; ibatch < m_batchSize; ibatch++){
67 float max = -999;
68 int max_index{-1};
69 for (int i = 0; i < 10; i++){
70 ATH_MSG_DEBUG("Score for class "<< i <<" = "<<outputScores[i] << " in batch " << ibatch);
71 int index = i + ibatch * 10;
72 if (not inRange(index)) continue;
73 if (max < outputScores[index]){
74 max = outputScores[index];
75 max_index = index;
76 }
77 }
78 if (not inRange(max_index)){
79 ATH_MSG_ERROR("No maximum found in ExampleMLInferenceWithTriton::execute");
80 return StatusCode::FAILURE;
81 }
82 ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
83 }
84
85 return StatusCode::SUCCESS;
86}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
bool inRange(const double *boundaries, const double value, const double tolerance=0.02)
#define max(a, b)
Definition cfImp.cxx:41
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
ToolHandle< AthInfer::IAthInferenceTool > m_tritonTool
Tool handle for the Triton client.
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
str index
Definition DeMoScan.py:362

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ extraOutputDeps()

const DataObjIDColl & AthCommonReentrantAlgorithm< Gaudi::Algorithm >::extraOutputDeps ( ) const
overridevirtualinherited

Return the list of extra output dependencies.

This list is extended to include symlinks implied by inheritance relations.

Definition at line 94 of file AthCommonReentrantAlgorithm.cxx.

90{
91 // If we didn't find any symlinks to add, just return the collection
92 // from the base class. Otherwise, return the extended collection.
93 if (!m_extendedExtraObjects.empty()) {
95 }
97}
An algorithm that can be simultaneously executed in multiple threads.

◆ filterPassed()

virtual bool AthCommonReentrantAlgorithm< Gaudi::Algorithm >::filterPassed ( const EventContext & ctx) const
inlinevirtualinherited

Definition at line 96 of file AthCommonReentrantAlgorithm.h.

96 {
97 return execState( ctx ).filterPassed();
98 }
virtual bool filterPassed(const EventContext &ctx) const

◆ flattenNestedVectors()

std::vector< float > AthInfer::ExampleMLInferenceWithTriton::flattenNestedVectors ( const std::vector< std::vector< float > > & nestedVector) const
private

◆ initialize()

StatusCode AthInfer::ExampleMLInferenceWithTriton::initialize ( )
overridevirtual

Function initialising the algorithm.

Definition at line 14 of file ExampleMLInferenceWithTriton.cxx.

14 {
15 // Fetch tools
16 ATH_CHECK( m_tritonTool.retrieve() );
17
18 if(m_batchSize > 10000){
19 ATH_MSG_INFO("The total no. of sample crossed the no. of available sample ....");
20 return StatusCode::FAILURE;
21 }
22 // read input file, and the target file for comparison.
23 std::string pixelFilePath = PathResolver::find_calib_file(m_pixelFileName);
24 ATH_MSG_INFO( "Using pixel file: " << pixelFilePath );
25
27 ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
28
29 return StatusCode::SUCCESS;
30}
#define ATH_MSG_INFO(x)
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path) const
Gaudi::Property< std::string > m_pixelFileName
Name of the model file to load.
static std::string find_calib_file(const std::string &logical_file_name)

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ isClonable()

◆ msg()

MsgStream & AthCommonMsg< Gaudi::Algorithm >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24 {
25 return this->msgStream();
26 }

◆ msgLvl()

bool AthCommonMsg< Gaudi::Algorithm >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30 {
31 return this->msgLevel(lvl);
32 }

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ read_mnist_pixel_notFlat()

std::vector< std::vector< std::vector< float > > > AthInfer::ExampleMLInferenceWithTriton::read_mnist_pixel_notFlat ( const std::string & full_path) const
private

Definition at line 89 of file ExampleMLInferenceWithTriton.cxx.

90{
91 std::vector<std::vector<std::vector<float>>> input_tensor_values;
92 input_tensor_values.resize(10000, std::vector<std::vector<float> >(28,std::vector<float>(28)));
93 std::ifstream file (full_path.c_str(), std::ios::binary);
94 int magic_number=0;
95 int number_of_images=0;
96 int n_rows=0;
97 int n_cols=0;
98 file.read(reinterpret_cast<char*>(&magic_number),sizeof(magic_number));
99 magic_number= ntohl(magic_number);
100 file.read(reinterpret_cast<char*>(&number_of_images),sizeof(number_of_images));
101 number_of_images= ntohl(number_of_images);
102 file.read(reinterpret_cast<char*>(&n_rows),sizeof(n_rows));
103 n_rows= ntohl(n_rows);
104 file.read(reinterpret_cast<char*>(&n_cols),sizeof(n_cols));
105 n_cols= ntohl(n_cols);
106 for(int i=0;i<number_of_images;++i)
107 {
108 for(int r=0;r<n_rows;++r)
109 {
110 for(int c=0;c<n_cols;++c)
111 {
112 unsigned char temp=0;
113 file.read(reinterpret_cast<char*>(&temp),sizeof(temp));
114 input_tensor_values[i][r][c]= float(temp)/255;
115 }
116 }
117 }
118 return input_tensor_values;
119}
int r
Definition globals.cxx:22
TFile * file

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381 {
382 h.renounce();
384 }
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364 {
366 }

◆ setFilterPassed()

virtual void AthCommonReentrantAlgorithm< Gaudi::Algorithm >::setFilterPassed ( bool state,
const EventContext & ctx ) const
inlinevirtualinherited

Definition at line 100 of file AthCommonReentrantAlgorithm.h.

100 {
102 }
virtual void setFilterPassed(bool state, const EventContext &ctx) const

◆ sysExecute()

StatusCode AthCommonReentrantAlgorithm< Gaudi::Algorithm >::sysExecute ( const EventContext & ctx)
overridevirtualinherited

Execute an algorithm.

We override this in order to work around an issue with the Algorithm base class storing the event context in a member variable that can cause crashes in MT jobs.

Definition at line 85 of file AthCommonReentrantAlgorithm.cxx.

77{
78 return BaseAlg::sysExecute (ctx);
79}

◆ sysInitialize()

StatusCode AthCommonReentrantAlgorithm< Gaudi::Algorithm >::sysInitialize ( )
overridevirtualinherited

Override sysInitialize.

Override sysInitialize from the base class.

Loop through all output handles, and if they're WriteCondHandles, automatically register them and this Algorithm with the CondSvc

Scan through all outputHandles, and if they're WriteCondHandles, register them with the CondSvc

Reimplemented from AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >.

Reimplemented in HypoBase, and InputMakerBase.

Definition at line 61 of file AthCommonReentrantAlgorithm.cxx.

107 {
109
110 if (sc.isFailure()) {
111 return sc;
112 }
113
114 ServiceHandle<ICondSvc> cs("CondSvc",name());
115 for (auto h : outputHandles()) {
116 if (h->isCondition() && h->mode() == Gaudi::DataHandle::Writer) {
117 // do this inside the loop so we don't create the CondSvc until needed
118 if ( cs.retrieve().isFailure() ) {
119 ATH_MSG_WARNING("no CondSvc found: won't autoreg WriteCondHandles");
120 return StatusCode::SUCCESS;
121 }
122 if (cs->regHandle(this,*h).isFailure()) {
124 ATH_MSG_ERROR("unable to register WriteCondHandle " << h->fullKey()
125 << " with CondSvc");
126 }
127 }
128 }
129 return sc;
130}
#define ATH_MSG_WARNING(x)
virtual std::vector< Gaudi::DataHandle * > outputHandles() const override

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308 {
309 // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310 // << " size: " << m_vhka.size() << endmsg;
311 for (auto &a : m_vhka) {
313 for (auto k : keys) {
314 k->setOwner(this);
315 }
316 }
317 }

Member Data Documentation

◆ m_batchSize

Gaudi::Property<int> AthInfer::ExampleMLInferenceWithTriton::m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"}
private

Following properties needed to be consdered if the .onnx model is evaluated in batch mode.

Definition at line 46 of file ExampleMLInferenceWithTriton.h.

46{this, "BatchSize", 1, "No. of elements/example in a batch"};

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_extendedExtraObjects

DataObjIDColl AthCommonReentrantAlgorithm< Gaudi::Algorithm >::m_extendedExtraObjects
privateinherited

Extra output dependency collection, extended by AthAlgorithmDHUpdate to add symlinks.

Empty if no symlinks were found.

Definition at line 114 of file AthCommonReentrantAlgorithm.h.

◆ m_input_tensor_values_notFlat

std::vector<std::vector<std::vector<float> > > AthInfer::ExampleMLInferenceWithTriton::m_input_tensor_values_notFlat
private

Definition at line 53 of file ExampleMLInferenceWithTriton.h.

◆ m_pixelFileName

Gaudi::Property< std::string > AthInfer::ExampleMLInferenceWithTriton::m_pixelFileName
private
Initial value:
{ this, "InputDataPixel",
"dev/MLTest/2020-03-31/t10k-images-idx3-ubyte",
"Name of the input pixel file to load" }

Name of the model file to load.

Definition at line 41 of file ExampleMLInferenceWithTriton.h.

41 { this, "InputDataPixel",
42 "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte",
43 "Name of the input pixel file to load" };

◆ m_tritonTool

ToolHandle< AthInfer::IAthInferenceTool > AthInfer::ExampleMLInferenceWithTriton::m_tritonTool
private
Initial value:
{
this, "InferenceTool", "AthInfer::TritonTool", "Triton client tool"
}

Tool handle for the Triton client.

Definition at line 49 of file ExampleMLInferenceWithTriton.h.

49 {
50 this, "InferenceTool", "AthInfer::TritonTool", "Triton client tool"
51 };

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: