ATLAS Offline Software
Public Member Functions | Public Attributes | Protected Member Functions | Private Types | Private Member Functions | Private Attributes | List of all members
AthONNX::JSSMLTool Class Reference

Tool using the ONNX Runtime C++ API to retrive constituents based model for boson jet tagging. More...

#include <JSSMLTool.h>

Inheritance diagram for AthONNX::JSSMLTool:
Collaboration diagram for AthONNX::JSSMLTool:

Public Member Functions

 JSSMLTool (const std::string &name)
 
virtual StatusCode initialize () override
 Function initialising the tool. More...
 
virtual double retrieveConstituentsScore (std::vector< TH2D > Images) const override
 Function executing the tool for a single event. More...
 
virtual double retrieveConstituentsScore (std::vector< std::vector< float >> constituents) const override
 
virtual double retrieveHighLevelScore (std::map< std::string, double > JSSVars) const override
 
std::vector< float > ReadJetImagePixels (std::vector< TH2D > Images) const
 
std::vector< float > ReadJSSInputs (std::map< std::string, double > JSSVars) const
 
std::vector< int > ReadOutputLabels () const
 
StatusCode SetScaler (std::map< std::string, std::vector< double >> scaler) override
 
virtual void print () const
 Print the state of the tool. More...
 
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & evtStore () const
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc. More...
 
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm. More...
 
virtual StatusCode sysStart () override
 Handle START transition. More...
 
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles. More...
 
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles. More...
 
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T > &t)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKey &hndl, const std::string &doc, const SG::VarHandleKeyType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleBase &hndl, const std::string &doc, const SG::VarHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKeyArray &hndArr, const std::string &doc, const SG::VarHandleKeyArrayType &)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc, const SG::NotHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc="none")
 Declare a new Gaudi property. More...
 
void updateVHKA (Gaudi::Details::PropertyBase &)
 
MsgStream & msg () const
 
MsgStream & msg (const MSG::Level lvl) const
 
bool msgLvl (const MSG::Level lvl) const
 

Public Attributes

std::unique_ptr< Ort::Session > m_session
 
std::unique_ptr< Ort::Env > m_env
 
std::map< std::string, std::vector< double > > m_scaler
 
std::map< int, std::string > m_JSSInputMap
 

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution More...
 
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
 
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed. More...
 

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t
 

Private Member Functions

Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyArrayType &)
 specialization for handling Gaudi::Property<SG::VarHandleKeyArray> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleType &)
 specialization for handling Gaudi::Property<SG::VarHandleBase> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &t, const SG::NotHandleType &)
 specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray> More...
 

Private Attributes

std::string m_modelFileName
 Name of the model file to load. More...
 
std::string m_pixelFileName
 
std::string m_labelFileName
 
std::vector< int64_t > m_input_node_dims
 
size_t m_num_input_nodes
 
std::vector< const char * > m_input_node_names
 
std::vector< int64_t > m_output_node_dims
 
size_t m_num_output_nodes
 
std::vector< const char * > m_output_node_names
 
int m_nPixelsX
 
int m_nPixelsY
 
int m_nPixelsZ
 
int m_nvars
 
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default) More...
 
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default) More...
 
std::vector< SG::VarHandleKeyArray * > m_vhka
 
bool m_varHandleArraysDeclared
 

Detailed Description

Tool using the ONNX Runtime C++ API to retrive constituents based model for boson jet tagging.

this is inspired from the general athena example here: https://gitlab.cern.ch/atlas/athena/-/blob/21.2/Control/AthenaExamples/AthExOnnxRuntime/AthExOnnxRuntime/CxxApiAlgorithm.h

this is implementation is an extension from the one done in rel.21 https://gitlab.cern.ch/atlas/athena/-/tree/21.2/Reconstruction/Jet/AthOnnxRuntimeBJT as the plan is to move to use the central ONNX interface the tool has been merged with the BJT

monitoring jira ticket: https://its.cern.ch/jira/browse/ATLJETMET-1893

Author
Antonio Giannini anton.nosp@m.io.g.nosp@m.ianni.nosp@m.ni@c.nosp@m.ern.c.nosp@m.h

