ATLAS Offline Software
Loading...
Searching...
No Matches
AthONNX::JSSMLTool Class Reference

Tool using the ONNX Runtime C++ API to retrieve constituents based model for boson jet tagging. More...

#include <JSSMLTool.h>

Inheritance diagram for AthONNX::JSSMLTool:
Collaboration diagram for AthONNX::JSSMLTool:

Public Member Functions

 JSSMLTool (const std::string &name)
virtual StatusCode initialize () override
 Function initialising the tool.
virtual double retrieveConstituentsScore (std::vector< TH2D > Images) const override
 Function executing the tool for a single event.
virtual double retrieveConstituentsScore (std::vector< std::vector< float > > constituents) const override
virtual double retrieveConstituentsScore (std::vector< std::vector< float > > constituents, std::vector< std::vector< std::vector< float > > > interactions) const override
virtual double retrieveConstituentsScore (std::vector< std::vector< float > > constituents, std::vector< std::vector< std::vector< float > > > interactions, std::vector< std::vector< float > > mask) const override
virtual double retrieveHighLevelScore (std::map< std::string, double > JSSVars) const override
std::vector< float > ReadJetImagePixels (std::vector< TH2D > Images) const
std::vector< float > ReadJSSInputs (std::map< std::string, double > JSSVars) const
std::vector< int > ReadOutputLabels () const
StatusCode SetScaler (std::map< std::string, std::vector< double > > scaler) override
virtual void print () const
 Print the state of the tool.
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const
Additional helper functions, not directly mimicking Athena
template<class T>
const T * getProperty (const std::string &name) const
 Get one of the tool's properties.
const std::string & msg_level_name () const __attribute__((deprecated))
 A deprecated function for getting the message level's name.
const std::string & getName (const void *ptr) const
 Get the name of an object that is / should be in the event store.
SG::sgkey_t getKey (const void *ptr) const
 Get the (hashed) key of an object that is in the event store.

Public Attributes

std::unique_ptr< Ort::Session > m_session
std::unique_ptr< Ort::Env > m_env
std::map< std::string, std::vector< double > > m_scaler
std::map< int, std::string > m_JSSInputMap

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>

Private Attributes

std::string m_modelFileName
 Name of the model file to load.
std::string m_pixelFileName
std::string m_labelFileName
std::vector< int64_t > m_input_node_dims
size_t m_num_input_nodes {}
std::vector< const char * > m_input_node_names
std::vector< int64_t > m_output_node_dims
size_t m_num_output_nodes {}
std::vector< const char * > m_output_node_names
int m_nPixelsX {}
int m_nPixelsY {}
int m_nPixelsZ {}
int m_nvars {}
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default)
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default)
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared

Detailed Description

Tool using the ONNX Runtime C++ API to retrieve constituents based model for boson jet tagging.

this is inspired from the general athena example here: https://gitlab.cern.ch/atlas/athena/-/blob/21.2/Control/AthenaExamples/AthExOnnxRuntime/AthExOnnxRuntime/CxxApiAlgorithm.h

this is implementation is an extension from the one done in rel.21 https://gitlab.cern.ch/atlas/athena/-/tree/21.2/Reconstruction/Jet/AthOnnxRuntimeBJT as the plan is to move to use the central ONNX interface the tool has been merged with the BJT

monitoring jira ticket: https://its.cern.ch/jira/browse/ATLJETMET-1893

Author
Antonio Giannini anton.nosp@m.io.g.nosp@m.ianni.nosp@m.ni@c.nosp@m.ern.c.nosp@m.h

Definition at line 46 of file JSSMLTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ JSSMLTool()

AthONNX::JSSMLTool::JSSMLTool ( const std::string & name)

Definition at line 73 of file JSSMLTool.cxx.

