ATLAS Offline Software
PhotonVertexSelectionTool.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef PhotonVertexSelection_PhotonVertexSelectionTool_h
6 #define PhotonVertexSelection_PhotonVertexSelectionTool_h
7 
8 // Framework includes
9 #include "AsgTools/AsgTool.h"
10 #include "AsgTools/ToolHandle.h"
12 
13 // EDM includes
17 
18 // Local includes
20 
21 // ONNX Runtime include(s).
22 #include <onnxruntime_cxx_api.h>
23 
24 // Forward declarations
25 namespace TMVA { class Reader; }
26 
27 namespace CP {
28 
38  public asg::AsgTool {
39 
42 
43  private:
45  int m_nVars; // number of input variables
46  float m_convPtCut;
48  std::string m_vertexContainerName;
49  std::string m_derivationPrefix;
50 
52  SG::ReadHandleKey<xAOD::EventInfo> m_eventInfo{this, "EventInfoContName", "EventInfo", "event info key"};
53  SG::ReadHandleKey<xAOD::VertexContainer> m_vertexContainer {this, "VertexContainer", "PrimaryVertices", "Vertex container name"};
54 
57  bool m_isTMVA; // boolean to use TMVA, if false assume to be ONNX-based
58  std::string m_TMVAModelFilePath1; //TMVA config file, converted case
59  std::string m_TMVAModelFilePath2; //TMVA config file, unconverted case
60 
61  // MVA readers
62  // Ideally these would be const but the main method called, EvaluateMVA, is non const.
63  std::unique_ptr<TMVA::Reader> m_mva1;
64  std::unique_ptr<TMVA::Reader> m_mva2;
65 
66  // ONNX
67  // ==================================================
68  // Name of the ONNX model file to load
69  std::string m_ONNXModelFilePath1; //converted case
70  std::string m_ONNXModelFilePath2; //unconverted case
71 
72  // declare node vars
73  // converted
74  std::vector<int64_t> m_input_node_dims1, m_output_node_dims1;
75  std::vector<const char*> m_input_node_names1, m_output_node_names1;
76  // unconverted
77  std::vector<int64_t> m_input_node_dims2, m_output_node_dims2;
78  std::vector<const char*> m_input_node_names2, m_output_node_names2;
79 
80  // The ONNX session handlers
81  std::shared_ptr<Ort::Session> m_sessionHandle1; //converted case
82  std::shared_ptr<Ort::Session> m_sessionHandle2; //unconverted case
83  // The ONNX memory allocators, for looping
84  Ort::AllocatorWithDefaultOptions m_allocator1; //converted case
85  Ort::AllocatorWithDefaultOptions m_allocator2; //unconverted case
86 
87  // ONNX Methods
88  // create ONNX session and return both the allocator and session handler from user-defined onnx env and model
89  std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions> setONNXSession(Ort::Env& env, std::string modelFilePath);
90  // get the input nodes info from the onnx model file (using session and allocator)
91  std::tuple<std::vector<int64_t>, std::vector<const char*>> getInputNodes( const std::shared_ptr<Ort::Session> sessionHandle, Ort::AllocatorWithDefaultOptions& allocator);
92  // get the output nodes info from the onnx model file (using session and allocator)
93  std::tuple<std::vector<int64_t>, std::vector<const char*>> getOutputNodes(const std::shared_ptr<Ort::Session> sessionHandle, Ort::AllocatorWithDefaultOptions& allocator);
94  // wrapper for getting the NN score from onnx model file (passed as onnx session)
95  float getScore(int nVars, std::vector<std::vector<float>> input_data, const std::shared_ptr<Ort::Session> sessionHandle, std::vector<int64_t> input_node_dims, std::vector<const char*> input_node_names, std::vector<const char*> output_node_names) const;
96  // ==================================================
97 
98  private:
100  TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const;
101 
103  static bool sortMLP(const std::pair<const xAOD::Vertex*, float> &a, const std::pair<const xAOD::Vertex*, float> &b);
104 
105 
107  StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv, bool noDecorate, std::vector<std::pair<const xAOD::Vertex*, float> >&, yyVtxType& , FailType& ) const;
108 
109 
110  public:
111  PhotonVertexSelectionTool(const std::string &name);
113 
116 
118  virtual StatusCode initialize();
119 
121 
124 
126  StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType* failType = nullptr) const;
127 
129  StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv = false) const;
130 
132  std::vector<std::pair<const xAOD::Vertex*, float> > getVertex(const xAOD::EgammaContainer &egammas, bool ignoreConv = false, bool noDecorate = false, yyVtxType* vtxCase = nullptr, FailType* failType = nullptr) const;
133 
135  // Deprecated no longer use this function
136  int getCase() const { return -1; }
137 
139  const xAOD::Vertex* getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const;
140 
142 
143  }; // class PhotonVertexSelectionTool
144 
145 } // namespace CP
146 
147 
148 #endif // PhotonVertexSelection_PhotonVertexSelectionTool_h
CP::PhotonVertexSelectionTool::PhotonVertexSelectionTool
PhotonVertexSelectionTool(const std::string &name)
Definition: PhotonVertexSelectionTool.cxx:54
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath2
std::string m_TMVAModelFilePath2
Definition: PhotonVertexSelectionTool.h:59
CP::PhotonVertexSelectionTool::setONNXSession
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, std::string modelFilePath)
Definition: PhotonVertexSelectionTool.cxx:186
CP::PhotonVertexSelectionTool::m_derivationPrefix
std::string m_derivationPrefix
Definition: PhotonVertexSelectionTool.h:49
CP::PhotonVertexSelectionTool::getPrimaryVertexFromConv
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
Definition: PhotonVertexSelectionTool.cxx:533
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath1
std::string m_TMVAModelFilePath1
Definition: PhotonVertexSelectionTool.h:58
CP::PhotonVertexSelectionTool::getEgammaVector
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
Definition: PhotonVertexSelectionTool.cxx:587
CP::PhotonVertexSelectionTool::m_vertexContainer
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
Definition: PhotonVertexSelectionTool.h:53
CP::PhotonVertexSelectionTool::m_input_node_dims1
std::vector< int64_t > m_input_node_dims1
Definition: PhotonVertexSelectionTool.h:74
CP::PhotonVertexSelectionTool::m_sessionHandle2
std::shared_ptr< Ort::Session > m_sessionHandle2
Definition: PhotonVertexSelectionTool.h:82
CP::PhotonVertexSelectionTool::m_mva1
std::unique_ptr< TMVA::Reader > m_mva1
Definition: PhotonVertexSelectionTool.h:63
SG::ReadHandleKey< xAOD::EventInfo >
CP::PhotonVertexSelectionTool::sortMLP
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
Definition: PhotonVertexSelectionTool.cxx:528
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath1
std::string m_ONNXModelFilePath1
Definition: PhotonVertexSelectionTool.h:69
CP
Select isolated Photons, Electrons and Muons.
Definition: Control/xAODRootAccess/xAODRootAccess/TEvent.h:48
CP::PhotonVertexSelectionTool::m_sessionHandle1
std::shared_ptr< Ort::Session > m_sessionHandle1
Definition: PhotonVertexSelectionTool.h:81
CP::PhotonVertexSelectionTool::m_output_node_dims2
std::vector< int64_t > m_output_node_dims2
Definition: PhotonVertexSelectionTool.h:77
CP::PhotonVertexSelectionTool::m_eventInfo
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
Definition: PhotonVertexSelectionTool.h:52
CP::IPhotonVertexSelectionTool
Definition: IPhotonVertexSelectionTool.h:26
CP::PhotonVertexSelectionTool
Implementation for the photon vertex selection tool.
Definition: PhotonVertexSelectionTool.h:38
CP::PhotonVertexSelectionTool::initialize
virtual StatusCode initialize()
Function initialising the tool.
Definition: PhotonVertexSelectionTool.cxx:205
CP::PhotonVertexSelectionTool::m_allocator1
Ort::AllocatorWithDefaultOptions m_allocator1
Definition: PhotonVertexSelectionTool.h:84
CP::IPhotonVertexSelectionTool::yyVtxType
yyVtxType
Definition: IPhotonVertexSelectionTool.h:44
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
CP::PhotonVertexSelectionTool::m_vertexContainerName
std::string m_vertexContainerName
Definition: PhotonVertexSelectionTool.h:48
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
CP::PhotonVertexSelectionTool::m_convPtCut
float m_convPtCut
Definition: PhotonVertexSelectionTool.h:46
CP::PhotonVertexSelectionTool::getCase
int getCase() const
Return the last case treated:
Definition: PhotonVertexSelectionTool.h:136
CP::PhotonVertexSelectionTool::~PhotonVertexSelectionTool
virtual ~PhotonVertexSelectionTool()
CP::IPhotonVertexSelectionTool::FailType
FailType
Declare the interface that the class provides.
Definition: IPhotonVertexSelectionTool.h:33
LHEF::Reader
Pythia8::Reader Reader
Definition: Prophecy4fMerger.cxx:11
CP::PhotonVertexSelectionTool::m_nVars
int m_nVars
Create a proper constructor for Athena.
Definition: PhotonVertexSelectionTool.h:45
CP::PhotonVertexSelectionTool::getOutputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:156
IPhotonVertexSelectionTool.h
CP::PhotonVertexSelectionTool::getVertexImp
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float > > &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
Definition: PhotonVertexSelectionTool.cxx:328
DataVector
Derived DataVector<T>.
Definition: DataVector.h:581
CP::PhotonVertexSelectionTool::m_input_node_dims2
std::vector< int64_t > m_input_node_dims2
Definition: PhotonVertexSelectionTool.h:77
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath2
std::string m_ONNXModelFilePath2
Definition: PhotonVertexSelectionTool.h:70
CP::PhotonVertexSelectionTool::m_output_node_names1
std::vector< const char * > m_output_node_names1
Definition: PhotonVertexSelectionTool.h:75
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
CP::PhotonVertexSelectionTool::getVertex
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
Definition: PhotonVertexSelectionTool.cxx:318
CP::PhotonVertexSelectionTool::m_mva2
std::unique_ptr< TMVA::Reader > m_mva2
Definition: PhotonVertexSelectionTool.h:64
CP::PhotonVertexSelectionTool::m_output_node_names2
std::vector< const char * > m_output_node_names2
Definition: PhotonVertexSelectionTool.h:78
CP::PhotonVertexSelectionTool::m_isTMVA
bool m_isTMVA
Definition: PhotonVertexSelectionTool.h:57
EventInfo.h
CP::PhotonVertexSelectionTool::m_doSkipByZSigma
bool m_doSkipByZSigma
Definition: PhotonVertexSelectionTool.h:47
CP::PhotonVertexSelectionTool::m_allocator2
Ort::AllocatorWithDefaultOptions m_allocator2
Definition: PhotonVertexSelectionTool.h:85
Trk::vertex
@ vertex
Definition: MeasurementType.h:21
VertexContainer.h
a
TList * a
Definition: liststreamerinfos.cxx:10
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
CP::PhotonVertexSelectionTool::decorateInputs
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
Definition: PhotonVertexSelectionTool.cxx:250
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
CP::PhotonVertexSelectionTool::m_input_node_names1
std::vector< const char * > m_input_node_names1
Definition: PhotonVertexSelectionTool.h:75
CP::PhotonVertexSelectionTool::m_output_node_dims1
std::vector< int64_t > m_output_node_dims1
Definition: PhotonVertexSelectionTool.h:74
CP::PhotonVertexSelectionTool::getInputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:125
ToolHandle.h
AsgTool.h
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32
CP::PhotonVertexSelectionTool::m_input_node_names2
std::vector< const char * > m_input_node_names2
Definition: PhotonVertexSelectionTool.h:78
PhotonContainer.h
CP::PhotonVertexSelectionTool::getScore
float getScore(int nVars, std::vector< std::vector< float >> input_data, const std::shared_ptr< Ort::Session > sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
Definition: PhotonVertexSelectionTool.cxx:89
TMVA
Definition: PhotonVertexSelectionTool.h:25