Definition at line 48 of file JSSMLTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ JSSMLTool()

AthONNX::JSSMLTool::JSSMLTool ( const std::string &  name)

Definition at line 72 of file JSSMLTool.cxx.

72  :
73  AsgTool(name)
74  {
75  declareProperty("ModelPath", m_modelFileName);
76  declareProperty("nPixelsX", m_nPixelsX);
77  declareProperty("nPixelsY", m_nPixelsY);
78  declareProperty("nPixelsZ", m_nPixelsZ);
79  }

Member Function Documentation

◆ declareGaudiProperty() [1/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyArrayType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKeyArray>

Definition at line 170 of file AthCommonDataStore.h.

172  {
173  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
174  hndl.value(),
175  hndl.documentation());
176 
177  }

◆ declareGaudiProperty() [2/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158  {
159  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
160  hndl.value(),
161  hndl.documentation());
162 
163  }

◆ declareGaudiProperty() [3/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleBase>

Definition at line 184 of file AthCommonDataStore.h.

186  {
187  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
188  hndl.value(),
189  hndl.documentation());
190  }

◆ declareGaudiProperty() [4/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  t,
const SG::NotHandleType  
)
inlineprivateinherited

specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray>

Definition at line 199 of file AthCommonDataStore.h.

200  {
201  return PBASE::declareProperty(t);
202  }