73 :
74 AsgTool(name)
75 {
76 declareProperty("ModelPath", m_modelFileName);
77 declareProperty("nPixelsX", m_nPixelsX);
78 declareProperty("nPixelsY", m_nPixelsY);
79 declareProperty("nPixelsZ", m_nPixelsZ);
80 }
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
std::string m_modelFileName
Name of the model file to load.
Definition JSSMLTool.h:81
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58

Member Function Documentation

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158 {
160 hndl.value(),
161 hndl.documentation());
162
163 }

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145 {
146 typedef typename SG::HandleClassifier<T>::type htype;
148 }
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ getKey()

SG::sgkey_t asg::AsgTool::getKey ( const void * ptr) const
inherited

Get the (hashed) key of an object that is in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the SG::sgkey_t key for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getName
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The hashed key of the object in the store. If not found, an invalid (zero) key.

Definition at line 119 of file AsgTool.cxx.

119 {
120
121#ifdef XAOD_STANDALONE
122 // In case we use @c xAOD::TEvent, we have a direct function call
123 // for this.
124 return evtStore()->event()->getKey( ptr );
125#else
126 const SG::DataProxy* proxy = evtStore()->proxy( ptr );
127 return ( proxy == nullptr ? 0 : proxy->sgkey() );
128#endif // XAOD_STANDALONE
129 }
ServiceHandle< StoreGateSvc > & evtStore()

◆ getName()

const std::string & asg::AsgTool::getName ( const void * ptr) const
inherited

Get the name of an object that is / should be in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the std::string name for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getKey
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The string name of the object in the store. If not found, an empty string.

Definition at line 106 of file AsgTool.cxx.

106 {
107
108#ifdef XAOD_STANDALONE
109 // In case we use @c xAOD::TEvent, we have a direct function call
110 // for this.
111 return evtStore()->event()->getName( ptr );
112#else
113 const SG::DataProxy* proxy = evtStore()->proxy( ptr );
114 static const std::string dummy = "";
115 return ( proxy == nullptr ? dummy : proxy->name() );
116#endif // XAOD_STANDALONE
117 }

◆ getProperty()

template<class T>
const T * asg::AsgTool::getProperty ( const std::string & name) const
inherited

Get one of the tool's properties.

◆ initialize()

StatusCode AthONNX::JSSMLTool::initialize ( void )
overridevirtual

Function initialising the tool.

Reimplemented from asg::AsgTool.

Definition at line 83 of file JSSMLTool.cxx.

