ATLAS Offline Software
Loading...
Searching...
No Matches
PhotonVertexSelectionTool.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef PhotonVertexSelection_PhotonVertexSelectionTool_h
6#define PhotonVertexSelection_PhotonVertexSelectionTool_h
7
8// Framework includes
9#include "AsgTools/AsgTool.h"
10#include "AsgTools/ToolHandle.h"
15
16// EDM includes
20
21// Local includes
23
24// ONNX Runtime include(s).
25#include <onnxruntime_cxx_api.h>
26
27// Forward declarations
28namespace TMVA { class Reader; }
29
30namespace CP {
31
41 public asg::AsgTool {
42
45
46 private:
48 int m_nVars; // number of input variables
52 std::string m_derivationPrefix;
53
56 this, "EventInfoContName", "EventInfo", "event info key"};
58 this, "VertexContainer", "PrimaryVertices", "Vertex container name"};
59
62 bool m_isTMVA; // boolean to use TMVA, if false assume to be ONNX-based
63 std::string m_TMVAModelFilePath1; //TMVA config file, converted case
64 std::string m_TMVAModelFilePath2; //TMVA config file, unconverted case
65
66 // MVA readers
67 // Ideally these would be const but the main method called, EvaluateMVA, is non const.
68 std::unique_ptr<TMVA::Reader> m_mva1;
69 std::unique_ptr<TMVA::Reader> m_mva2;
70
71 // ONNX
72 // ==================================================
74 Gaudi::Property<int> m_ONNXLogLevel{this, "LogLevel", 2, "ONNX Runtime logging level (0=VERBOSE, 1=INFO, 2=WARNING, 3=ERROR, 4=FATAL)"};
75
77 std::unique_ptr< Ort::Env > m_env;
78
79 // Name of the ONNX model file to load
80 std::string m_ONNXModelFilePath1; //converted case
81 std::string m_ONNXModelFilePath2; //unconverted case
82
83 // declare node vars
84 // converted
86 std::vector<const char*> m_input_node_names1, m_output_node_names1;
87 // unconverted
89 std::vector<const char*> m_input_node_names2, m_output_node_names2;
90
91 // The ONNX session handlers
92 std::shared_ptr<Ort::Session> m_sessionHandle1; //converted case
93 std::shared_ptr<Ort::Session> m_sessionHandle2; //unconverted case
94 // The ONNX memory allocators, for looping
95 Ort::AllocatorWithDefaultOptions m_allocator1; //converted case
96 Ort::AllocatorWithDefaultOptions m_allocator2; //unconverted case
97
98 // ONNX Methods
99 // create ONNX session and return both the allocator and session handler from user-defined onnx env and model
100 std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
101 setONNXSession(Ort::Env& env, const std::string& modelFilePath);
102 // get the input nodes info from the onnx model file (using session and
103 // allocator)
104 std::tuple<std::vector<int64_t>, std::vector<const char*>> getInputNodes(
105 const std::shared_ptr<Ort::Session>& sessionHandle,
106 Ort::AllocatorWithDefaultOptions& allocator);
107 // get the output nodes info from the onnx model file (using session and
108 // allocator)
109 std::tuple<std::vector<int64_t>, std::vector<const char*>> getOutputNodes(
110 const std::shared_ptr<Ort::Session>& sessionHandle,
111 Ort::AllocatorWithDefaultOptions& allocator);
112 // wrapper for getting the NN score from onnx model file (passed as onnx
113 // session)
114 float getScore(int nVars, const std::vector<std::vector<float>>& input_data,
115 const std::shared_ptr<Ort::Session>& sessionHandle,
116 std::vector<int64_t> input_node_dims,
117 std::vector<const char*> input_node_names,
118 std::vector<const char*> output_node_names) const;
119 // ==================================================
120
121 private:
123 TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const;
124
126 static bool sortMLP(const std::pair<const xAOD::Vertex*, float> &a, const std::pair<const xAOD::Vertex*, float> &b);
127
129 StatusCode getVertexImp(const xAOD::EgammaContainer& egammas,
130 const xAOD::Vertex*& vertex, bool ignoreConv,
131 bool noDecorate,
132 std::vector<std::pair<const xAOD::Vertex*, float>>&,
133 yyVtxType&, FailType&) const;
134
135 public:
136 PhotonVertexSelectionTool(const std::string &name);
138
141
143 virtual StatusCode initialize();
144 virtual StatusCode finalize();
145
147
150
152 StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType* failType = nullptr) const;
153
155 StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv = false) const;
156
158 std::vector<std::pair<const xAOD::Vertex*, float>> getVertex(
159 const xAOD::EgammaContainer& egammas,
160 bool ignoreConv = false,
161 bool noDecorate = false,
162 yyVtxType* vtxCase = nullptr,
163 FailType* failType = nullptr) const;
164
166 // Deprecated no longer use this function
167 int getCase() const { return -1; }
168
171
173
182
183 }; // class PhotonVertexSelectionTool
184
185} // namespace CP
186
187
188#endif // PhotonVertexSelection_PhotonVertexSelectionTool_h
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Property holding a SG store/key/clid from which a ReadHandle is made.
static Double_t a
FailType
Declare the interface that the class provides.
std::unique_ptr< Ort::Env > m_env
Global runtime environment for Onnx Runtime.
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
std::vector< const char * > m_output_node_names2
Ort::AllocatorWithDefaultOptions m_allocator2
std::shared_ptr< Ort::Session > m_sessionHandle2
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
std::unique_ptr< TMVA::Reader > m_mva2
std::vector< const char * > m_output_node_names1
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Ort::AllocatorWithDefaultOptions m_allocator1
std::unique_ptr< TMVA::Reader > m_mva1
float getScore(int nVars, const std::vector< std::vector< float > > &input_data, const std::shared_ptr< Ort::Session > &sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
std::vector< const char * > m_input_node_names2
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
virtual StatusCode initialize()
Function initialising the tool.
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
Gaudi::Property< int > m_ONNXLogLevel
ONNX Runtime logging level (0=VERBOSE, 1=INFO, 2=WARNING, 3=ERROR, 4=FATAL).
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float > > &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
std::shared_ptr< Ort::Session > m_sessionHandle1
int getCase() const
Return the last case treated:
PhotonVertexSelectionTool(const std::string &name)
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
int m_nVars
Create a proper constructor for Athena.
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
std::vector< const char * > m_input_node_names1
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Property holding a SG store/key/clid from which a ReadHandle is made.
Property holding a SG store/key/clid/attr name from which a WriteDecorHandle is made.
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47
Select isolated Photons, Electrons and Muons.
PhotonContainer_v1 PhotonContainer
Definition of the current "photon container version".
Vertex_v1 Vertex
Define the latest version of the vertex class.
EgammaContainer_v1 EgammaContainer
Definition of the current "egamma container version".