◆ declareProperty() [1/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleBase hndl,
const std::string &  doc,
const SG::VarHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleBase. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 245 of file AthCommonDataStore.h.

249  {
250  this->declare(hndl.vhKey());
251  hndl.vhKey().setOwner(this);
252 
253  return PBASE::declareProperty(name,hndl,doc);
254  }

◆ declareProperty() [2/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKey hndl,
const std::string &  doc,
const SG::VarHandleKeyType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleKey. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 221 of file AthCommonDataStore.h.

225  {
226  this->declare(hndl);
227  hndl.setOwner(this);
228 
229  return PBASE::declareProperty(name,hndl,doc);
230  }

◆ declareProperty() [3/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKeyArray hndArr,
const std::string &  doc,
const SG::VarHandleKeyArrayType  
)
inlineinherited

Definition at line 259 of file AthCommonDataStore.h.

263  {
264 
265  // std::ostringstream ost;
266  // ost << Algorithm::name() << " VHKA declareProp: " << name
267  // << " size: " << hndArr.keys().size()
268  // << " mode: " << hndArr.mode()
269  // << " vhka size: " << m_vhka.size()
270  // << "\n";
271  // debug() << ost.str() << endmsg;
272 
273  hndArr.setOwner(this);
274  m_vhka.push_back(&hndArr);
275 
276  Gaudi::Details::PropertyBase* p = PBASE::declareProperty(name, hndArr, doc);
277  if (p != 0) {
278  p->declareUpdateHandler(&AthCommonDataStore<PBASE>::updateVHKA, this);
279  } else {
280  ATH_MSG_ERROR("unable to call declareProperty on VarHandleKeyArray "
281  << name);
282  }
283 
284  return p;
285 
286  }

◆ declareProperty() [4/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc,
const SG::NotHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This is the generic version, for types that do not derive from SG::VarHandleKey. It just forwards to the base class version of declareProperty.

Definition at line 333 of file AthCommonDataStore.h.

337  {
338  return PBASE::declareProperty(name, property, doc);
339  }

◆ declareProperty() [5/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc = "none" 
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This dispatches to either the generic declareProperty or the one for VarHandle/Key/KeyArray.

Definition at line 352 of file AthCommonDataStore.h.

355  {
356  typedef typename SG::HandleClassifier<T>::type htype;
357  return declareProperty (name, property, doc, htype());
358  }

◆ declareProperty() [6/6]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T > &  t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145  {
146  typedef typename SG::HandleClassifier<T>::type htype;
148  }

◆ detStore()

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

95 { return m_detStore; }

◆ evtStore() [1/2]

ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

85 { return m_evtStore; }

◆ evtStore() [2/2]

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( ) const
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 90 of file AthCommonDataStore.h.

90 { return m_evtStore; }

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase &  ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ getKey()

SG::sgkey_t asg::AsgTool::getKey ( const void *  ptr) const
inherited

Get the (hashed) key of an object that is in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the SG::sgkey_t key for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getName
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The hashed key of the object in the store. If not found, an invalid (zero) key.

Definition at line 119 of file AsgTool.cxx.

119  {
120 
121 #ifdef XAOD_STANDALONE
122  // In case we use @c xAOD::TEvent, we have a direct function call
123  // for this.
124  return evtStore()->event()->getKey( ptr );
125 #else
126  const SG::DataProxy* proxy = evtStore()->proxy( ptr );
127  return ( proxy == nullptr ? 0 : proxy->sgkey() );
128 #endif // XAOD_STANDALONE
129  }

◆ getName()

const std::string & asg::AsgTool::getName ( const void *  ptr) const
inherited

Get the name of an object that is / should be in the event store.

This is a bit of a special one. StoreGateSvc and xAOD::TEvent both provide ways for getting the std::string name for an object that is in the store, based on a bare pointer. But they provide different interfaces for doing so.

In order to allow tools to efficiently perform this operation, they can use this helper function.

See also
asg::AsgTool::getKey
Parameters
ptrThe bare pointer to the object that the event store should know about
Returns
The string name of the object in the store. If not found, an empty string.

Definition at line 106 of file AsgTool.cxx.

106  {
107 
108 #ifdef XAOD_STANDALONE
109  // In case we use @c xAOD::TEvent, we have a direct function call
110  // for this.
111  return evtStore()->event()->getName( ptr );
112 #else
113  const SG::DataProxy* proxy = evtStore()->proxy( ptr );
114  static const std::string dummy = "";
115  return ( proxy == nullptr ? dummy : proxy->name() );
116 #endif // XAOD_STANDALONE
117  }

◆ getProperty()

template<class T >
const T* asg::AsgTool::getProperty ( const std::string &  name) const
inherited

Get one of the tool's properties.

◆ initialize()

StatusCode AthONNX::JSSMLTool::initialize ( )
overridevirtual

Function initialising the tool.

Reimplemented from asg::AsgTool.

Definition at line 82 of file JSSMLTool.cxx.

82  {
83 
84  // Access the service.
85  // Find the model file.
86  ATH_MSG_INFO( "Using model file: " << m_modelFileName );
87 
88  // Set up the ONNX Runtime session.
89  Ort::SessionOptions sessionOptions;
90  sessionOptions.SetIntraOpNumThreads( 1 );
91  sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
92  Ort::AllocatorWithDefaultOptions allocator;
93  m_env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "");
94  m_session = std::make_unique< Ort::Session >( *m_env,
95  m_modelFileName.c_str(),
96  sessionOptions );
97 
98  ATH_MSG_INFO( "Created the ONNX Runtime session" );
99 
100  m_num_input_nodes = m_session->GetInputCount();
102 
103  for( std::size_t i = 0; i < m_num_input_nodes; i++ ) {
104  // print input node names
105  char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
106  ATH_MSG_DEBUG("Input "<<i<<" : "<<" name = "<<input_name);
107  m_input_node_names[i] = input_name;
108  // print input node types
109  Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
110  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
111  ONNXTensorElementDataType type = tensor_info.GetElementType();
112  ATH_MSG_DEBUG("Input "<<i<<" : "<<" type = "<<type);
113 
114  // print input shapes/dims
115  m_input_node_dims = tensor_info.GetShape();
116  ATH_MSG_DEBUG("Input "<<i<<" : num_dims = "<<m_input_node_dims.size());
117  for (std::size_t j = 0; j < m_input_node_dims.size(); j++){
118  if(m_input_node_dims[j]<0)
119  m_input_node_dims[j] =1;
120  ATH_MSG_DEBUG("Input"<<i<<" : dim "<<j<<" = "<<m_input_node_dims[j]);
121  }
122  }
123 
124  m_num_output_nodes = m_session->GetOutputCount();
126 
127  for( std::size_t i = 0; i < m_num_output_nodes; i++ ) {
128  // print output node names
129  char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
130  ATH_MSG_DEBUG("Output "<<i<<" : "<<" name = "<<output_name);
131  m_output_node_names[i] = output_name;
132 
133  Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
134  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
135  ONNXTensorElementDataType type = tensor_info.GetElementType();
136  ATH_MSG_DEBUG("Output "<<i<<" : "<<" type = "<<type);
137 
138  // print output shapes/dims
139  m_output_node_dims = tensor_info.GetShape();
140  ATH_MSG_INFO("Output "<<i<<" : num_dims = "<<m_output_node_dims.size());
141  for (std::size_t j = 0; j < m_output_node_dims.size(); j++){
142  if(m_output_node_dims[j]<0)
143  m_output_node_dims[j] =1;
144  ATH_MSG_INFO("Output"<<i<<" : dim "<<j<<" = "<<m_output_node_dims[j]);
145  }
146  }
147 
148  // Return gracefully.
149  return StatusCode::SUCCESS;
150  } // end initialize ---

◆ inputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ msg() [1/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24  {
25  return this->msgStream();
26  }

◆ msg() [2/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( const MSG::Level  lvl) const
inlineinherited

Definition at line 27 of file AthCommonMsg.h.

27  {
28  return this->msgStream(lvl);
29  }

◆ msg_level_name()

const std::string & asg::AsgTool::msg_level_name ( ) const
inherited

A deprecated function for getting the message level's name.

Instead of using this, weirdly named function, user code should get the string name of the current minimum message level (in case they really need it...), with:

MSG::name( msg().level() )

This function's name doesn't follow the ATLAS coding rules, and as such will be removed in the not too distant future.

Returns
The string name of the current minimum message level that's printed

Definition at line 101 of file AsgTool.cxx.

101  {
102 
103  return MSG::name( msg().level() );
104  }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30  {
31  return this->msgLevel(lvl);
32  }

◆ outputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ print()

void asg::AsgTool::print ( ) const
virtualinherited

◆ ReadJetImagePixels()

std::vector< float > AthONNX::JSSMLTool::ReadJetImagePixels ( std::vector< TH2D >  Images) const

Definition at line 16 of file JSSMLTool.cxx.

17  {
18 
19  int n_rows = m_nPixelsX;
20  int n_cols = m_nPixelsY;
21  int n_colors = m_nPixelsZ;
22 
23  std::vector<float> input_tensor_values(n_rows*n_cols*n_colors);
24 
25  for(int iRow=0; iRow<n_rows; ++iRow){
26  for(int iColumn=0; iColumn<n_cols; ++iColumn){
27  for(int iColor=0; iColor<n_colors; ++iColor){
28  input_tensor_values[ (n_colors*n_cols*iRow) + iColumn*n_colors + iColor] = Images[iColor].GetBinContent(iRow+1, iColumn+1);
29  }
30  }
31  }
32 
33  return input_tensor_values;
34  }

◆ ReadJSSInputs()

std::vector< float > AthONNX::JSSMLTool::ReadJSSInputs ( std::map< std::string, double >  JSSVars) const

Definition at line 38 of file JSSMLTool.cxx.

39  {
40 
41  std::vector<float> input_tensor_values(m_nvars);
42 
43  // apply features scaling
44  for(auto var : JSSVars){
45  double mean = m_scaler.find(var.first)->second[0];
46  double std = m_scaler.find(var.first)->second[1];
47  JSSVars[var.first] = (var.second - mean) / std;
48  }
49 
50  // then dump it to a vector
51  for(int v=0; v<m_nvars; ++v){
52  std::string name = m_JSSInputMap.find(v)->second;
53  input_tensor_values[v] = JSSVars[name];
54  }
55 
56  return input_tensor_values;
57  }

◆ ReadOutputLabels()

std::vector< int > AthONNX::JSSMLTool::ReadOutputLabels ( ) const

Definition at line 62 of file JSSMLTool.cxx.

63  {
64  std::vector<int> output_tensor_values(1);
65 
66  output_tensor_values[0] = 1;
67 
68  return output_tensor_values;
69  }

◆ renounce()

std::enable_if_t<std::is_void_v<std::result_of_t<decltype(&T::renounce)(T)> > && !std::is_base_of_v<SG::VarHandleKeyArray, T> && std::is_base_of_v<Gaudi::DataHandle, T>, void> AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T &  h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381  {
382  h.renounce();
383  PBASE::renounce (h);
384  }

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364  {
365  handlesArray.renounce();
366  }

◆ retrieveConstituentsScore() [1/2]

double AthONNX::JSSMLTool::retrieveConstituentsScore ( std::vector< std::vector< float >>  constituents) const
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 200 of file JSSMLTool.cxx.

200  {
201 
202  // the format of the packed constituents is:
203  // constituents.size() ---> 4, for example, (m pT, eta, phi)
204  // constituents.at(0) ---> number of constituents
205  // the packing can be done for any kind of low level inputs
206  // i.e. PFO/UFO constituents, topo-towers, tracks, etc
207  // they can be concatened one after the other in case of multiple inputs
208 
209  //*************************************************************************
210  // Score the model using sample data, and inspect values
211  // loading input data
212 
213  std::vector<int> output_tensor_values_ = ReadOutputLabels();
214 
215  int testSample = 0;
216 
217  //preparing container to hold output data
218  int output_tensor_values = output_tensor_values_[testSample];
219 
220  // prepare the inputs
221  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
222  std::vector<Ort::Value> input_tensors;
223  for (long unsigned int i=0; i<constituents.size(); i++) {
224 
225  // test
226  std::vector<int64_t> const_dim = {1, static_cast<int64_t>(constituents.at(i).size())};
227 
228  input_tensors.push_back(Ort::Value::CreateTensor<float>(
229  memory_info,
230  constituents.at(i).data(), constituents.at(i).size(), const_dim.data(), const_dim.size()
231  )
232  );
233  }
234 
235  auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), input_tensors.data(), m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
236  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
237 
238  // Get pointer to output tensor float values
239  float* floatarr = output_tensors.front().GetTensorMutableData<float>();
240  int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
241 
242  // show true label for the test input
243  ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
244  float ConstScore = -999;
245  int max_index = 0;
246  for (int i = 0; i < arrSize; i++){
247  ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
248  ATH_MSG_VERBOSE(" +++ Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
249  if (ConstScore<floatarr[i]){
250  ConstScore = floatarr[i];
251  max_index = i;
252  }
253  }
254  ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
255 
256  return ConstScore;
257 
258  } // end retrieve constituents score ----

◆ retrieveConstituentsScore() [2/2]

double AthONNX::JSSMLTool::retrieveConstituentsScore ( std::vector< TH2D >  Images) const
overridevirtual

Function executing the tool for a single event.

Implements AthONNX::IJSSMLTool.

Definition at line 153 of file JSSMLTool.cxx.

153  {
154 
155  //*************************************************************************
156  // Score the model using sample data, and inspect values
157 
158  // preparing container to hold input data
159  size_t input_tensor_size = m_nPixelsX*m_nPixelsY*m_nPixelsZ;
160  std::vector<float> input_tensor_values(input_tensor_size);
161 
162  // loading input data
163  input_tensor_values = ReadJetImagePixels(Images);
164 
165  // preparing container to hold output data
166  int testSample = 0;
167  std::vector<int> output_tensor_values_ = ReadOutputLabels();
168  int output_tensor_values = output_tensor_values_[testSample];
169 
170  // create input tensor object from data values
171  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
172  Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
173  assert(input_tensor.IsTensor());
174 
175  auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
176  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
177 
178  // Get pointer to output tensor float values
179  float* floatarr = output_tensors.front().GetTensorMutableData<float>();
180  int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
181 
182  // show true label for the test input
183  ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
184  float ConstScore = -999;
185  int max_index = 0;
186  for (int i = 0; i < arrSize; i++){
187  ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
188  if (ConstScore<floatarr[i]){
189  ConstScore = floatarr[i];
190  max_index = i;
191  }
192  }
193  ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
194 
195  return ConstScore;
196 
197  } // end retrieve CNN score ----

◆ retrieveHighLevelScore()

double AthONNX::JSSMLTool::retrieveHighLevelScore ( std::map< std::string, double >  JSSVars) const
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 261 of file JSSMLTool.cxx.

261  {
262 
263  //*************************************************************************
264  // Score the model using sample data, and inspect values
265 
266  //preparing container to hold input data
267  size_t input_tensor_size = m_nvars;
268  std::vector<float> input_tensor_values(m_nvars);
269 
270  // loading input data
271  input_tensor_values = ReadJSSInputs(JSSVars);
272 
273  // preparing container to hold output data
274  int testSample = 0;
275  std::vector<int> output_tensor_values_ = ReadOutputLabels();
276  int output_tensor_values = output_tensor_values_[testSample];
277 
278  // create input tensor object from data values
279  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
280 
281  // we need a multiple tensor input structure for DisCo model
282  Ort::Value input1 = Ort::Value::CreateTensor<float>(memory_info, const_cast<float*>(input_tensor_values.data()), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
283  std::vector<float> empty = {1.};
284  Ort::Value input2 = Ort::Value::CreateTensor<float>(memory_info, empty.data(), 1, m_input_node_dims.data(), m_input_node_dims.size());
285  Ort::Value input3 = Ort::Value::CreateTensor<float>(memory_info, empty.data(), 1, m_input_node_dims.data(), m_input_node_dims.size());
286  Ort::Value input4 = Ort::Value::CreateTensor<float>(memory_info, empty.data(), 1, m_input_node_dims.data(), m_input_node_dims.size());
287  std::vector<Ort::Value> input_tensor;
288  std::vector<int64_t> aaa = {1, m_nvars};
289  input_tensor.emplace_back(
290  Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, aaa.data(), aaa.size())
291  );
292  input_tensor.emplace_back(
293  Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size())
294  );
295  input_tensor.emplace_back(
296  Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size())
297  );
298  input_tensor.emplace_back(
299  Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size())
300  );
301 
302  auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), input_tensor.data(), m_input_node_names.size(), m_output_node_names.data(), m_output_node_names.size());
303  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
304 
305  // Get pointer to output tensor float values
306  float* floatarr = output_tensors.front().GetTensorMutableData<float>();
307  int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
308 
309  // show true label for the test input
310  ATH_MSG_DEBUG("Label for the input test data = "<<output_tensor_values);
311  float HLScore = -999;
312  int max_index = 0;
313  for (int i = 0; i < arrSize; i++){
314  ATH_MSG_VERBOSE("Score for class "<<i<<" = "<<floatarr[i]<<std::endl);
315  if (HLScore<floatarr[i]){
316  HLScore = floatarr[i];
317  max_index = i;
318  }
319  }
320  ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
321 
322  return HLScore;
323 
324  } // end retrieve HighLevel score ----

