ATLAS Offline Software
Public Member Functions | Static Public Member Functions | Protected Member Functions | Private Types | Private Member Functions | Private Attributes | List of all members
PFEnergyPredictorTool Class Reference

#include <PFEnergyPredictorTool.h>

Inheritance diagram for PFEnergyPredictorTool:
Collaboration diagram for PFEnergyPredictorTool:

Public Member Functions

 PFEnergyPredictorTool (const std::string &type, const std::string &name, const IInterface *parent)
 
virtual StatusCode initialize () override
 
virtual StatusCode finalize () override
 
float runOnnxInference (std::vector< float > &tensor) const
 
float nnEnergyPrediction (const eflowRecTrack *ptr) const
 
void NormalizeTensor (std::vector< float > &tensor, size_t limit) const
 
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & evtStore () const
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc. More...
 
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm. More...
 
virtual StatusCode sysStart () override
 Handle START transition. More...
 
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles. More...
 
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles. More...
 
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T > &t)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKey &hndl, const std::string &doc, const SG::VarHandleKeyType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleBase &hndl, const std::string &doc, const SG::VarHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKeyArray &hndArr, const std::string &doc, const SG::VarHandleKeyArrayType &)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc, const SG::NotHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc="none")
 Declare a new Gaudi property. More...
 
void updateVHKA (Gaudi::Details::PropertyBase &)
 
MsgStream & msg () const
 
MsgStream & msg (const MSG::Level lvl) const
 
bool msgLvl (const MSG::Level lvl) const
 

Static Public Member Functions

static const InterfaceID & interfaceID ()
 

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution More...
 
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
 
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed. More...
 

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t
 

Private Member Functions

Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyArrayType &)
 specialization for handling Gaudi::Property<SG::VarHandleKeyArray> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleType &)
 specialization for handling Gaudi::Property<SG::VarHandleBase> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &t, const SG::NotHandleType &)
 specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray> More...
 

Private Attributes

std::unique_ptr< Ort::Session > m_session ATLAS_THREAD_SAFE
 
std::vector< const char * > m_input_node_names
 
std::vector< const char * > m_output_node_names
 
std::vector< int64_t > m_input_node_dims
 
ServiceHandle< AthOnnx::IOnnxRuntimeSvcm_svc {this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"}
 
Gaudi::Property< std::string > m_model_filepath {this, "ModelPath", "////"}
 
Gaudi::Property< float > m_cellE_mean {this,"cellE_mean",-2.2852574689444385}
 Normalization constants for the inputs to the onnx model. More...
 
Gaudi::Property< float > m_cellE_std {this,"cellE_std",2.0100506557174946}
 
Gaudi::Property< float > m_cellPhi_std {this,"cellPhi_std",0.6916977411859621}
 
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default) More...
 
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default) More...
 
std::vector< SG::VarHandleKeyArray * > m_vhka
 
bool m_varHandleArraysDeclared
 

Detailed Description

Definition at line 16 of file PFEnergyPredictorTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ PFEnergyPredictorTool()

PFEnergyPredictorTool::PFEnergyPredictorTool ( const std::string &  type,
const std::string &  name,
const IInterface *  parent 
)

Definition at line 12 of file PFEnergyPredictorTool.cxx.

13 {
14 
15 }

Member Function Documentation

◆ declareGaudiProperty() [1/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyArrayType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKeyArray>

Definition at line 170 of file AthCommonDataStore.h.

172  {
173  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
174  hndl.value(),
175  hndl.documentation());
176 
177  }

◆ declareGaudiProperty() [2/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158  {
159  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
160  hndl.value(),
161  hndl.documentation());
162 
163  }

◆ declareGaudiProperty() [3/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleBase>

Definition at line 184 of file AthCommonDataStore.h.

186  {
187  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
188  hndl.value(),
189  hndl.documentation());
190  }

◆ declareGaudiProperty() [4/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  t,
const SG::NotHandleType  
)
inlineprivateinherited

specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray>

Definition at line 199 of file AthCommonDataStore.h.

200  {
201  return PBASE::declareProperty(t);
202  }

◆ declareProperty() [1/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleBase hndl,
const std::string &  doc,
const SG::VarHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleBase. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 245 of file AthCommonDataStore.h.

249  {
250  this->declare(hndl.vhKey());
251  hndl.vhKey().setOwner(this);
252 
253  return PBASE::declareProperty(name,hndl,doc);
254  }

◆ declareProperty() [2/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKey hndl,
const std::string &  doc,
const SG::VarHandleKeyType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleKey. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 221 of file AthCommonDataStore.h.

225  {
226  this->declare(hndl);
227  hndl.setOwner(this);
228 
229  return PBASE::declareProperty(name,hndl,doc);
230  }

◆ declareProperty() [3/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKeyArray hndArr,
const std::string &  doc,
const SG::VarHandleKeyArrayType  
)
inlineinherited

Definition at line 259 of file AthCommonDataStore.h.

263  {
264 
265  // std::ostringstream ost;
266  // ost << Algorithm::name() << " VHKA declareProp: " << name
267  // << " size: " << hndArr.keys().size()
268  // << " mode: " << hndArr.mode()
269  // << " vhka size: " << m_vhka.size()
270  // << "\n";
271  // debug() << ost.str() << endmsg;
272 
273  hndArr.setOwner(this);
274  m_vhka.push_back(&hndArr);
275 
276  Gaudi::Details::PropertyBase* p = PBASE::declareProperty(name, hndArr, doc);
277  if (p != 0) {
278  p->declareUpdateHandler(&AthCommonDataStore<PBASE>::updateVHKA, this);
279  } else {
280  ATH_MSG_ERROR("unable to call declareProperty on VarHandleKeyArray "
281  << name);
282  }
283 
284  return p;
285 
286  }

◆ declareProperty() [4/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc,
const SG::NotHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This is the generic version, for types that do not derive from SG::VarHandleKey. It just forwards to the base class version of declareProperty.

Definition at line 333 of file AthCommonDataStore.h.

337  {
338  return PBASE::declareProperty(name, property, doc);
339  }

◆ declareProperty() [5/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc = "none" 
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This dispatches to either the generic declareProperty or the one for VarHandle/Key/KeyArray.

Definition at line 352 of file AthCommonDataStore.h.

355  {
356  typedef typename SG::HandleClassifier<T>::type htype;
357  return declareProperty (name, property, doc, htype());
358  }

◆ declareProperty() [6/6]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T > &  t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145  {
146  typedef typename SG::HandleClassifier<T>::type htype;
148  }

◆ detStore()

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

95 { return m_detStore; }

◆ evtStore() [1/2]

ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

85 { return m_evtStore; }

◆ evtStore() [2/2]

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( ) const
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 90 of file AthCommonDataStore.h.

90 { return m_evtStore; }

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase &  ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ finalize()

StatusCode PFEnergyPredictorTool::finalize ( )
overridevirtual

Definition at line 323 of file PFEnergyPredictorTool.cxx.

324 {
325  return StatusCode::SUCCESS;
326 }

◆ initialize()

StatusCode PFEnergyPredictorTool::initialize ( )
overridevirtual

Definition at line 18 of file PFEnergyPredictorTool.cxx.

19 {
20  ATH_MSG_DEBUG("Initializing " << name());
21  if(m_model_filepath == "////"){
22  ATH_MSG_WARNING("model not provided tool will not work");
23  return StatusCode::SUCCESS;
24  }
25  ATH_CHECK(m_svc.retrieve());
26  std::string path = m_model_filepath;//Add path resolving code
27 
28  Ort::SessionOptions session_options;
29  Ort::AllocatorWithDefaultOptions allocator;
30  session_options.SetIntraOpNumThreads(1);
31  session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
32  m_session = std::make_unique<Ort::Session>(m_svc->env(), path.c_str(), session_options);
33 
34  ATH_MSG_INFO("Created ONNX runtime session with model " << path);
35 
36  size_t num_input_nodes = m_session->GetInputCount();
37  m_input_node_names.resize(num_input_nodes);
38 
39  for (std::size_t i = 0; i < num_input_nodes; i++) {
40  // print input node names
41  char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
42  ATH_MSG_INFO("Input " << i << " : "
43  << " name= " << input_name);
44  m_input_node_names[i] = input_name;
45  // print input node types
46  Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
47  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
48  ONNXTensorElementDataType type = tensor_info.GetElementType();
49  ATH_MSG_INFO("Input " << i << " : "
50  << " type= " << type);
51 
52  // print input shapes/dims
53  m_input_node_dims = tensor_info.GetShape();
54  m_input_node_dims[1] = 5430/5;
55  ATH_MSG_INFO("Input " << i << " : num_dims= " << m_input_node_dims.size());
56  for (std::size_t j = 0; j < m_input_node_dims.size(); j++) {
57  if (m_input_node_dims[j] < 0) m_input_node_dims[j] = 1;
58  ATH_MSG_INFO("Input " << i << " : dim " << j << "= " << m_input_node_dims[j]);
59  }
60  }
61 
62  // output nodes
63  std::vector<int64_t> output_node_dims;
64  size_t num_output_nodes = m_session->GetOutputCount();
65  ATH_MSG_INFO("Have output nodes " << num_output_nodes);
66  m_output_node_names.resize(num_output_nodes);
67 
68  for (std::size_t i = 0; i < num_output_nodes; i++) {
69  // print output node names
70  char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
71  ATH_MSG_INFO("Output " << i << " : "
72  << " name= " << output_name);
73  m_output_node_names[i] = output_name;
74 
75  Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
76  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
77  ONNXTensorElementDataType type = tensor_info.GetElementType();
78  ATH_MSG_INFO("Output " << i << " : "
79  << " type= " << type);
80 
81  // print output shapes/dims
82  output_node_dims = tensor_info.GetShape();
83  ATH_MSG_INFO("Output " << i << " : num_dims= " << output_node_dims.size());
84  for (std::size_t j = 0; j < output_node_dims.size(); j++) {
85  if (output_node_dims[j] < 0) output_node_dims[j] = 1;
86  ATH_MSG_INFO("Output" << i << " : dim " << j << "= " << output_node_dims[j]);
87  }
88  }
89 
90  return StatusCode::SUCCESS;
91 }

◆ inputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ interfaceID()

const InterfaceID & PFEnergyPredictorTool::interfaceID ( )
inlinestatic

Definition at line 49 of file PFEnergyPredictorTool.h.

49 { return IID_PFEnergyPredictorTool; }

◆ msg() [1/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24  {
25  return this->msgStream();
26  }

◆ msg() [2/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( const MSG::Level  lvl) const
inlineinherited

Definition at line 27 of file AthCommonMsg.h.

27  {
28  return this->msgStream(lvl);
29  }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30  {
31  return this->msgLevel(lvl);
32  }

◆ nnEnergyPrediction()

float PFEnergyPredictorTool::nnEnergyPrediction ( const eflowRecTrack ptr) const

Definition at line 134 of file PFEnergyPredictorTool.cxx.

134  {
135 
136  constexpr std::array<int,19> calo_numbers{1,2,3,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20};
137  constexpr std::array<int,12> fixed_r_numbers = {1,2,3,12,13,14,15,16,17,18,19,20};
138  constexpr std::array<double,12> fixed_r_vals = {1532.18, 1723.89, 1923.02, 2450.00, 2995.00, 3630.00, 3215.00,
139  3630.00, 2246.50, 2450.00, 2870.00, 3480.00
140  };
141  constexpr std::array<int, 7> fixed_z_numbers = {5,6,7,8,9,10,11};
142  constexpr std::array<double, 7> fixed_z_vals = {3790.03, 3983.68, 4195.84, 4461.25, 4869.50, 5424.50, 5905.00};
143  std::unordered_map<int, double> r_calo_dict;//change to flatmap in c++23
144  std::unordered_map<int, double> z_calo_dict;
145  for(size_t i=0; i<fixed_r_vals.size(); i++) r_calo_dict[fixed_r_numbers[i]] = fixed_r_vals[i];
146  for(size_t i=0; i<fixed_z_numbers.size(); i++) z_calo_dict[fixed_z_numbers[i]] = fixed_z_vals[i];
147 
148  std::vector<float> inputnn;
149  inputnn.assign(5430, 0.0);
150  std::vector<eflowRecCluster*> matchedClusters;
151  std::vector<eflowTrackClusterLink*> links = ptr->getClusterMatches();
152 
153  std::array<double, 19> etatotal = getEtaTrackCalo(ptr->getTrackCaloPoints());
154  std::array<double, 19> phitotal = getPhiTrackCalo(ptr->getTrackCaloPoints());
155 
156  const std::array<double, 2> track{ptr->getTrack()->eta(), ptr->getTrack()->phi()};
157 
158  for(auto *clink : links){
159  auto *cell = clink->getCluster()->getCluster();
160  float clusterE = cell->e()*1e-3;
161  float clusterEta = cell->eta();
162 
163  if (clusterE < 0.0 || clusterE > 1e4f || std::abs(clusterEta) > 2.5) continue;
164 
165  constexpr bool cutOnR = false;
166  if(cutOnR){
167  std::array<double, 2> p{clink->getCluster()->getCluster()->eta(), clink->getCluster()->getCluster()->phi()};
168  double part1 = p[0] - track[0];
169  double part2 = p[1] - track[1];
170  while(part1 > M_PI) part1 -= 2*M_PI;
171  while(part1 < -M_PI) part1 += 2*M_PI;
172  while(part2 > M_PI) part2 -= 2*M_PI;
173  while(part2 < -M_PI) part2 += 2*M_PI;
174  double R = std::sqrt(part1 * part1 + part2*part2);
175  if(R >= 1.2) continue;
176  }
177 
178  matchedClusters.push_back(clink->getCluster());
179  }
180 
181  std::vector<std::array<double, 5>> cells;
182 
183  const eflowTrackCaloPoints& trackCaloPoints = ptr->getTrackCaloPoints();
184  bool trk_bool_em[2] = {false,false};
185  std::array<double,2> trk_em_eta = {trackCaloPoints.getEta(eflowCalo::EMB2), trackCaloPoints.getEta(eflowCalo::EME2)};
186  std::array<double,2> trk_em_phi = {trackCaloPoints.getPhi(eflowCalo::EMB2), trackCaloPoints.getPhi(eflowCalo::EME2)};
187  double eta_ctr;
188  double phi_ctr;
189  for(int i =0; i<2; i++) {
190  trk_bool_em[i] = std::abs(trk_em_eta[i]) < 2.5 && std::abs(trk_em_phi[i]) <= M_PI;
191  }
192  int nProj_em = (int)trk_bool_em[0] + (int)trk_bool_em[1];
193 
194  if(nProj_em ==1) {
195  eta_ctr = trk_bool_em[0] ? trk_em_eta[0] : trk_em_eta[1];
196  phi_ctr = trk_bool_em[0] ? trk_em_phi[0] : trk_em_phi[1];
197  } else if(nProj_em==2) {
198  eta_ctr = (trk_em_eta[0] + trk_em_eta[1]) / 2.0;
199  phi_ctr = (trk_em_phi[0] + trk_em_phi[1]) / 2.0;
200  } else {
201  eta_ctr = ptr->getTrack()->eta();
202  phi_ctr = ptr->getTrack()->phi();
203  }
204 
205 
206 
207  for(auto *cptr : matchedClusters){
208  auto *clustlink = cptr->getCluster();
209 
210  for(auto it_cell = clustlink->cell_begin(); it_cell != clustlink->cell_end(); it_cell++){
211  const CaloCell* cell = (*it_cell);
212  float cellE = cell->e()*(it_cell.weight())*1e-3f;
213  if(cellE < 0.005) continue;//Cut from ntuple maker
214  const auto *theDDE=it_cell->caloDDE();
215  double cx=theDDE->x();
216  double cy=theDDE->y();
217 
218  cells.emplace_back( std::array<double, 5> { cellE,
219  theDDE->eta() - eta_ctr,
220  theDDE->phi() - phi_ctr,
221  std::hypot(cx,cy), //rperp
222  0.0 } );
223  }
224  }
225 
226 
227  std::vector<bool> trk_bool(calo_numbers.size(), false);
228  std::vector<std::array<double,4>> trk_full(calo_numbers.size());
229  for(size_t j=0; j<phitotal.size(); j++) {
230  int cnum = calo_numbers[j];
231  double eta = etatotal[j];
232  double phi = phitotal[j];
233  if(std::abs(eta) < 2.5 && std::abs(phi) <= M_PI) {
234  trk_bool[j] = true;
235  trk_full[j][0] = eta;
236  trk_full[j][1] = phi;
237  trk_full[j][3] = cnum;
238  double rPerp =-99999;
239  if(auto itr = r_calo_dict.find(cnum); itr != r_calo_dict.end()) rPerp = itr->second;
240  else if(auto itr = z_calo_dict.find(cnum); itr != z_calo_dict.end())
241  {
242  double z = itr->second;
243  if(eta != 0.0){
244  double aeta = std::abs(eta);
245  rPerp = z*2.*std::exp(aeta)/(std::exp(2.0*aeta)-1.0);
246  }else rPerp =0.0; //Check if this makes sense
247  } else {
248  throw std::runtime_error("Calo sample num not found in dicts..");
249  }
250  trk_full[j][2] = rPerp;
251  } else {
252  trk_full[j].fill(0.0);
253  }
254  }
255  double trackP = std::abs(1. / ptr->getTrack()->qOverP()) * 1e-3;
256  int trk_proj_num = std::accumulate(trk_bool.begin(), trk_bool.end(), 0);
257  if(trk_proj_num ==0) {
258  trk_proj_num =1;
259  std::array<double,5> trk_arr{};
260 
261  trk_arr[0] = trackP;
262  trk_arr[1] = ptr->getTrack()->eta() - eta_ctr;
263  trk_arr[2] = ptr->getTrack()->phi() - phi_ctr;
264  trk_arr[3] = 1532.18; // just place it in EMB1
265  trk_arr[4] = 1.;
266 
267  cells.emplace_back(trk_arr);
268  } else {
269  for(size_t i =0; i<calo_numbers.size(); i++) {
270  if(!trk_bool[i]) continue;
271  std::array<double,5> trk_arr{};
272  trk_arr[0]= trackP/double(trk_proj_num);
273  trk_arr[1]= trk_full[i][0] - eta_ctr;
274  trk_arr[2]= trk_full[i][1] - phi_ctr;
275  trk_arr[3]= trk_full[i][2];
276  trk_arr[4]= 1.;
277 
278  cells.emplace_back(trk_arr);
279  }
280  }
281 
282  int index = 0;
283  for(auto &in : cells){
284  std::copy(in.begin(), in.end(), inputnn.begin() + index);
285  index+=5;
286  if(index >= static_cast<int>(inputnn.size()-4)) {
287  ATH_MSG_WARNING("Data exceeded tensor size");
288  break;
289  }
290  }
291 
292  //Normalization prior to training
293  NormalizeTensor(inputnn, cells.size() * 5 );
294 
295  float predictedEnergy = exp(runOnnxInference(inputnn)) * 1000.0;//Correct to MeV units
296  ATH_MSG_DEBUG("NN Predicted energy " << predictedEnergy);
297  return predictedEnergy;
298 
299 }

◆ NormalizeTensor()

void PFEnergyPredictorTool::NormalizeTensor ( std::vector< float > &  tensor,
size_t  limit 
) const

Definition at line 301 of file PFEnergyPredictorTool.cxx.

301  {
302  size_t i=0;
303  for(i =0;i<limit;i+=5){
304  auto &f = inputnn[i+3];
305  if(f!= 0.0f) f/= 3630.f;
306  auto &e = inputnn[i+0];
307  if(e!= 0.0f){
308  e = std::log(e);
309  e = (e - m_cellE_mean)/m_cellE_std;
310  }
311  auto &eta = inputnn[i+1];
312  if(eta!= 0.0) eta /= 0.7f;
313  auto &phi = inputnn[i+2];
314  if(phi!= 0.0) phi /= m_cellPhi_std;
315  }
316  if(i> inputnn.size()){
317  ATH_MSG_ERROR("Index exceeded tensor MEMORY CORRUPTION");
318  }
319 }

◆ outputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ renounce()

std::enable_if_t<std::is_void_v<std::result_of_t<decltype(&T::renounce)(T)> > && !std::is_base_of_v<SG::VarHandleKeyArray, T> && std::is_base_of_v<Gaudi::DataHandle, T>, void> AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T &  h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381  {
382  h.renounce();
383  PBASE::renounce (h);
384  }

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364  {
365  handlesArray.renounce();
366  }

◆ runOnnxInference()

float PFEnergyPredictorTool::runOnnxInference ( std::vector< float > &  tensor) const

Definition at line 93 of file PFEnergyPredictorTool.cxx.

93  {
94  using std::endl;
95  using std::cout;
96  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
97  auto input_tensor_size = tensor.size();
98 
99  Ort::Value input_tensor =
100  Ort::Value::CreateTensor<float>(memory_info, tensor.data(), input_tensor_size,
101  m_input_node_dims.data(), m_input_node_dims.size());
102 
103  auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(),
105 
106  const float *output_score_array = output_tensors.front().GetTensorData<float>();
107 
108  // Binary classification - the score is just the first element of the output tensor
109  float output_score = output_score_array[0];
110 
111  return output_score;
112 }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in DerivationFramework::CfAthAlgTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and asg::AsgMetadataTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase &  )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308  {
309  // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310  // << " size: " << m_vhka.size() << endmsg;
311  for (auto &a : m_vhka) {
312  std::vector<SG::VarHandleKey*> keys = a->keys();
313  for (auto k : keys) {
314  k->setOwner(this);
315  }
316  }
317  }

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::unique_ptr<Ort::Session> m_session PFEnergyPredictorTool::ATLAS_THREAD_SAFE
private

Definition at line 32 of file PFEnergyPredictorTool.h.

◆ m_cellE_mean

Gaudi::Property<float> PFEnergyPredictorTool::m_cellE_mean {this,"cellE_mean",-2.2852574689444385}
private

Normalization constants for the inputs to the onnx model.

Definition at line 43 of file PFEnergyPredictorTool.h.

◆ m_cellE_std

Gaudi::Property<float> PFEnergyPredictorTool::m_cellE_std {this,"cellE_std",2.0100506557174946}
private

Definition at line 44 of file PFEnergyPredictorTool.h.

◆ m_cellPhi_std

Gaudi::Property<float> PFEnergyPredictorTool::m_cellPhi_std {this,"cellPhi_std",0.6916977411859621}
private

Definition at line 45 of file PFEnergyPredictorTool.h.

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_input_node_dims

std::vector<int64_t> PFEnergyPredictorTool::m_input_node_dims
private

Definition at line 38 of file PFEnergyPredictorTool.h.

◆ m_input_node_names

std::vector<const char *> PFEnergyPredictorTool::m_input_node_names
private

Definition at line 34 of file PFEnergyPredictorTool.h.

◆ m_model_filepath

Gaudi::Property<std::string> PFEnergyPredictorTool::m_model_filepath {this, "ModelPath", "////"}
private

Definition at line 40 of file PFEnergyPredictorTool.h.

◆ m_output_node_names

std::vector<const char *> PFEnergyPredictorTool::m_output_node_names
private

Definition at line 36 of file PFEnergyPredictorTool.h.

◆ m_svc

ServiceHandle<AthOnnx::IOnnxRuntimeSvc> PFEnergyPredictorTool::m_svc {this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"}
private

Definition at line 39 of file PFEnergyPredictorTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files:
RunTileCalibRec.cells
cells
Definition: RunTileCalibRec.py:271
python.CaloRecoConfig.f
f
Definition: CaloRecoConfig.py:127
eflowTrackCaloPoints
This class stores a map of calorimeter layers and track parameters (the result of the track extrapola...
Definition: eflowTrackCaloPoints.h:30
ReadCellNoiseFromCool.cell
cell
Definition: ReadCellNoiseFromCool.py:53
python.PerfMonSerializer.p
def p
Definition: PerfMonSerializer.py:743
athena.path
path
python interpreter configuration --------------------------------------—
Definition: athena.py:126
phi
Scalar phi() const
phi method
Definition: AmgMatrixBasePlugin.h:64
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
IDTPM::R
float R(const U &p)
Definition: TrackParametersHelper.h:101
CaloCellPos2Ntuple.int
int
Definition: CaloCellPos2Ntuple.py:24
PFEnergyPredictorTool::m_input_node_dims
std::vector< int64_t > m_input_node_dims
Definition: PFEnergyPredictorTool.h:38
eflowCalo::EMB2
@ EMB2
Definition: eflowCaloRegions.h:45
PFEnergyPredictorTool::NormalizeTensor
void NormalizeTensor(std::vector< float > &tensor, size_t limit) const
Definition: PFEnergyPredictorTool.cxx:301
eta
Scalar eta() const
pseudorapidity method
Definition: AmgMatrixBasePlugin.h:79
index
Definition: index.py:1
AthCommonDataStore::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
xAOD::TrackParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TrackParticle_v1.cxx:77
accumulate
bool accumulate(AccumulateMap &map, std::vector< module_t > const &modules, FPGATrackSimMatrixAccumulator const &acc)
Accumulates an accumulator (e.g.
Definition: FPGATrackSimMatrixAccumulator.cxx:22
M_PI
#define M_PI
Definition: ActiveFraction.h:11
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
StoreGateSvc_t m_evtStore
Pointer to StoreGate (event store by default)
Definition: AthCommonDataStore.h:390
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
std::vector< SG::VarHandleKeyArray * > m_vhka
Definition: AthCommonDataStore.h:398
PFEnergyPredictorTool::m_model_filepath
Gaudi::Property< std::string > m_model_filepath
Definition: PFEnergyPredictorTool.h:40
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
drawFromPickle.exp
exp
Definition: drawFromPickle.py:36
SG::VarHandleKeyArray::setOwner
virtual void setOwner(IDataHandleHolder *o)=0
PFEnergyPredictorTool::m_cellPhi_std
Gaudi::Property< float > m_cellPhi_std
Definition: PFEnergyPredictorTool.h:45
part1
Definition: part1.py:1
eflowRecTrack::getClusterMatches
const std::vector< eflowTrackClusterLink * > & getClusterMatches() const
Definition: eflowRecTrack.h:66
PlotCalibFromCool.cx
cx
Definition: PlotCalibFromCool.py:666
AthCommonDataStore
Definition: AthCommonDataStore.h:52
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
eflowTrackCaloPoints::getPhi
double getPhi(eflowCalo::LAYER layer) const
Definition: eflowTrackCaloPoints.h:44
PFEnergyPredictorTool::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: PFEnergyPredictorTool.h:36
lumiFormat.i
int i
Definition: lumiFormat.py:92
z
#define z
PFEnergyPredictorTool::m_svc
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
Definition: PFEnergyPredictorTool.h:39
DMTest::links
links
Definition: CLinks_v1.cxx:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
eflowCalo::EME2
@ EME2
Definition: eflowCaloRegions.h:46
test_pyathena.parent
parent
Definition: test_pyathena.py:15
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
getPhiTrackCalo
std::array< double, 19 > getPhiTrackCalo(const eflowTrackCaloPoints &trackCaloPoints)
Definition: PFEnergyPredictorTool.cxx:124
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
StoreGateSvc_t m_detStore
Pointer to StoreGate (detector store by default)
Definition: AthCommonDataStore.h:393
xAOD::double
double
Definition: CompositeParticle_v1.cxx:159
AthAlgTool::AthAlgTool
AthAlgTool()
Default constructor:
part2
Definition: part2.py:1
SG::VarHandleKeyArray::renounce
virtual void renounce()=0
SG::HandleClassifier::type
std::conditional< std::is_base_of< SG::VarHandleKeyArray, T >::value, VarHandleKeyArrayType, type2 >::type type
Definition: HandleClassifier.h:54
xAOD::TrackParticle_v1::qOverP
float qOverP() const
Returns the parameter.
PFEnergyPredictorTool::m_cellE_mean
Gaudi::Property< float > m_cellE_mean
Normalization constants for the inputs to the onnx model.
Definition: PFEnergyPredictorTool.h:43
eflowRecTrack::getTrack
const xAOD::TrackParticle * getTrack() const
Definition: eflowRecTrack.h:53
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
eflowRecTrack::getTrackCaloPoints
const eflowTrackCaloPoints & getTrackCaloPoints() const
Definition: eflowRecTrack.h:55
DiTauMassTools::MaxHistStrategyV2::e
e
Definition: PhysicsAnalysis/TauID/DiTauMassTools/DiTauMassTools/HelperFunctions.h:26
a
TList * a
Definition: liststreamerinfos.cxx:10
h
CaloCell
Data object for each calorimeter readout cell.
Definition: CaloCell.h:57
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
PlotCalibFromCool.cy
cy
Definition: PlotCalibFromCool.py:667
PFEnergyPredictorTool::m_cellE_std
Gaudi::Property< float > m_cellE_std
Definition: PFEnergyPredictorTool.h:44
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
SG::VarHandleBase::vhKey
SG::VarHandleKey & vhKey()
Return a non-const reference to the HandleKey.
Definition: StoreGate/src/VarHandleBase.cxx:616
declareProperty
#define declareProperty(n, p, h)
Definition: BaseFakeBkgTool.cxx:15
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:790
xAOD::track
@ track
Definition: TrackingPrimitives.h:512
calibdata.copy
bool copy
Definition: calibdata.py:27
eflowTrackCaloPoints::getEta
double getEta(eflowCalo::LAYER layer) const
Definition: eflowTrackCaloPoints.h:43
updateCoolNtuple.limit
int limit
Definition: updateCoolNtuple.py:45
AthCommonDataStore::declareGaudiProperty
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>
Definition: AthCommonDataStore.h:156
getEtaTrackCalo
std::array< double, 19 > getEtaTrackCalo(const eflowTrackCaloPoints &trackCaloPoints)
Definition: PFEnergyPredictorTool.cxx:114
xAOD::TrackParticle_v1::phi
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
fitman.k
k
Definition: fitman.py:528
PFEnergyPredictorTool::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: PFEnergyPredictorTool.h:34
PFEnergyPredictorTool::runOnnxInference
float runOnnxInference(std::vector< float > &tensor) const
Definition: PFEnergyPredictorTool.cxx:93