83 {
84
85 // Access the service.
86 // Find the model file.
87 ATH_MSG_INFO( "Using model file: " << m_modelFileName );
88
89 // Set up the ONNX Runtime session.
90 Ort::SessionOptions sessionOptions;
91 sessionOptions.SetIntraOpNumThreads( 1 );
92 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
93
94 // according to the discussion here https://its.cern.ch/jira/browse/ATLASG-2866
95 // this should reduce memory use while slowing things down slightly
96 sessionOptions.DisableCpuMemArena();
97
98 // declare an allocator
99 Ort::AllocatorWithDefaultOptions allocator;
100
101 // create session and load model into memory
102 m_env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "");
103 m_session = std::make_unique< Ort::Session >( *m_env,
104 m_modelFileName.c_str(),
105 sessionOptions );
106
107 ATH_MSG_INFO( "Created the ONNX Runtime session" );
108
109 m_num_input_nodes = m_session->GetInputCount();
111
112 for( std::size_t i = 0; i < m_num_input_nodes; i++ ) {
113 // print input node names
114 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
115 ATH_MSG_DEBUG("Input "<<i<<" : "<<" name = "<<input_name);
116 m_input_node_names[i] = input_name;
117 // print input node types
118 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
119 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
120 ONNXTensorElementDataType type = tensor_info.GetElementType();
121 ATH_MSG_DEBUG("Input "<<i<<" : "<<" type = "<<type);
122
123 // print input shapes/dims
124 m_input_node_dims = tensor_info.GetShape();
125 ATH_MSG_DEBUG("Input "<<i<<" : num_dims = "<<m_input_node_dims.size());
126 for (std::size_t j = 0; j < m_input_node_dims.size(); j++){
127 if(m_input_node_dims[j]<0)
128 m_input_node_dims[j] =1;
129 ATH_MSG_DEBUG("Input"<<i<<" : dim "<<j<<" = "<<m_input_node_dims[j]);
130 }
131 }
132
133 m_num_output_nodes = m_session->GetOutputCount();
135
136 for( std::size_t i = 0; i < m_num_output_nodes; i++ ) {
137 // print output node names
138 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
139 ATH_MSG_DEBUG("Output "<<i<<" : "<<" name = "<<output_name);
140 m_output_node_names[i] = output_name;
141
142 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
143 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
144 ONNXTensorElementDataType type = tensor_info.GetElementType();
145 ATH_MSG_DEBUG("Output "<<i<<" : "<<" type = "<<type);
146
147 // print output shapes/dims
148 m_output_node_dims = tensor_info.GetShape();
149 ATH_MSG_INFO("Output "<<i<<" : num_dims = "<<m_output_node_dims.size());
150 for (std::size_t j = 0; j < m_output_node_dims.size(); j++){
151 if(m_output_node_dims[j]<0)
152 m_output_node_dims[j] =1;
153 ATH_MSG_INFO("Output"<<i<<" : dim "<<j<<" = "<<m_output_node_dims[j]);
154 }
155 }
156
157 // Return gracefully.
158 return StatusCode::SUCCESS;
159 } // end initialize ---
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
std::unique_ptr< Ort::Env > m_env
Definition JSSMLTool.h:73
std::vector< int64_t > m_output_node_dims
Definition JSSMLTool.h:91
size_t m_num_output_nodes
Definition JSSMLTool.h:92
size_t m_num_input_nodes
Definition JSSMLTool.h:87
std::vector< const char * > m_output_node_names
Definition JSSMLTool.h:93
std::vector< int64_t > m_input_node_dims
Definition JSSMLTool.h:86
std::vector< const char * > m_input_node_names
Definition JSSMLTool.h:88
std::unique_ptr< Ort::Session > m_session
Definition JSSMLTool.h:72

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ msg()

MsgStream & AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24 {
25 return this->msgStream();
26 }

◆ msg_level_name()

const std::string & asg::AsgTool::msg_level_name ( ) const
inherited

A deprecated function for getting the message level's name.

Instead of using this, weirdly named function, user code should get the string name of the current minimum message level (in case they really need it...), with:

MSG::name( msg().level() )

This function's name doesn't follow the ATLAS coding rules, and as such will be removed in the not too distant future.

Returns
The string name of the current minimum message level that's printed

Definition at line 101 of file AsgTool.cxx.

101 {
102
103 return MSG::name( msg().level() );
104 }
MsgStream & msg() const
const std::string & name(Level lvl)
Convenience function for translating message levels to strings.
Definition MsgLevel.cxx:19

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30 {
31 return this->msgLevel(lvl);
32 }

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ print()

◆ ReadJetImagePixels()

std::vector< float > AthONNX::JSSMLTool::ReadJetImagePixels ( std::vector< TH2D > Images) const

Definition at line 17 of file JSSMLTool.cxx.

18 {
19
20 int n_rows = m_nPixelsX;
21 int n_cols = m_nPixelsY;
22 int n_colors = m_nPixelsZ;
23
24 std::vector<float> input_tensor_values(n_rows*n_cols*n_colors);
25
26 for(int iRow=0; iRow<n_rows; ++iRow){
27 for(int iColumn=0; iColumn<n_cols; ++iColumn){
28 for(int iColor=0; iColor<n_colors; ++iColor){
29 input_tensor_values[ (n_colors*n_cols*iRow) + iColumn*n_colors + iColor] = Images[iColor].GetBinContent(iRow+1, iColumn+1);
30 }
31 }
32 }
33
34 return input_tensor_values;
35 }
int iColor
Definition DeMoLib.py:1067