◆ SetScaler()

StatusCode AthONNX::JSSMLTool::SetScaler ( std::map< std::string, std::vector< double >>  scaler)
overridevirtual

Implements AthONNX::IJSSMLTool.

Definition at line 327 of file JSSMLTool.cxx.

327  {
328  m_scaler = scaler;
329 
330  // ToDo:
331  // this will have an overriding config as property
332  m_JSSInputMap = {
333  {0,"pT"}, {1,"CNN"}, {2,"D2"}, {3,"nTracks"}, {4,"ZCut12"},
334  {5,"Tau1_wta"}, {6,"Tau2_wta"}, {7,"Tau3_wta"},
335  {8,"KtDR"}, {9,"Split12"}, {10,"Split23"},
336  {11,"ECF1"}, {12,"ECF2"}, {13,"ECF3"},
337  {14,"Angularity"}, {15,"FoxWolfram0"}, {16,"FoxWolfram2"},
338  {17,"Aplanarity"}, {18,"PlanarFlow"}, {19,"Qw"},
339  };
340  m_nvars = m_JSSInputMap.size();
341 
342  return StatusCode::SUCCESS;
343  }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in DerivationFramework::CfAthAlgTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and asg::AsgMetadataTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase &  )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308  {
309  // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310  // << " size: " << m_vhka.size() << endmsg;
311  for (auto &a : m_vhka) {
312  std::vector<SG::VarHandleKey*> keys = a->keys();
313  for (auto k : keys) {
314  k->setOwner(this);
315  }
316  }
317  }

