ATLAS Offline Software
Loading...
Searching...
No Matches
PhotonVertexSelectionTool.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef PhotonVertexSelection_PhotonVertexSelectionTool_h
6#define PhotonVertexSelection_PhotonVertexSelectionTool_h
7
8// Framework includes
9#include "AsgTools/AsgTool.h"
10#include "AsgTools/ToolHandle.h"
14
15// EDM includes
19
20// Local includes
22
23// ONNX Runtime include(s).
24#include <onnxruntime_cxx_api.h>
25
26// Forward declarations
27namespace TMVA { class Reader; }
28
29namespace CP {
30
40 public asg::AsgTool {
41
44
45 private:
47 int m_nVars; // number of input variables
51 std::string m_derivationPrefix;
52
55 this, "EventInfoContName", "EventInfo", "event info key"};
57 this, "VertexContainer", "PrimaryVertices", "Vertex container name"};
58
61 bool m_isTMVA; // boolean to use TMVA, if false assume to be ONNX-based
62 std::string m_TMVAModelFilePath1; //TMVA config file, converted case
63 std::string m_TMVAModelFilePath2; //TMVA config file, unconverted case
64
65 // MVA readers
66 // Ideally these would be const but the main method called, EvaluateMVA, is non const.
67 std::unique_ptr<TMVA::Reader> m_mva1;
68 std::unique_ptr<TMVA::Reader> m_mva2;
69
70 // ONNX
71 // ==================================================
72 // Name of the ONNX model file to load
73 std::string m_ONNXModelFilePath1; //converted case
74 std::string m_ONNXModelFilePath2; //unconverted case
75
76 // declare node vars
77 // converted
79 std::vector<const char*> m_input_node_names1, m_output_node_names1;
80 // unconverted
82 std::vector<const char*> m_input_node_names2, m_output_node_names2;
83
84 // The ONNX session handlers
85 std::shared_ptr<Ort::Session> m_sessionHandle1; //converted case
86 std::shared_ptr<Ort::Session> m_sessionHandle2; //unconverted case
87 // The ONNX memory allocators, for looping
88 Ort::AllocatorWithDefaultOptions m_allocator1; //converted case
89 Ort::AllocatorWithDefaultOptions m_allocator2; //unconverted case
90
91 // ONNX Methods
92 // create ONNX session and return both the allocator and session handler from user-defined onnx env and model
93 std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
94 setONNXSession(Ort::Env& env, const std::string& modelFilePath);
95 // get the input nodes info from the onnx model file (using session and
96 // allocator)
97 std::tuple<std::vector<int64_t>, std::vector<const char*>> getInputNodes(
98 const std::shared_ptr<Ort::Session>& sessionHandle,
99 Ort::AllocatorWithDefaultOptions& allocator);
100 // get the output nodes info from the onnx model file (using session and
101 // allocator)
102 std::tuple<std::vector<int64_t>, std::vector<const char*>> getOutputNodes(
103 const std::shared_ptr<Ort::Session>& sessionHandle,
104 Ort::AllocatorWithDefaultOptions& allocator);
105 // wrapper for getting the NN score from onnx model file (passed as onnx
106 // session)
107 float getScore(int nVars, const std::vector<std::vector<float>>& input_data,
108 const std::shared_ptr<Ort::Session>& sessionHandle,
109 std::vector<int64_t> input_node_dims,
110 std::vector<const char*> input_node_names,
111 std::vector<const char*> output_node_names) const;
112 // ==================================================
113
114 private:
116 TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const;
117
119 static bool sortMLP(const std::pair<const xAOD::Vertex*, float> &a, const std::pair<const xAOD::Vertex*, float> &b);
120
122 StatusCode getVertexImp(const xAOD::EgammaContainer& egammas,
123 const xAOD::Vertex*& vertex, bool ignoreConv,
124 bool noDecorate,
125 std::vector<std::pair<const xAOD::Vertex*, float>>&,
126 yyVtxType&, FailType&) const;
127
128 public:
129 PhotonVertexSelectionTool(const std::string &name);
131
134
136 virtual StatusCode initialize();
137
139
142
144 StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType* failType = nullptr) const;
145
147 StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv = false) const;
148
150 std::vector<std::pair<const xAOD::Vertex*, float>> getVertex(
151 const xAOD::EgammaContainer& egammas,
152 bool ignoreConv = false,
153 bool noDecorate = false,
154 yyVtxType* vtxCase = nullptr,
155 FailType* failType = nullptr) const;
156
158 // Deprecated no longer use this function
159 int getCase() const { return -1; }
160
163
165
174
175 }; // class PhotonVertexSelectionTool
176
177} // namespace CP
178
179
180#endif // PhotonVertexSelection_PhotonVertexSelectionTool_h
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Property holding a SG store/key/clid from which a ReadHandle is made.
static Double_t a
FailType
Declare the interface that the class provides.
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
std::vector< const char * > m_output_node_names2
Ort::AllocatorWithDefaultOptions m_allocator2
std::shared_ptr< Ort::Session > m_sessionHandle2
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
std::unique_ptr< TMVA::Reader > m_mva2
std::vector< const char * > m_output_node_names1
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Ort::AllocatorWithDefaultOptions m_allocator1
std::unique_ptr< TMVA::Reader > m_mva1
float getScore(int nVars, const std::vector< std::vector< float > > &input_data, const std::shared_ptr< Ort::Session > &sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
std::vector< const char * > m_input_node_names2
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
virtual StatusCode initialize()
Function initialising the tool.
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float > > &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
std::shared_ptr< Ort::Session > m_sessionHandle1
int getCase() const
Return the last case treated:
PhotonVertexSelectionTool(const std::string &name)
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
int m_nVars
Create a proper constructor for Athena.
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
std::vector< const char * > m_input_node_names1
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Property holding a SG store/key/clid from which a ReadHandle is made.
Property holding a SG store/key/clid/attr name from which a WriteDecorHandle is made.
Base class for the dual-use tool implementation classes.
Definition AsgTool.h:47
Select isolated Photons, Electrons and Muons.
PhotonContainer_v1 PhotonContainer
Definition of the current "photon container version".
Vertex_v1 Vertex
Define the latest version of the vertex class.
EgammaContainer_v1 EgammaContainer
Definition of the current "egamma container version".