Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
PhotonVertexSelectionTool.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef PhotonVertexSelection_PhotonVertexSelectionTool_h
6 #define PhotonVertexSelection_PhotonVertexSelectionTool_h
7 
8 // Framework includes
9 #include "AsgTools/AsgTool.h"
10 #include "AsgTools/ToolHandle.h"
14 
15 // EDM includes
19 
20 // Local includes
22 
23 // ONNX Runtime include(s).
24 #include <onnxruntime_cxx_api.h>
25 
26 // Forward declarations
27 namespace TMVA { class Reader; }
28 
29 namespace CP {
30 
40  public asg::AsgTool {
41 
44 
45  private:
47  int m_nVars; // number of input variables
48  float m_convPtCut;
50  std::string m_vertexContainerName;
51  std::string m_derivationPrefix;
52 
54  SG::ReadHandleKey<xAOD::EventInfo> m_eventInfo{this, "EventInfoContName", "EventInfo", "event info key"};
55  SG::ReadHandleKey<xAOD::VertexContainer> m_vertexContainer {this, "VertexContainer", "PrimaryVertices", "Vertex container name"};
56 
59  bool m_isTMVA; // boolean to use TMVA, if false assume to be ONNX-based
60  std::string m_TMVAModelFilePath1; //TMVA config file, converted case
61  std::string m_TMVAModelFilePath2; //TMVA config file, unconverted case
62 
63  // MVA readers
64  // Ideally these would be const but the main method called, EvaluateMVA, is non const.
65  std::unique_ptr<TMVA::Reader> m_mva1;
66  std::unique_ptr<TMVA::Reader> m_mva2;
67 
68  // ONNX
69  // ==================================================
70  // Name of the ONNX model file to load
71  std::string m_ONNXModelFilePath1; //converted case
72  std::string m_ONNXModelFilePath2; //unconverted case
73 
74  // declare node vars
75  // converted
76  std::vector<int64_t> m_input_node_dims1, m_output_node_dims1;
77  std::vector<const char*> m_input_node_names1, m_output_node_names1;
78  // unconverted
79  std::vector<int64_t> m_input_node_dims2, m_output_node_dims2;
80  std::vector<const char*> m_input_node_names2, m_output_node_names2;
81 
82  // The ONNX session handlers
83  std::shared_ptr<Ort::Session> m_sessionHandle1; //converted case
84  std::shared_ptr<Ort::Session> m_sessionHandle2; //unconverted case
85  // The ONNX memory allocators, for looping
86  Ort::AllocatorWithDefaultOptions m_allocator1; //converted case
87  Ort::AllocatorWithDefaultOptions m_allocator2; //unconverted case
88 
89  // ONNX Methods
90  // create ONNX session and return both the allocator and session handler from user-defined onnx env and model
91  std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions> setONNXSession(Ort::Env& env, const std::string& modelFilePath);
92  // get the input nodes info from the onnx model file (using session and allocator)
93  std::tuple<std::vector<int64_t>, std::vector<const char*>> getInputNodes( const std::shared_ptr<Ort::Session> sessionHandle, Ort::AllocatorWithDefaultOptions& allocator);
94  // get the output nodes info from the onnx model file (using session and allocator)
95  std::tuple<std::vector<int64_t>, std::vector<const char*>> getOutputNodes(const std::shared_ptr<Ort::Session> sessionHandle, Ort::AllocatorWithDefaultOptions& allocator);
96  // wrapper for getting the NN score from onnx model file (passed as onnx session)
97  float getScore(int nVars, const std::vector<std::vector<float>>& input_data, const std::shared_ptr<Ort::Session> sessionHandle, std::vector<int64_t> input_node_dims, std::vector<const char*> input_node_names, std::vector<const char*> output_node_names) const;
98  // ==================================================
99 
100  private:
102  TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const;
103 
105  static bool sortMLP(const std::pair<const xAOD::Vertex*, float> &a, const std::pair<const xAOD::Vertex*, float> &b);
106 
107 
109  StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv, bool noDecorate, std::vector<std::pair<const xAOD::Vertex*, float> >&, yyVtxType& , FailType& ) const;
110 
111 
112  public:
113  PhotonVertexSelectionTool(const std::string &name);
115 
118 
120  virtual StatusCode initialize();
121 
123 
126 
128  StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType* failType = nullptr) const;
129 
131  StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex* &vertex, bool ignoreConv = false) const;
132 
134  std::vector<std::pair<const xAOD::Vertex*, float> > getVertex(const xAOD::EgammaContainer &egammas, bool ignoreConv = false, bool noDecorate = false, yyVtxType* vtxCase = nullptr, FailType* failType = nullptr) const;
135 
137  // Deprecated no longer use this function
138  int getCase() const { return -1; }
139 
141  const xAOD::Vertex* getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const;
142 
144 
146  { this, "DeltaPhiKey", m_vertexContainer, "deltaPhi" };
148  { this, "DeltaZKey", m_vertexContainer, "deltaZ" };
150  { this, "SumPt2Key", m_vertexContainer, "sumPt2" };
152  { this, "SumPtKey", m_vertexContainer, "sumPt" };
153 
154  }; // class PhotonVertexSelectionTool
155 
156 } // namespace CP
157 
158 
159 #endif // PhotonVertexSelection_PhotonVertexSelectionTool_h
CP::PhotonVertexSelectionTool::PhotonVertexSelectionTool
PhotonVertexSelectionTool(const std::string &name)
Definition: PhotonVertexSelectionTool.cxx:55
SG::WriteDecorHandleKey
Property holding a SG store/key/clid/attr name from which a WriteDecorHandle is made.
Definition: StoreGate/StoreGate/WriteDecorHandleKey.h:89
CP::PhotonVertexSelectionTool::setONNXSession
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
Definition: PhotonVertexSelectionTool.cxx:187
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath2
std::string m_TMVAModelFilePath2
Definition: PhotonVertexSelectionTool.h:61
CP::PhotonVertexSelectionTool::m_derivationPrefix
std::string m_derivationPrefix
Definition: PhotonVertexSelectionTool.h:51
CP::PhotonVertexSelectionTool::getPrimaryVertexFromConv
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
Definition: PhotonVertexSelectionTool.cxx:551
CP::PhotonVertexSelectionTool::m_sumPt2Key
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
Definition: PhotonVertexSelectionTool.h:150
CurrentContext.h
CP::PhotonVertexSelectionTool::getScore
float getScore(int nVars, const std::vector< std::vector< float >> &input_data, const std::shared_ptr< Ort::Session > sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
Definition: PhotonVertexSelectionTool.cxx:90
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath1
std::string m_TMVAModelFilePath1
Definition: PhotonVertexSelectionTool.h:60
CP::PhotonVertexSelectionTool::getEgammaVector
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
Definition: PhotonVertexSelectionTool.cxx:605
CP::PhotonVertexSelectionTool::m_vertexContainer
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
Definition: PhotonVertexSelectionTool.h:55
CP::PhotonVertexSelectionTool::m_input_node_dims1
std::vector< int64_t > m_input_node_dims1
Definition: PhotonVertexSelectionTool.h:76
CP::PhotonVertexSelectionTool::m_sessionHandle2
std::shared_ptr< Ort::Session > m_sessionHandle2
Definition: PhotonVertexSelectionTool.h:84
CP::PhotonVertexSelectionTool::m_mva1
std::unique_ptr< TMVA::Reader > m_mva1
Definition: PhotonVertexSelectionTool.h:65
SG::ReadHandleKey< xAOD::EventInfo >
CP::PhotonVertexSelectionTool::sortMLP
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
Definition: PhotonVertexSelectionTool.cxx:546
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath1
std::string m_ONNXModelFilePath1
Definition: PhotonVertexSelectionTool.h:71
CP
Select isolated Photons, Electrons and Muons.
Definition: Control/xAODRootAccess/xAODRootAccess/TEvent.h:49
CP::PhotonVertexSelectionTool::m_sessionHandle1
std::shared_ptr< Ort::Session > m_sessionHandle1
Definition: PhotonVertexSelectionTool.h:83
CP::PhotonVertexSelectionTool::m_output_node_dims2
std::vector< int64_t > m_output_node_dims2
Definition: PhotonVertexSelectionTool.h:79
CP::PhotonVertexSelectionTool::m_eventInfo
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
Definition: PhotonVertexSelectionTool.h:54
CP::IPhotonVertexSelectionTool
Definition: IPhotonVertexSelectionTool.h:26
CP::PhotonVertexSelectionTool
Implementation for the photon vertex selection tool.
Definition: PhotonVertexSelectionTool.h:40
CP::PhotonVertexSelectionTool::initialize
virtual StatusCode initialize()
Function initialising the tool.
Definition: PhotonVertexSelectionTool.cxx:206
CP::PhotonVertexSelectionTool::m_allocator1
Ort::AllocatorWithDefaultOptions m_allocator1
Definition: PhotonVertexSelectionTool.h:86
CP::IPhotonVertexSelectionTool::yyVtxType
yyVtxType
Definition: IPhotonVertexSelectionTool.h:44
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
CP::PhotonVertexSelectionTool::m_vertexContainerName
std::string m_vertexContainerName
Definition: PhotonVertexSelectionTool.h:50
CP::PhotonVertexSelectionTool::m_sumPtKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
Definition: PhotonVertexSelectionTool.h:152
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
CP::PhotonVertexSelectionTool::m_convPtCut
float m_convPtCut
Definition: PhotonVertexSelectionTool.h:48
CP::PhotonVertexSelectionTool::getCase
int getCase() const
Return the last case treated:
Definition: PhotonVertexSelectionTool.h:138
CP::PhotonVertexSelectionTool::~PhotonVertexSelectionTool
virtual ~PhotonVertexSelectionTool()
CP::IPhotonVertexSelectionTool::FailType
FailType
Declare the interface that the class provides.
Definition: IPhotonVertexSelectionTool.h:33
LHEF::Reader
Pythia8::Reader Reader
Definition: Prophecy4fMerger.cxx:11
CP::PhotonVertexSelectionTool::m_nVars
int m_nVars
Create a proper constructor for Athena.
Definition: PhotonVertexSelectionTool.h:47
CP::PhotonVertexSelectionTool::getOutputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:157
IPhotonVertexSelectionTool.h
CP::PhotonVertexSelectionTool::getVertexImp
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float > > &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
Definition: PhotonVertexSelectionTool.cxx:346
DataVector
Derived DataVector<T>.
Definition: DataVector.h:794
CP::PhotonVertexSelectionTool::m_deltaPhiKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
Definition: PhotonVertexSelectionTool.h:146
CP::PhotonVertexSelectionTool::m_input_node_dims2
std::vector< int64_t > m_input_node_dims2
Definition: PhotonVertexSelectionTool.h:79
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath2
std::string m_ONNXModelFilePath2
Definition: PhotonVertexSelectionTool.h:72
CP::PhotonVertexSelectionTool::m_output_node_names1
std::vector< const char * > m_output_node_names1
Definition: PhotonVertexSelectionTool.h:77
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
CP::PhotonVertexSelectionTool::getVertex
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
Definition: PhotonVertexSelectionTool.cxx:336
CP::PhotonVertexSelectionTool::m_mva2
std::unique_ptr< TMVA::Reader > m_mva2
Definition: PhotonVertexSelectionTool.h:66
CP::PhotonVertexSelectionTool::m_output_node_names2
std::vector< const char * > m_output_node_names2
Definition: PhotonVertexSelectionTool.h:80
CP::PhotonVertexSelectionTool::m_isTMVA
bool m_isTMVA
Definition: PhotonVertexSelectionTool.h:59
EventInfo.h
CP::PhotonVertexSelectionTool::m_doSkipByZSigma
bool m_doSkipByZSigma
Definition: PhotonVertexSelectionTool.h:49
WriteDecorHandleKey.h
CP::PhotonVertexSelectionTool::m_allocator2
Ort::AllocatorWithDefaultOptions m_allocator2
Definition: PhotonVertexSelectionTool.h:87
VertexContainer.h
a
TList * a
Definition: liststreamerinfos.cxx:10
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
CP::PhotonVertexSelectionTool::decorateInputs
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
Definition: PhotonVertexSelectionTool.cxx:264
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
CP::PhotonVertexSelectionTool::m_input_node_names1
std::vector< const char * > m_input_node_names1
Definition: PhotonVertexSelectionTool.h:77
CP::PhotonVertexSelectionTool::m_output_node_dims1
std::vector< int64_t > m_output_node_dims1
Definition: PhotonVertexSelectionTool.h:76
CP::PhotonVertexSelectionTool::getInputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:126
ToolHandle.h
CP::PhotonVertexSelectionTool::m_deltaZKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Definition: PhotonVertexSelectionTool.h:148
AsgTool.h
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32
CP::PhotonVertexSelectionTool::m_input_node_names2
std::vector< const char * > m_input_node_names2
Definition: PhotonVertexSelectionTool.h:80
PhotonContainer.h
TMVA
Definition: PhotonVertexSelectionTool.h:27