Member Data Documentation

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_env

std::unique_ptr< Ort::Env > AthONNX::JSSMLTool::m_env

Definition at line 73 of file JSSMLTool.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_input_node_dims

std::vector<int64_t> AthONNX::JSSMLTool::m_input_node_dims
private

Definition at line 86 of file JSSMLTool.h.

◆ m_input_node_names

std::vector<const char*> AthONNX::JSSMLTool::m_input_node_names
private

Definition at line 88 of file JSSMLTool.h.

◆ m_JSSInputMap

std::map<int, std::string> AthONNX::JSSMLTool::m_JSSInputMap

Definition at line 76 of file JSSMLTool.h.

◆ m_labelFileName

std::string AthONNX::JSSMLTool::m_labelFileName
private

Definition at line 83 of file JSSMLTool.h.

◆ m_modelFileName

std::string AthONNX::JSSMLTool::m_modelFileName
private

Name of the model file to load.

Definition at line 81 of file JSSMLTool.h.

◆ m_nPixelsX

int AthONNX::JSSMLTool::m_nPixelsX
private

Definition at line 96 of file JSSMLTool.h.

◆ m_nPixelsY

int AthONNX::JSSMLTool::m_nPixelsY
private

Definition at line 96 of file JSSMLTool.h.

◆ m_nPixelsZ