◆ ReadJSSInputs()

std::vector< float > AthONNX::JSSMLTool::ReadJSSInputs ( std::map< std::string, double > JSSVars) const

Definition at line 39 of file JSSMLTool.cxx.

40 {
41
42 std::vector<float> input_tensor_values(m_nvars);
43
44 // apply features scaling
45 for(const auto & var : JSSVars){
46 double mean = m_scaler.find(var.first)->second[0];
47 double std = m_scaler.find(var.first)->second[1];
48 JSSVars[var.first] = (var.second - mean) / std;
49 }
50
51 // then dump it to a vector
52 for(int v=0; v<m_nvars; ++v){
53 std::string name = m_JSSInputMap.find(v)->second;
54 input_tensor_values[v] = JSSVars[name];
55 }
56
57 return input_tensor_values;
58 }
std::map< std::string, std::vector< double > > m_scaler
Definition JSSMLTool.h:75
std::map< int, std::string > m_JSSInputMap
Definition JSSMLTool.h:76
void mean(std::vector< double > &bins, std::vector< double > &values, const std::vector< std::string > &files, const std::string &histname, const std::string &tplotname, const std::string &label="")

◆ ReadOutputLabels()

std::vector< int > AthONNX::JSSMLTool::ReadOutputLabels ( ) const

Definition at line 63 of file JSSMLTool.cxx.

64 {
65 std::vector<int> output_tensor_values(1);
66
67 output_tensor_values[0] = 1;
68
69 return output_tensor_values;
70 }

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381 {
382 h.renounce();
384 }
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364 {
366 }

◆ retrieveConstituentsScore() [1/4]

double AthONNX::JSSMLTool::retrieveConstituentsScore ( std::vector< std::vector< float > > constituents) const
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 209 of file JSSMLTool.cxx.

209 {
210
211 // the format of the packed constituents is:
212 // constituents.size() ---> 4, for example, (m pT, eta, phi)
213 // constituents.at(0) ---> number of constituents
214 // the packing can be done for any kind of low level inputs
215 // i.e. PFO/UFO constituents, topo-towers, tracks, etc
216 // they can be concatened one after the other in case of multiple inputs
217
218 //*************************************************************************
219 // Score the model using sample data, and inspect values
220 // loading input data
221
222 std::vector<int> output_tensor_values_ = ReadOutputLabels();
223
224 int testSample = 0;
225
226 //preparing container to hold output data
227 int output_tensor_values = output_tensor_values_[testSample];
228
229 // prepare the inputs
230 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
231 std::vector<Ort::Value> input_tensors;
232 for (long unsigned int i=0; i<constituents.size(); i++) {
233
234 // test
235 std::vector<int64_t> const_dim = {1, static_cast<int64_t>(constituents.at(i).size())};
236
237 input_tensors.push_back(Ort::Value::CreateTensor<float>(
238 memory_info,
239 constituents.at(i).data(), constituents.at(i).size(), const_dim.data(), const_dim.size()
240 )
241 );
242 }
243
244 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), input_tensors.data(), m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
245 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
246
247 // Get pointer to output tensor float values
248 float* floatarr = output_tensors.front().GetTensorMutableData<float>();
249 int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
250
251 // show true label for the test input
252 ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
253 float ConstScore = -999;
254 int max_index = 0;
255 for (int i = 0; i < arrSize; i++){
256 ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
257 ATH_MSG_VERBOSE(" +++ Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
258 if (ConstScore<floatarr[i]){
259 ConstScore = floatarr[i];
260 max_index = i;
261 }
262 }
263 ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
264
265 return ConstScore;
266
267 } // end retrieve constituents score ----
#define ATH_MSG_VERBOSE(x)
std::vector< int > ReadOutputLabels() const
Definition JSSMLTool.cxx:63

◆ retrieveConstituentsScore() [2/4]

