ATLAS Offline Software
PFEnergyPredictorTool.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef PFENERFYPREDICTORTOOL_H
6 #define PFENERFYPREDICTORTOOL_H
7 
9 #include "GaudiKernel/ServiceHandle.h"
11 #include <fstream> // std::fstream
12 
13 static const InterfaceID IID_PFEnergyPredictorTool("PFEnergyPredictorTool", 1, 0);
14 class eflowRecTrack;
15 
17 {
18 public:
19  PFEnergyPredictorTool(const std::string& type, const std::string& name, const IInterface* parent);
20  virtual StatusCode initialize() override;
21  virtual StatusCode finalize() override;
22 
23  float runOnnxInference(std::vector<float> &tensor) const;
24  static const InterfaceID& interfaceID();
25 
26  float nnEnergyPrediction(const eflowRecTrack *ptr) const;
27  void NormalizeTensor(std::vector<float> &tensor, size_t limit) const;
28 
29 private:
30  //mark as thread safe because we need to call the run function of Session, which is not const
31  //the onnx documentation states that this is thread safe
32  std::unique_ptr<Ort::Session> m_session ATLAS_THREAD_SAFE;
33 
34  std::vector<const char *> m_input_node_names;
35 
36  std::vector<const char *> m_output_node_names;
37 
38  std::vector<int64_t> m_input_node_dims;
39  ServiceHandle<AthOnnx::IOnnxRuntimeSvc> m_svc{this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"};
40  Gaudi::Property<std::string> m_model_filepath{this, "ModelPath", "////"};
41 
43  Gaudi::Property<float> m_cellE_mean{this,"cellE_mean",-2.2852574689444385};
44  Gaudi::Property<float> m_cellE_std{this,"cellE_std",2.0100506557174946};
45  Gaudi::Property<float> m_cellPhi_std{this,"cellPhi_std",0.6916977411859621};
46 
47 };
48 
49 inline const InterfaceID& PFEnergyPredictorTool::interfaceID() { return IID_PFEnergyPredictorTool; }
50 
51 
52 #endif
53 
PFEnergyPredictorTool::finalize
virtual StatusCode finalize() override
Definition: PFEnergyPredictorTool.cxx:323
PFEnergyPredictorTool::m_input_node_dims
std::vector< int64_t > m_input_node_dims
Definition: PFEnergyPredictorTool.h:38
PFEnergyPredictorTool::NormalizeTensor
void NormalizeTensor(std::vector< float > &tensor, size_t limit) const
Definition: PFEnergyPredictorTool.cxx:301
PFEnergyPredictorTool::initialize
virtual StatusCode initialize() override
Definition: PFEnergyPredictorTool.cxx:18
PFEnergyPredictorTool
Definition: PFEnergyPredictorTool.h:17
PFEnergyPredictorTool::m_model_filepath
Gaudi::Property< std::string > m_model_filepath
Definition: PFEnergyPredictorTool.h:40
PFEnergyPredictorTool::m_cellPhi_std
Gaudi::Property< float > m_cellPhi_std
Definition: PFEnergyPredictorTool.h:45
eflowRecTrack
This class extends the information about a xAOD::Track.
Definition: eflowRecTrack.h:45
PFEnergyPredictorTool::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: PFEnergyPredictorTool.h:36
PFEnergyPredictorTool::m_svc
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
Definition: PFEnergyPredictorTool.h:39
PFEnergyPredictorTool::ATLAS_THREAD_SAFE
std::unique_ptr< Ort::Session > m_session ATLAS_THREAD_SAFE
Definition: PFEnergyPredictorTool.h:32
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
test_pyathena.parent
parent
Definition: test_pyathena.py:15
PFEnergyPredictorTool::m_cellE_mean
Gaudi::Property< float > m_cellE_mean
Normalization constants for the inputs to the onnx model.
Definition: PFEnergyPredictorTool.h:43
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
PFEnergyPredictorTool::nnEnergyPrediction
float nnEnergyPrediction(const eflowRecTrack *ptr) const
Definition: PFEnergyPredictorTool.cxx:134
PFEnergyPredictorTool::PFEnergyPredictorTool
PFEnergyPredictorTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition: PFEnergyPredictorTool.cxx:12
IOnnxRuntimeSvc.h
PFEnergyPredictorTool::m_cellE_std
Gaudi::Property< float > m_cellE_std
Definition: PFEnergyPredictorTool.h:44
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
PFEnergyPredictorTool::interfaceID
static const InterfaceID & interfaceID()
Definition: PFEnergyPredictorTool.h:49
AthAlgTool
Definition: AthAlgTool.h:26
updateCoolNtuple.limit
int limit
Definition: updateCoolNtuple.py:45
PFEnergyPredictorTool::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: PFEnergyPredictorTool.h:34
ServiceHandle< AthOnnx::IOnnxRuntimeSvc >
PFEnergyPredictorTool::runOnnxInference
float runOnnxInference(std::vector< float > &tensor) const
Definition: PFEnergyPredictorTool.cxx:93