int AthONNX::JSSMLTool::m_nPixelsZ
private

Definition at line 96 of file JSSMLTool.h.

◆ m_num_input_nodes

size_t AthONNX::JSSMLTool::m_num_input_nodes
private

Definition at line 87 of file JSSMLTool.h.

◆ m_num_output_nodes

size_t AthONNX::JSSMLTool::m_num_output_nodes
private

Definition at line 92 of file JSSMLTool.h.

◆ m_nvars

int AthONNX::JSSMLTool::m_nvars
private

Definition at line 98 of file JSSMLTool.h.

◆ m_output_node_dims

std::vector<int64_t> AthONNX::JSSMLTool::m_output_node_dims
private

Definition at line 91 of file JSSMLTool.h.

◆ m_output_node_names

std::vector<const char*> AthONNX::JSSMLTool::m_output_node_names
private

Definition at line 93 of file JSSMLTool.h.

◆ m_pixelFileName

std::string AthONNX::JSSMLTool::m_pixelFileName
private

Definition at line 82 of file JSSMLTool.h.

◆ m_scaler

std::map<std::string, std::vector<double> > AthONNX::JSSMLTool::m_scaler

Definition at line 75 of file JSSMLTool.h.

◆ m_session

std::unique_ptr< Ort::Session > AthONNX::JSSMLTool::m_session