double AthONNX::JSSMLTool::retrieveConstituentsScore ( std::vector< std::vector< float > > constituents,
std::vector< std::vector< std::vector< float > > > interactions ) const
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 270 of file JSSMLTool.cxx.

270 {
271
272 // the format of the constituents/interaction variables is:
273 // constituents ---> (nConstituents + nTowers, 7)
274 // interactions ---> (i, j, 4), with i, j in {nConstituents + nTowers}
275 // the packing can be done for any kind of low level inputs
276 // i.e. PFO/UFO constituents, topo-towers, tracks, etc
277 // they can be concatened one after the other in case of multiple inputs
278
279 //*************************************************************************
280 // Score the model using sample data, and inspect values
281 // loading input data
282
283 std::vector<int> output_tensor_values_ = ReadOutputLabels();
284
285 int testSample = 0;
286
287 //preparing container to hold output data
288 int output_tensor_values = output_tensor_values_[testSample];
289
290 // prepare the inputs
291 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
292 std::vector<Ort::Value> input_tensors;
293
294 // unroll the inputs
295 std::vector<float> constituents_values; //(constituents.size()*7);
296 for (long unsigned int i=0; i<constituents.size(); i++) {
297 for (long unsigned int j=0; j<7; j++) {
298 constituents_values.push_back(constituents.at(i).at(j));
299 }
300 }
301
302 std::vector<float> interactions_values; //(interactions.size()*interactions.size()*4);
303 for (long unsigned int i=0; i<interactions.size(); i++) {
304 for (long unsigned int k=0; k<interactions.size(); k++) {
305 for (long unsigned int j=0; j<4; j++) {
306 interactions_values.push_back(interactions.at(i).at(k).at(j));
307 }
308 }
309 }
310
311 std::vector<int64_t> const_dim = {1, static_cast<int64_t>(constituents.size()), 7};
312 input_tensors.push_back(Ort::Value::CreateTensor<float>(
313 memory_info,
314 constituents_values.data(), constituents_values.size(), const_dim.data(), const_dim.size()
315 )
316 );
317
318 std::vector<int64_t> inter_dim = {1, static_cast<int64_t>(constituents.size()), static_cast<int64_t>(constituents.size()), 4};
319 input_tensors.push_back(Ort::Value::CreateTensor<float>(
320 memory_info,
321 interactions_values.data(), interactions_values.size(), inter_dim.data(), inter_dim.size()
322 )
323 );
324
325 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), input_tensors.data(), m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
326 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
327
328 // Get pointer to output tensor float values
329 float* floatarr = output_tensors.front().GetTensorMutableData<float>();
330 int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
331
332 // show true label for the test input
333 ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
334 float ConstScore = -999;
335 int max_index = 0;
336 for (int i = 0; i < arrSize; i++){
337 ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
338 ATH_MSG_VERBOSE(" +++ Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
339 if (ConstScore<floatarr[i]){
340 ConstScore = floatarr[i];
341 max_index = i;
342 }
343 }
344 ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
345
346 return ConstScore;
347
348 } // end retrieve constituents score ----

◆ retrieveConstituentsScore() [3/4]

double AthONNX::JSSMLTool::retrieveConstituentsScore ( std::vector< std::vector< float > > constituents,
std::vector< std::vector< std::vector< float > > > interactions,
std::vector< std::vector< float > > mask ) const
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 351 of file JSSMLTool.cxx.

