ATLAS Offline Software
PhotonVertexSelectionTool.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef PhotonVertexSelection_PhotonVertexSelectionTool_h
6 #define PhotonVertexSelection_PhotonVertexSelectionTool_h
7 
8 // Framework includes
9 #include "AsgTools/AsgTool.h"
10 #include "AsgTools/ToolHandle.h"
14 
15 // EDM includes
19 
20 // Local includes
22 
23 // ONNX Runtime include(s).
24 #include <onnxruntime_cxx_api.h>
25 
26 // Forward declarations
27 namespace TMVA { class Reader; }
28 
29 namespace CP {
30 
40  public asg::AsgTool {
41 
44 
45  private:
47  int m_nVars; // number of input variables
48  float m_convPtCut;
50  std::string m_vertexContainerName;
51  std::string m_derivationPrefix;
52 
55  this, "EventInfoContName", "EventInfo", "event info key"};
57  this, "VertexContainer", "PrimaryVertices", "Vertex container name"};
58 
61  bool m_isTMVA; // boolean to use TMVA, if false assume to be ONNX-based
62  std::string m_TMVAModelFilePath1; //TMVA config file, converted case
63  std::string m_TMVAModelFilePath2; //TMVA config file, unconverted case
64 
65  // MVA readers
66  // Ideally these would be const but the main method called, EvaluateMVA, is non const.
67  std::unique_ptr<TMVA::Reader> m_mva1;
68  std::unique_ptr<TMVA::Reader> m_mva2;
69 
70  // ONNX
71  // ==================================================
72  // Name of the ONNX model file to load
73  std::string m_ONNXModelFilePath1; //converted case
74  std::string m_ONNXModelFilePath2; //unconverted case
75 
76  // declare node vars
77  // converted
78  std::vector<int64_t> m_input_node_dims1, m_output_node_dims1;
79  std::vector<const char*> m_input_node_names1, m_output_node_names1;
80  // unconverted
81  std::vector<int64_t> m_input_node_dims2, m_output_node_dims2;
82  std::vector<const char*> m_input_node_names2, m_output_node_names2;
83 
84  // The ONNX session handlers
85  std::shared_ptr<Ort::Session> m_sessionHandle1; //converted case
86  std::shared_ptr<Ort::Session> m_sessionHandle2; //unconverted case
87  // The ONNX memory allocators, for looping
88  Ort::AllocatorWithDefaultOptions m_allocator1; //converted case
89  Ort::AllocatorWithDefaultOptions m_allocator2; //unconverted case
90 
91  // ONNX Methods
92  // create ONNX session and return both the allocator and session handler from user-defined onnx env and model
93  std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
94  setONNXSession(Ort::Env& env, const std::string& modelFilePath);
95  // get the input nodes info from the onnx model file (using session and
96  // allocator)
97  std::tuple<std::vector<int64_t>, std::vector<const char*>> getInputNodes(
98  const std::shared_ptr<Ort::Session>& sessionHandle,
99  Ort::AllocatorWithDefaultOptions& allocator);
100  // get the output nodes info from the onnx model file (using session and
101  // allocator)
102  std::tuple<std::vector<int64_t>, std::vector<const char*>> getOutputNodes(
103  const std::shared_ptr<Ort::Session>& sessionHandle,
104  Ort::AllocatorWithDefaultOptions& allocator);
105  // wrapper for getting the NN score from onnx model file (passed as onnx
106  // session)
107  float getScore(int nVars, const std::vector<std::vector<float>>& input_data,
108  const std::shared_ptr<Ort::Session>& sessionHandle,
109  std::vector<int64_t> input_node_dims,
110  std::vector<const char*> input_node_names,
111  std::vector<const char*> output_node_names) const;
112  // ==================================================
113 
114  private:
116  TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const;
117 
119  static bool sortMLP(const std::pair<const xAOD::Vertex*, float> &a, const std::pair<const xAOD::Vertex*, float> &b);
120 
123  const xAOD::Vertex*& vertex, bool ignoreConv,
124  bool noDecorate,
125  std::vector<std::pair<const xAOD::Vertex*, float>>&,
126  yyVtxType&, FailType&) const;
127 
128  public:
129  PhotonVertexSelectionTool(const std::string &name);
131 
134 
136  virtual StatusCode initialize();
137 
139 
142 
144  StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType* failType = nullptr) const;
145 
147  StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv = false) const;
148 
150  std::vector<std::pair<const xAOD::Vertex*, float>> getVertex(
151  const xAOD::EgammaContainer& egammas,
152  bool ignoreConv = false,
153  bool noDecorate = false,
154  yyVtxType* vtxCase = nullptr,
155  FailType* failType = nullptr) const;
156 
158  // Deprecated no longer use this function
159  int getCase() const { return -1; }
160 
162  const xAOD::Vertex* getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const;
163 
165 
167  { this, "DeltaPhiKey", m_vertexContainer, "deltaPhi" };
169  { this, "DeltaZKey", m_vertexContainer, "deltaZ" };
171  { this, "SumPt2Key", m_vertexContainer, "sumPt2" };
173  { this, "SumPtKey", m_vertexContainer, "sumPt" };
174 
175  }; // class PhotonVertexSelectionTool
176 
177 } // namespace CP
178 
179 
180 #endif // PhotonVertexSelection_PhotonVertexSelectionTool_h
CP::PhotonVertexSelectionTool::PhotonVertexSelectionTool
PhotonVertexSelectionTool(const std::string &name)
Definition: PhotonVertexSelectionTool.cxx:54
SG::WriteDecorHandleKey
Property holding a SG store/key/clid/attr name from which a WriteDecorHandle is made.
Definition: StoreGate/StoreGate/WriteDecorHandleKey.h:89
CP::PhotonVertexSelectionTool::setONNXSession
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
Definition: PhotonVertexSelectionTool.cxx:204
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath2
std::string m_TMVAModelFilePath2
Definition: PhotonVertexSelectionTool.h:63
CP::PhotonVertexSelectionTool::m_derivationPrefix
std::string m_derivationPrefix
Definition: PhotonVertexSelectionTool.h:51
CP::PhotonVertexSelectionTool::getPrimaryVertexFromConv
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
Definition: PhotonVertexSelectionTool.cxx:574
CP::PhotonVertexSelectionTool::m_sumPt2Key
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
Definition: PhotonVertexSelectionTool.h:171
CurrentContext.h
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath1
std::string m_TMVAModelFilePath1
Definition: PhotonVertexSelectionTool.h:62
CP::PhotonVertexSelectionTool::getEgammaVector
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
Definition: PhotonVertexSelectionTool.cxx:628
CP::PhotonVertexSelectionTool::m_vertexContainer
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
Definition: PhotonVertexSelectionTool.h:56
CP::PhotonVertexSelectionTool::m_input_node_dims1
std::vector< int64_t > m_input_node_dims1
Definition: PhotonVertexSelectionTool.h:78
CP::PhotonVertexSelectionTool::m_sessionHandle2
std::shared_ptr< Ort::Session > m_sessionHandle2
Definition: PhotonVertexSelectionTool.h:86
CP::PhotonVertexSelectionTool::m_mva1
std::unique_ptr< TMVA::Reader > m_mva1
Definition: PhotonVertexSelectionTool.h:67
SG::ReadHandleKey< xAOD::EventInfo >
CP::PhotonVertexSelectionTool::sortMLP
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
Definition: PhotonVertexSelectionTool.cxx:569
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath1
std::string m_ONNXModelFilePath1
Definition: PhotonVertexSelectionTool.h:73
CP
Select isolated Photons, Electrons and Muons.
Definition: Control/xAODRootAccess/xAODRootAccess/TEvent.h:49
CP::PhotonVertexSelectionTool::m_sessionHandle1
std::shared_ptr< Ort::Session > m_sessionHandle1
Definition: PhotonVertexSelectionTool.h:85
CP::PhotonVertexSelectionTool::m_output_node_dims2
std::vector< int64_t > m_output_node_dims2
Definition: PhotonVertexSelectionTool.h:81
CP::PhotonVertexSelectionTool::m_eventInfo
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
Definition: PhotonVertexSelectionTool.h:54
CP::IPhotonVertexSelectionTool
Definition: IPhotonVertexSelectionTool.h:26
CP::PhotonVertexSelectionTool
Implementation for the photon vertex selection tool.
Definition: PhotonVertexSelectionTool.h:40
CP::PhotonVertexSelectionTool::initialize
virtual StatusCode initialize()
Function initialising the tool.
Definition: PhotonVertexSelectionTool.cxx:224
CP::PhotonVertexSelectionTool::m_allocator1
Ort::AllocatorWithDefaultOptions m_allocator1
Definition: PhotonVertexSelectionTool.h:88
CP::PhotonVertexSelectionTool::getVertexImp
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float >> &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
Definition: PhotonVertexSelectionTool.cxx:367
CP::IPhotonVertexSelectionTool::yyVtxType
yyVtxType
Definition: IPhotonVertexSelectionTool.h:44
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
CP::PhotonVertexSelectionTool::m_vertexContainerName
std::string m_vertexContainerName
Definition: PhotonVertexSelectionTool.h:50
CP::PhotonVertexSelectionTool::m_sumPtKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
Definition: PhotonVertexSelectionTool.h:173
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
CP::PhotonVertexSelectionTool::m_convPtCut
float m_convPtCut
Definition: PhotonVertexSelectionTool.h:48
CP::PhotonVertexSelectionTool::getOutputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:171
CP::PhotonVertexSelectionTool::getCase
int getCase() const
Return the last case treated:
Definition: PhotonVertexSelectionTool.h:159
CP::PhotonVertexSelectionTool::~PhotonVertexSelectionTool
virtual ~PhotonVertexSelectionTool()
CP::IPhotonVertexSelectionTool::FailType
FailType
Declare the interface that the class provides.
Definition: IPhotonVertexSelectionTool.h:33
LHEF::Reader
Pythia8::Reader Reader
Definition: Prophecy4fMerger.cxx:11
CP::PhotonVertexSelectionTool::m_nVars
int m_nVars
Create a proper constructor for Athena.
Definition: PhotonVertexSelectionTool.h:47
IPhotonVertexSelectionTool.h
CP::PhotonVertexSelectionTool::getInputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:137
DataVector
Derived DataVector<T>.
Definition: DataVector.h:794
CP::PhotonVertexSelectionTool::m_deltaPhiKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
Definition: PhotonVertexSelectionTool.h:167
CP::PhotonVertexSelectionTool::m_input_node_dims2
std::vector< int64_t > m_input_node_dims2
Definition: PhotonVertexSelectionTool.h:81
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath2
std::string m_ONNXModelFilePath2
Definition: PhotonVertexSelectionTool.h:74
CP::PhotonVertexSelectionTool::m_output_node_names1
std::vector< const char * > m_output_node_names1
Definition: PhotonVertexSelectionTool.h:79
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:76
CP::PhotonVertexSelectionTool::getVertex
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
Definition: PhotonVertexSelectionTool.cxx:357
CP::PhotonVertexSelectionTool::m_mva2
std::unique_ptr< TMVA::Reader > m_mva2
Definition: PhotonVertexSelectionTool.h:68
CP::PhotonVertexSelectionTool::m_output_node_names2
std::vector< const char * > m_output_node_names2
Definition: PhotonVertexSelectionTool.h:82
CP::PhotonVertexSelectionTool::m_isTMVA
bool m_isTMVA
Definition: PhotonVertexSelectionTool.h:61
EventInfo.h
CP::PhotonVertexSelectionTool::m_doSkipByZSigma
bool m_doSkipByZSigma
Definition: PhotonVertexSelectionTool.h:49
WriteDecorHandleKey.h
CP::PhotonVertexSelectionTool::m_allocator2
Ort::AllocatorWithDefaultOptions m_allocator2
Definition: PhotonVertexSelectionTool.h:89
Trk::vertex
@ vertex
Definition: MeasurementType.h:21
VertexContainer.h
a
TList * a
Definition: liststreamerinfos.cxx:10
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
CP::PhotonVertexSelectionTool::decorateInputs
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
Definition: PhotonVertexSelectionTool.cxx:282
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
CP::PhotonVertexSelectionTool::getScore
float getScore(int nVars, const std::vector< std::vector< float >> &input_data, const std::shared_ptr< Ort::Session > &sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
Definition: PhotonVertexSelectionTool.cxx:89
CP::PhotonVertexSelectionTool::m_input_node_names1
std::vector< const char * > m_input_node_names1
Definition: PhotonVertexSelectionTool.h:79
CP::PhotonVertexSelectionTool::m_output_node_dims1
std::vector< int64_t > m_output_node_dims1
Definition: PhotonVertexSelectionTool.h:78
ToolHandle.h
CP::PhotonVertexSelectionTool::m_deltaZKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Definition: PhotonVertexSelectionTool.h:169
AsgTool.h
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32
CP::PhotonVertexSelectionTool::m_input_node_names2
std::vector< const char * > m_input_node_names2
Definition: PhotonVertexSelectionTool.h:82
PhotonContainer.h
TMVA
Definition: PhotonVertexSelectionTool.h:27