Definition at line 72 of file JSSMLTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files:
AthONNX::JSSMLTool::m_nPixelsX
int m_nPixelsX
Definition: JSSMLTool.h:96
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
AthONNX::JSSMLTool::m_nPixelsY
int m_nPixelsY
Definition: JSSMLTool.h:96
mean
void mean(std::vector< double > &bins, std::vector< double > &values, const std::vector< std::string > &files, const std::string &histname, const std::string &tplotname, const std::string &label="")
Definition: dependence.cxx:254
StateLessPT_NewConfig.proxy
proxy
Definition: StateLessPT_NewConfig.py:392
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthONNX::JSSMLTool::m_num_output_nodes
size_t m_num_output_nodes
Definition: JSSMLTool.h:92
AthONNX::JSSMLTool::m_modelFileName
std::string m_modelFileName
Name of the model file to load.
Definition: JSSMLTool.h:81
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
xAOD::scaler
setOverV setNumU setNumY setODFibSel setYDetCS setYLhcCS setXRPotCS setXStatCS setXBeamCS scaler
Definition: ALFAData_v1.cxx:66
AthONNX::JSSMLTool::m_env
std::unique_ptr< Ort::Env > m_env
Definition: JSSMLTool.h:73
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
StoreGateSvc_t m_evtStore
Pointer to StoreGate (event store by default)
Definition: AthCommonDataStore.h:390
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
std::vector< SG::VarHandleKeyArray * > m_vhka
Definition: AthCommonDataStore.h:398
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
dbg::ptr
void * ptr(T *p)
Definition: SGImplSvc.cxx:74
empty
bool empty(TH1 *h)
Definition: computils.cxx:295
AthONNX::JSSMLTool::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: JSSMLTool.h:88
AthONNX::JSSMLTool::ReadOutputLabels
std::vector< int > ReadOutputLabels() const
Definition: JSSMLTool.cxx:62
python.iconfTool.models.loaders.level
level
Definition: loaders.py:20
SG::VarHandleKeyArray::setOwner
virtual void setOwner(IDataHandleHolder *o)=0
IDTPMcnv.htype
htype
Definition: IDTPMcnv.py:29
skel.input1
tuple input1
Definition: skel.GENtoEVGEN.py:750
AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore
ServiceHandle< StoreGateSvc > & evtStore()
The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
Definition: AthCommonDataStore.h:85
asg::AsgTool::AsgTool
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition: AsgTool.cxx:58
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
AthCommonDataStore
Definition: AthCommonDataStore.h:52
DeMoLib.iColor
int iColor
NEWSYSTEM defects if system == "NEWSYSTEM": partitions["color"] = {} partitions["list"] = partitions[...
Definition: DeMoLib.py:1067
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
AthONNX::JSSMLTool::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: JSSMLTool.h:93
AthONNX::JSSMLTool::m_scaler
std::map< std::string, std::vector< double > > m_scaler
Definition: JSSMLTool.h:75
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
AthONNX::JSSMLTool::m_nvars
int m_nvars
Definition: JSSMLTool.h:98
python.xAODType.dummy
dummy
Definition: xAODType.py:4
AthONNX::JSSMLTool::m_input_node_dims
std::vector< int64_t > m_input_node_dims
Definition: JSSMLTool.h:86
MSG::name
const std::string & name(Level lvl)
Convenience function for translating message levels to strings.
Definition: MsgLevel.cxx:19
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
StoreGateSvc_t m_detStore
Pointer to StoreGate (detector store by default)
Definition: AthCommonDataStore.h:393
AthONNX::JSSMLTool::m_nPixelsZ
int m_nPixelsZ
Definition: JSSMLTool.h:96
AthONNX::JSSMLTool::m_JSSInputMap
std::map< int, std::string > m_JSSInputMap
Definition: JSSMLTool.h:76
SG::VarHandleKeyArray::renounce
virtual void renounce()=0
SG::HandleClassifier::type
std::conditional< std::is_base_of< SG::VarHandleKeyArray, T >::value, VarHandleKeyArrayType, type2 >::type type
Definition: HandleClassifier.h:54
AthONNX::JSSMLTool::ReadJSSInputs
std::vector< float > ReadJSSInputs(std::map< std::string, double > JSSVars) const
Definition: JSSMLTool.cxx:38
AthONNX::JSSMLTool::m_output_node_dims
std::vector< int64_t > m_output_node_dims
Definition: JSSMLTool.h:91
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
python.PyAthena.v
v
Definition: PyAthena.py:154
a
TList * a
Definition: liststreamerinfos.cxx:10
h
AthONNX::JSSMLTool::m_session
std::unique_ptr< Ort::Session > m_session
Definition: JSSMLTool.h:72
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
AthCommonMsg< AlgTool >::msg
MsgStream & msg() const
Definition: AthCommonMsg.h:24
SG::VarHandleBase::vhKey
SG::VarHandleKey & vhKey()
Return a non-const reference to the HandleKey.
Definition: StoreGate/src/VarHandleBase.cxx:623
AthONNX::JSSMLTool::m_num_input_nodes
size_t m_num_input_nodes
Definition: JSSMLTool.h:87
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
SG::DataProxy
Definition: DataProxy.h:45
AthCommonDataStore::declareGaudiProperty
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>
Definition: AthCommonDataStore.h:156
fitman.k
k
Definition: fitman.py:528
AthONNX::JSSMLTool::ReadJetImagePixels
std::vector< float > ReadJetImagePixels(std::vector< TH2D > Images) const
Definition: JSSMLTool.cxx:16