351 {
352
353 // the format of the constituents/interaction variables is:
354 // constituents ---> (nConstituents, 7)
355 // interactions ---> (i, j, 4), with i, j in {nConstituents}
356 // masks ---> (nConstituents, 1)
357 // the packing can be done for any kind of low level inputs
358 // i.e. PFO/UFO constituents, topo-towers, tracks, etc
359 // they can be concatened one after the other in case of multiple inputs
360
361 //*************************************************************************
362 // Score the model using sample data, and inspect values
363 // loading input data
364
365 std::vector<int> output_tensor_values_ = ReadOutputLabels();
366
367 int testSample = 0;
368
369 //preparing container to hold output data
370 int output_tensor_values = output_tensor_values_[testSample];
371
372 // prepare the inputs
373 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
374 std::vector<Ort::Value> input_tensors;
375
376 // unroll the inputs
377 std::vector<float> constituents_values;
378 for (long unsigned int j=0; j<7; j++) {
379 for (long unsigned int i=0; i<constituents.size(); i++) {
380 constituents_values.push_back(constituents.at(i).at(j));
381 }
382 }
383
384 std::vector<float> interactions_values;
385 for (long unsigned int k=0; k<4; k++) {
386 for (long unsigned int i=0; i<interactions.size(); i++) {
387 for (long unsigned int j=0; j<interactions.size(); j++) {
388 interactions_values.push_back(interactions.at(i).at(j).at(k));
389 }
390 }
391 }
392
393 std::vector<float> mask_values;
394 for (long unsigned int j=0; j<1; j++) {
395 for (long unsigned int i=0; i<mask.size(); i++) {
396 mask_values.push_back(mask.at(i).at(j));
397 }
398 }
399
400 std::vector<int64_t> const_dim = {1, 7, static_cast<int64_t>(constituents.size())};
401 input_tensors.push_back(Ort::Value::CreateTensor<float>(
402 memory_info,
403 constituents_values.data(), constituents_values.size(), const_dim.data(), const_dim.size()
404 )
405 );
406
407 std::vector<int64_t> inter_dim = {1, 4, static_cast<int64_t>(interactions.size()), static_cast<int64_t>(interactions.size())};
408 input_tensors.push_back(Ort::Value::CreateTensor<float>(
409 memory_info,
410 interactions_values.data(), interactions_values.size(), inter_dim.data(), inter_dim.size()
411 )
412 );
413
414 std::vector<int64_t> mask_dim = {1, 1, static_cast<int64_t>(mask.size())};
415 input_tensors.push_back(Ort::Value::CreateTensor<float>(
416 memory_info,
417 mask_values.data(), mask_values.size(), mask_dim.data(), mask_dim.size()
418 )
419 );
420
421 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), input_tensors.data(), m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
422 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
423
424 // Get pointer to output tensor float values
425 float* floatarr = output_tensors.front().GetTensorMutableData<float>();
426 int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
427
428 // show true label for the test input
429 ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
430 float ConstScore = -999;
431 int max_index = 0;
432 for (int i = 0; i < arrSize; i++){
433 ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
434 ATH_MSG_VERBOSE(" +++ Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
435 if (ConstScore<floatarr[i]){
436 ConstScore = floatarr[i];
437 max_index = i;
438 }
439 }
440 ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
441
442 return ConstScore;
443
444 } // end retrieve constituents score ----

◆ retrieveConstituentsScore() [4/4]

double AthONNX::JSSMLTool::retrieveConstituentsScore ( std::vector< TH2D > Images) const
overridevirtual

Function executing the tool for a single event.

Implements AthONNX::IJSSMLTool.

Definition at line 162 of file JSSMLTool.cxx.

162 {
163
164 //*************************************************************************
165 // Score the model using sample data, and inspect values
166
167 // preparing container to hold input data
168 size_t input_tensor_size = m_nPixelsX*m_nPixelsY*m_nPixelsZ;
169 std::vector<float> input_tensor_values(input_tensor_size);
170
171 // loading input data
172 input_tensor_values = ReadJetImagePixels(Images);
173
174 // preparing container to hold output data
175 int testSample = 0;
176 std::vector<int> output_tensor_values_ = ReadOutputLabels();
177 int output_tensor_values = output_tensor_values_[testSample];
178
179 // create input tensor object from data values
180 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
181 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
182 assert(input_tensor.IsTensor());
183
184 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
185 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
186
187 // Get pointer to output tensor float values
188 float* floatarr = output_tensors.front().GetTensorMutableData<float>();
189 int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
190
191 // show true label for the test input
192 ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
193 float ConstScore = -999;
194 int max_index = 0;
195 for (int i = 0; i < arrSize; i++){
196 ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
197 if (ConstScore<floatarr[i]){
198 ConstScore = floatarr[i];
199 max_index = i;
200 }
201 }
202 ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
203
204 return ConstScore;
205
206 } // end retrieve CNN score ----
std::vector< float > ReadJetImagePixels(std::vector< TH2D > Images) const
Definition JSSMLTool.cxx:17

◆ retrieveHighLevelScore()

double AthONNX::JSSMLTool::retrieveHighLevelScore ( std::map< std::string, double > JSSVars) const
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 447 of file JSSMLTool.cxx.

447 {
448
449 //*************************************************************************
450 // Score the model using sample data, and inspect values
451
452 //preparing container to hold input data
453 size_t input_tensor_size = m_nvars;
454 std::vector<float> input_tensor_values(m_nvars);
455
456 // loading input data
457 input_tensor_values = ReadJSSInputs(JSSVars);
458
459 // preparing container to hold output data
460 int testSample = 0;
461 std::vector<int> output_tensor_values_ = ReadOutputLabels();
462 int output_tensor_values = output_tensor_values_[testSample];
463
464 // create input tensor object from data values
465 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
466
467 // we need a multiple tensor input structure for DisCo model
468 Ort::Value input1 = Ort::Value::CreateTensor<float>(memory_info, const_cast<float*>(input_tensor_values.data()), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
469 std::vector<float> empty = {1.};
470 Ort::Value input2 = Ort::Value::CreateTensor<float>(memory_info, empty.data(), 1, m_input_node_dims.data(), m_input_node_dims.size());
471 Ort::Value input3 = Ort::Value::CreateTensor<float>(memory_info, empty.data(), 1, m_input_node_dims.data(), m_input_node_dims.size());
472 Ort::Value input4 = Ort::Value::CreateTensor<float>(memory_info, empty.data(), 1, m_input_node_dims.data(), m_input_node_dims.size());
473 std::vector<Ort::Value> input_tensor;
474 std::vector<int64_t> aaa = {1, m_nvars};
475 input_tensor.emplace_back(
476 Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, aaa.data(), aaa.size())
477 );
478 input_tensor.emplace_back(
479 Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size())
480 );
481 input_tensor.emplace_back(
482 Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size())
483 );
484 input_tensor.emplace_back(
485 Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size())
486 );
487
488 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), input_tensor.data(), m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
489 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
490
491 // Get pointer to output tensor float values
492 float* floatarr = output_tensors.front().GetTensorMutableData<float>();
493 int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
494
495 // show true label for the test input
496 ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
497 float HLScore = -999;
498 int max_index = 0;
499 for (int i = 0; i < arrSize; i++){
500 ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
501 if (HLScore<floatarr[i]){
502 HLScore = floatarr[i];
503 max_index = i;
504 }
505 }
506 ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
507
508 return HLScore;
509
510 } // end retrieve HighLevel score ----
static const Attributes_t empty
std::vector< float > ReadJSSInputs(std::map< std::string, double > JSSVars) const
Definition JSSMLTool.cxx:39
tuple input1

◆ SetScaler()

StatusCode AthONNX::JSSMLTool::SetScaler ( std::map< std::string, std::vector< double > > scaler)
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 513 of file JSSMLTool.cxx.

513 {
515
516 // ToDo:
517 // this will have an overriding config as property
518 m_JSSInputMap = {
519 {0,"pT"}, {1,"CNN"}, {2,"D2"}, {3,"nTracks"}, {4,"ZCut12"},
520 {5,"Tau1_wta"}, {6,"Tau2_wta"}, {7,"Tau3_wta"},
521 {8,"KtDR"}, {9,"Split12"}, {10,"Split23"},
522 {11,"ECF1"}, {12,"ECF2"}, {13,"ECF3"},
523 {14,"Angularity"}, {15,"FoxWolfram0"}, {16,"FoxWolfram2"},
524 {17,"Aplanarity"}, {18,"PlanarFlow"}, {19,"Qw"},
525 };
526 m_nvars = m_JSSInputMap.size();
527
528 return StatusCode::SUCCESS;
529 }
setOverV setNumU setNumY setODFibSel setYDetCS setYLhcCS setXRPotCS setXStatCS setXBeamCS scaler

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in asg::AsgMetadataTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and DerivationFramework::CfAthAlgTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308 {
309 // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310 // << " size: " << m_vhka.size() << endmsg;
311 for (auto &a : m_vhka) {
313 for (auto k : keys) {
314 k->setOwner(this);
315 }
316 }
317 }
std::vector< SG::VarHandleKeyArray * > m_vhka

Member Data Documentation

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_env

std::unique_ptr< Ort::Env > AthONNX::JSSMLTool::m_env

Definition at line 73 of file JSSMLTool.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_input_node_dims

std::vector<int64_t> AthONNX::JSSMLTool::m_input_node_dims
private

Definition at line 86 of file JSSMLTool.h.

◆ m_input_node_names

std::vector<const char*> AthONNX::JSSMLTool::m_input_node_names
private

Definition at line 88 of file JSSMLTool.h.

◆ m_JSSInputMap

std::map<int, std::string> AthONNX::JSSMLTool::m_JSSInputMap

Definition at line 76 of file JSSMLTool.h.

◆ m_labelFileName

std::string AthONNX::JSSMLTool::m_labelFileName
private

Definition at line 83 of file JSSMLTool.h.

◆ m_modelFileName

std::string AthONNX::JSSMLTool::m_modelFileName
private

Name of the model file to load.

Definition at line 81 of file JSSMLTool.h.

◆ m_nPixelsX

int AthONNX::JSSMLTool::m_nPixelsX {}
private

Definition at line 96 of file JSSMLTool.h.

96{}, m_nPixelsY{}, m_nPixelsZ{};

◆ m_nPixelsY

int AthONNX::JSSMLTool::m_nPixelsY {}
private

Definition at line 96 of file JSSMLTool.h.

96{}, m_nPixelsY{}, m_nPixelsZ{};

◆ m_nPixelsZ

int AthONNX::JSSMLTool::m_nPixelsZ {}
private

Definition at line 96 of file JSSMLTool.h.

96{}, m_nPixelsY{}, m_nPixelsZ{};

◆ m_num_input_nodes

size_t AthONNX::JSSMLTool::m_num_input_nodes {}
private

Definition at line 87 of file JSSMLTool.h.

87{};

◆ m_num_output_nodes

size_t AthONNX::JSSMLTool::m_num_output_nodes {}
private

Definition at line 92 of file JSSMLTool.h.

92{};

◆ m_nvars

int AthONNX::JSSMLTool::m_nvars {}
private

Definition at line 98 of file JSSMLTool.h.

98{};

◆ m_output_node_dims

std::vector<int64_t> AthONNX::JSSMLTool::m_output_node_dims
private

Definition at line 91 of file JSSMLTool.h.

◆ m_output_node_names

std::vector<const char*> AthONNX::JSSMLTool::m_output_node_names
private

Definition at line 93 of file JSSMLTool.h.

◆ m_pixelFileName

std::string AthONNX::JSSMLTool::m_pixelFileName
private

Definition at line 82 of file JSSMLTool.h.

◆ m_scaler

std::map<std::string, std::vector<double> > AthONNX::JSSMLTool::m_scaler

Definition at line 75 of file JSSMLTool.h.

◆ m_session

std::unique_ptr< Ort::Session > AthONNX::JSSMLTool::m_session

Definition at line 72 of file JSSMLTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: