Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
PhotonVertexSelectionTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 // Local includes
8 
9 // EDM includes
12 #include "xAODEgamma/EgammaDefs.h"
15 
16 // Framework includes
21 
22 // ROOT includes
23 #include "TMVA/Reader.h"
24 
25 // std includes
26 #include <algorithm>
27 
28 namespace CP {
29 
30  // helper function to get the vertex of a track
32  const xAOD::VertexContainer* vertices)
33  {
34  const xAOD::Vertex* vtxWithLargestWeight = nullptr;
35  float largestWeight = 0;
36 
37  for (const auto *vtx : *vertices) {
38  //Search for vertex linked to this track
39  const auto& trkLinks=vtx->trackParticleLinks();
40  const size_t nTrackLinks=trkLinks.size();
41  for (unsigned i=0;i<nTrackLinks;++i) {
42  if (trkLinks[i].isValid() && *(trkLinks[i]) == track) {//ptr comparison
43  if( vtx->trackWeights()[i] > largestWeight ){
44  vtxWithLargestWeight = vtx;
45  largestWeight = vtx->trackWeights()[i];
46  }
47  }
48  }
49  }
50 
51  return vtxWithLargestWeight;
52  }
53 
54  //____________________________________________________________________________
56  : asg::AsgTool(name)
57  {
58  // run 2 NN model:
59  // m_doSkipByZSigma = true, m_isTMVA = true
60  // run 3 NN model:
61  // m_doSkipByZSigma = false, m_isTMVA = false
62 
63  // default variables
64  declareProperty("nVars", m_nVars = 4);
65  declareProperty("conversionPtCut", m_convPtCut = 2e3);
66  declareProperty("DoSkipByZSigma", m_doSkipByZSigma = false);
67 
68  declareProperty("derivationPrefix", m_derivationPrefix = "");
69 
70  // boolean for TMVA, default true
71  declareProperty("isTMVA", m_isTMVA = false);
72 
73  // config files (TMVA), default paths if not set
74  declareProperty("ConfigFileCase1",
75  m_TMVAModelFilePath1 = "PhotonVertexSelection/v1/DiphotonVertex_case1.weights.xml");
76  declareProperty("ConfigFileCase2",
77  m_TMVAModelFilePath2 = "PhotonVertexSelection/v1/DiphotonVertex_case2.weights.xml");
78 
79  // config files (ONNX), default paths if not set
80  declareProperty("ONNXModelFileCase1", m_ONNXModelFilePath1 = "PhotonVertexSelection/run3nn/model1.onnx");
81  declareProperty("ONNXModelFileCase2", m_ONNXModelFilePath2 = "PhotonVertexSelection/run3nn/model2.onnx");
82  }
83 
84  //____________________________________________________________________________
86  = default;
87 
88  //____________________________________________________________________________
89  //new additions for ONNX
90  float PhotonVertexSelectionTool::getScore(int nVars, const std::vector<std::vector<float>>& input_data, const std::shared_ptr<Ort::Session> sessionHandle, std::vector<int64_t> input_node_dims, std::vector<const char*> input_node_names, std::vector<const char*> output_node_names) const{
91  //*************************************************************************
92  // score the model using sample data, and inspect values
93  // loading input data
94  std::vector<std::vector<float>> input_tensor_values_ = input_data;
95 
96  //preparing container to hold input data
97  size_t input_tensor_size = nVars;
98  std::vector<float> input_tensor_values(nVars);
99  input_tensor_values = input_tensor_values_[0]; //0th element since only batch_size of 1, otherwise loop
100 
101  // create input tensor object from data values
102  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
103  // create tensor using info from inputs
104  Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), input_node_dims.size());
105 
106  // check if input is of type tensor
107  assert(input_tensor.IsTensor());
108 
109  // run the inference
110  auto output_tensors = sessionHandle->Run(Ort::RunOptions{nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
111 
112  // check size of output tensor
113  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
114 
115  // get pointer to output tensor float values
116  //float* floatarr = output_tensors.front().GetTensorMutableData<float>();
117  float* floatarr = output_tensors[0].GetTensorMutableData<float>();
118 
119  int arrSize = sizeof(*floatarr)/sizeof(floatarr[0]);
120  ATH_MSG_DEBUG("The size of the array is: " << arrSize);
121  ATH_MSG_DEBUG("floatarr[0] = " << floatarr[0]);
122  return floatarr[0];
123  }
124 
125  //new additions for ONNX
126  std::tuple<std::vector<int64_t>, std::vector<const char*>> PhotonVertexSelectionTool::getInputNodes(const std::shared_ptr<Ort::Session> sessionHandle, Ort::AllocatorWithDefaultOptions& allocator){
127  // input nodes
128  std::vector<int64_t> input_node_dims;
129  size_t num_input_nodes = sessionHandle->GetInputCount();
130  std::vector<const char*> input_node_names(num_input_nodes);
131 
132  // Loop the input nodes
133  for( std::size_t i = 0; i < num_input_nodes; i++ ) {
134  // Print input node names
135  char* input_name = sessionHandle->GetInputNameAllocated(i, allocator).release();
136  ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<input_name);
137  input_node_names[i] = input_name;
138 
139  // Print input node types
140  Ort::TypeInfo type_info = sessionHandle->GetInputTypeInfo(i);
141  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
142  ONNXTensorElementDataType type = tensor_info.GetElementType();
143  ATH_MSG_DEBUG("Input "<<i<<" : "<<" type= "<<type);
144 
145  // Print input shapes/dims
146  input_node_dims = tensor_info.GetShape();
147  ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<input_node_dims.size());
148  for (std::size_t j = 0; j < input_node_dims.size(); j++){
149  if(input_node_dims[j]<0){input_node_dims[j] =1;}
150  ATH_MSG_DEBUG("Input"<<i<<" : dim "<<j<<"= "<<input_node_dims[j]);
151  }
152  }
153  return std::make_tuple(input_node_dims, input_node_names);
154  }
155 
156  //new additions for ONNX
157  std::tuple<std::vector<int64_t>, std::vector<const char*>> PhotonVertexSelectionTool::getOutputNodes(const std::shared_ptr<Ort::Session> sessionHandle, Ort::AllocatorWithDefaultOptions& allocator){
158  // output nodes
159  std::vector<int64_t> output_node_dims;
160  size_t num_output_nodes = sessionHandle->GetOutputCount();
161  std::vector<const char*> output_node_names(num_output_nodes);
162 
163  // Loop the output nodes
164  for( std::size_t i = 0; i < num_output_nodes; i++ ) {
165  // Print output node names
166  char* output_name = sessionHandle->GetOutputNameAllocated(i, allocator).release();
167  ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<output_name);
168  output_node_names[i] = output_name;
169 
170  Ort::TypeInfo type_info = sessionHandle->GetOutputTypeInfo(i);
171  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
172  ONNXTensorElementDataType type = tensor_info.GetElementType();
173  ATH_MSG_DEBUG("Output "<<i<<" : "<<" type= "<<type);
174 
175  // Print output shapes/dims
176  output_node_dims = tensor_info.GetShape();
177  ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<output_node_dims.size());
178  for (std::size_t j = 0; j < output_node_dims.size(); j++){
179  if(output_node_dims[j]<0){output_node_dims[j] =1;}
180  ATH_MSG_DEBUG("Output"<<i<<" : dim "<<j<<"= "<<output_node_dims[j]);
181  }
182  }
183  return std::make_tuple(output_node_dims, output_node_names);
184  }
185 
186  //new additions for ONNX
187  std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions> PhotonVertexSelectionTool::setONNXSession(Ort::Env& env, const std::string& modelFilePath){
188  // Find the model file.
189  const std::string modelFileName = PathResolverFindCalibFile( modelFilePath );
190  ATH_MSG_INFO( "Using model file: " << modelFileName );
191 
192  // set onnx session options
193  Ort::SessionOptions sessionOptions;
194  sessionOptions.SetIntraOpNumThreads( 1 );
195  sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
196  // set allocator
197  Ort::AllocatorWithDefaultOptions allocator;
198  // set the onnx runtime session
199  std::shared_ptr<Ort::Session> sessionHandle = std::make_shared<Ort::Session>( env, modelFileName.c_str(), sessionOptions );
200 
201  ATH_MSG_INFO( "Created the ONNX Runtime session for model file = " << modelFileName);
202  return std::make_tuple(sessionHandle, allocator);
203  }
204 
205  //____________________________________________________________________________
207  {
208  ATH_MSG_INFO("Initializing PhotonVertexSelectionTool...");
209  // initialize the readers or sessions
210  if(m_isTMVA){
211  // Get full path of configuration files for MVA
214  // Setup MVAs
215  std::vector<std::string> var_names = {
216  "deltaZ := TMath::Min(abs(PrimaryVerticesAuxDyn.z-zCommon)/zCommonError,20)",
217  "deltaPhi := abs(deltaPhi(PrimaryVerticesAuxDyn.phi,egamma_phi))" ,
218  "logSumpt := log10(PrimaryVerticesAuxDyn.sumPt)" ,
219  "logSumpt2 := log10(PrimaryVerticesAuxDyn.sumPt2)"
220  };
221  auto mva1 = new TMVA::Reader(var_names, "!Silent:Color");
222  mva1->BookMVA ("MLP method", m_TMVAModelFilePath1 );
223  m_mva1 = std::unique_ptr<TMVA::Reader>( std::move(mva1) );
224 
225  auto mva2 = std::make_unique<TMVA::Reader>(var_names, "!Silent:Color");
226  mva2->BookMVA ("MLP method", m_TMVAModelFilePath2 );
227  m_mva2 = std::unique_ptr<TMVA::Reader>( std::move(mva2) );
228  }
229  else{ // assume only ONNX for now
230  // create onnx environment
231  Ort::Env env;
232  // converted
236 
237  // unconverted
241  }
242 
243  // initialize the containers
245  ATH_CHECK( m_vertexContainer.initialize() );
246 
251  ATH_CHECK( m_deltaPhiKey.initialize() );
252  ATH_CHECK( m_deltaZKey.initialize() );
253  ATH_CHECK( m_sumPt2Key.initialize() );
254  ATH_CHECK( m_sumPtKey.initialize() );
255 #ifndef XAOD_STANDALONE
258 #endif
259 
260  return StatusCode::SUCCESS;
261  }
262 
263  //____________________________________________________________________________
265  auto fail = FailType::NoFail;
266 
267  const EventContext& ctx = Gaudi::Hive::currentContext();
272 
273  // Get the EventInfo
275 
276  // Find the common z-position from beam / photon pointing information
277  std::pair<float, float> zCommon = xAOD::PVHelpers::getZCommonAndError(&*eventInfo, &egammas, m_convPtCut);
278  // Vector sum of photons
279  TLorentzVector vegamma = getEgammaVector(&egammas, fail);
280 
281  // Retrieve PV collection from TEvent
283 
284  bool writeSumPt2 = !sumPt2.isAvailable();
285  bool writeSumPt = !sumPt.isAvailable();
286 
287  for (const xAOD::Vertex* vertex: *vertices) {
288 
289  // Skip dummy vertices
290  if (!(vertex->vertexType() == xAOD::VxType::VertexType::PriVtx ||
291  vertex->vertexType() == xAOD::VxType::VertexType::PileUp)) continue;
292 
293  // Set input variables
294  if (writeSumPt) {
295  sumPt(*vertex) = xAOD::PVHelpers::getVertexSumPt(vertex, 1, false);
296  }
297 
298  if (writeSumPt2) {
299  sumPt2(*vertex) = xAOD::PVHelpers::getVertexSumPt(vertex, 2);
300  }
301 
302  // Get momentum vector of vertex
303  TLorentzVector vmom = xAOD::PVHelpers::getVertexMomentum(vertex, true, m_derivationPrefix);
304 
305  deltaPhi(*vertex) = (fail != FailType::FailEgamVect) ? std::abs(vmom.DeltaPhi(vegamma)) : -999.;
306  deltaZ(*vertex) = std::abs((zCommon.first - vertex->z())/zCommon.second);
307 
308  } // loop over vertices
309 
310  ATH_MSG_DEBUG("DecorateInputs exit code "<< fail);
311  if(failType!=nullptr)
312  *failType = fail;
313  return StatusCode::SUCCESS;
314  }
315 
316  //____________________________________________________________________________
317  std::vector<std::pair<const xAOD::Vertex*, float> >
318  PhotonVertexSelectionTool::getVertex(const xAOD::EgammaContainer &egammas, bool ignoreConv, bool noDecorate, yyVtxType* vtxCasePtr, FailType* failTypePtr) const
319  {
320  const xAOD::Vertex *vertex = nullptr;
321  std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
322  yyVtxType vtxCase = yyVtxType::Unknown;
323  FailType failType = FailType::NoFail;
324  if (getVertexImp( egammas, vertex, ignoreConv, noDecorate, vertexMLP, vtxCase, failType ).isSuccess()) {
325  std::sort(vertexMLP.begin(), vertexMLP.end(), sortMLP);
326  }
327  if(vtxCasePtr!=nullptr)
328  *vtxCasePtr = vtxCase;
329  if(failTypePtr!=nullptr)
330  *failTypePtr = failType;
331 
332  return vertexMLP;
333  }
334 
335  //____________________________________________________________________________
337  const xAOD::Vertex* &prime_vertex,
338  bool ignoreConv) const
339  {
340  std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
341  yyVtxType vtxcase = yyVtxType::Unknown;
342  FailType failType = FailType::NoFail;
343  return getVertexImp( egammas, prime_vertex, ignoreConv, false, vertexMLP, vtxcase, failType );
344  }
345 
347  const xAOD::Vertex* &prime_vertex,
348  bool ignoreConv,
349  bool noDecorate,
350  std::vector<std::pair<const xAOD::Vertex*, float> >& vertexMLP, yyVtxType& vtxCase, FailType& fail) const
351  {
352  // Set default vertex case and declare photon container
353  vtxCase = yyVtxType::Unknown;
354  const xAOD::PhotonContainer *photons = dynamic_cast<const xAOD::PhotonContainer*>(&egammas);
355 
356  // Retrieve PV collection from TEvent
358 
359  if (!noDecorate && !decorateInputs(egammas).isSuccess()){
360  return StatusCode::FAILURE;
361  }
362 
363  // Check if a conversion photon has a track attached to a primary/pileup vertex
364  if (!ignoreConv && photons) {
365  prime_vertex = getPrimaryVertexFromConv(photons);
366  if (prime_vertex != nullptr) {
367  vtxCase = yyVtxType::ConvTrack;
368  fail = FailType::MatchedTrack;
369  vertexMLP.emplace_back(prime_vertex, 0.);
370  return StatusCode::SUCCESS;
371  }
372  }
373 
374  if (fail != FailType::NoFail){
375  ATH_MSG_VERBOSE("Returning hardest vertex. Fail detected (type="<< fail <<")");
376  vertexMLP.clear();
377  prime_vertex = xAOD::PVHelpers::getHardestVertex(&*vertices);
378  vertexMLP.emplace_back(prime_vertex, 10.);
379  return StatusCode::SUCCESS;
380  }
381 
382  // Get the EventInfo
384 
385  // If there are any silicon conversions passing selection
386  // ==> use Model 1 (Conv) otherwise Model 2 (Unconv)
387  // Set default for conversion bool as false unless otherwise
388  bool isConverted = false;
389 
390  // assume default NoSiTrack (unconverted) unless otherwise
391  vtxCase = yyVtxType::NoSiTracks;
392  if (!ignoreConv && photons) {
393  for (const auto *photon: *photons) {
394  if (!photon)
395  {
396  ATH_MSG_WARNING("Null pointer to photon");
397  return StatusCode::FAILURE;
398  }
399  // find out if pass conversion selection criteria and tag as SiConvTrack case
401  {
402  isConverted = true;
403  vtxCase = yyVtxType::SiConvTrack;
404  }
405  }
406  }
407 
408  // if TMVA chosen, declare tmva_reader only once (before for looping vertex)
409  TMVA::Reader *tmva_reader = new TMVA::Reader();
410  if(m_isTMVA){
411  if(isConverted){
412  // If there are any silicon conversions passing selection, use MVA1 (converted case)
413  tmva_reader = m_mva1.get();
414  }
415  // Otherwise, use MVA2 (unconverted case)
416  if(!isConverted){
417  tmva_reader = m_mva2.get();
418  }
419  }
420  ATH_MSG_DEBUG("Vtx Case: " << vtxCase);
421 
422  // Vector sum of photons
423  TLorentzVector vegamma = getEgammaVector(&egammas, fail);
424 
429 
430  // Loop over vertices and find best candidate
431  std::vector<float> ONNXInputVector;
432  std::vector<std::vector<float>> onnx_input_tensor_values;
433  std::vector<float> TMVAInputVector;
434  TString TMVAMethod;
435  float mlp = 0.0, mlp_max = -99999.0;
436  float doSkipByZSigmaScore = -9999.0;
437  // assign threshold score value to compare later for good vtx
438  float thresGoodVtxScore;
439  if(m_doSkipByZSigma){thresGoodVtxScore = doSkipByZSigmaScore;}
440  else{thresGoodVtxScore = mlp_max;}
441  for (const xAOD::Vertex* vertex: *vertices) {
442  // Skip dummy vertices
443  if (!(vertex->vertexType() == xAOD::VxType::VertexType::PriVtx ||
444  vertex->vertexType() == xAOD::VxType::VertexType::PileUp)) continue;
445 
446  onnx_input_tensor_values.clear();
447 
448  // Variables used as input features in classifier
449  float sumPt, sumPt2, deltaPhi, deltaZ;
450  float log10_sumPt, log10_sumPt2;
451 
452  sumPt = (sumPtA)(*vertex);
453  sumPt2 = (sumPt2A)(*vertex);
454  deltaPhi = (deltaPhiA)(*vertex);
455  deltaZ = (deltaZA)(*vertex);
456  ATH_MSG_VERBOSE("sumPt: " << sumPt <<
457  " sumPt2: " << sumPt2 <<
458  " deltaPhi: " << deltaPhi <<
459  " deltaZ: " << deltaZ);
460 
461  // setup the vector of input features based on selected inference framework
462  if(m_isTMVA){
463  // Get likelihood probability from TMVA model
464  TMVAMethod = "MLP method";
465  log10_sumPt = static_cast<float>(log10(sumPt));
466  log10_sumPt2 = static_cast<float>(log10(sumPt2));
467  TMVAInputVector = {deltaZ,deltaPhi,log10_sumPt,log10_sumPt2};
468  }
469  else{ //assume ony ONNX for now
470  // Get likelihood probability from onnx model
471  // check if value is 0, assign small number like 1e-8 as dummy, as we will take log later (log(0) is nan)
472  // note that the ordering here is a bit different, following the order used when training
473  ONNXInputVector = {sumPt2, sumPt, deltaPhi, deltaZ};
474  for (long unsigned int i = 0; i < ONNXInputVector.size(); i++) {
475  // skip log for deltaPhi and take log for the rest
476  if (i == 2) {
477  continue;
478  }
479  if (ONNXInputVector[i] != 0 && std::isinf(ONNXInputVector[i]) != true && std::isnan(ONNXInputVector[i]) != true){
480  ONNXInputVector[i] = log(std::abs(ONNXInputVector[i]));
481  }
482  else{
483  ONNXInputVector[i] = log(std::abs(0.00000001)); //log(abs(1e-8))
484  }
485  } //end ONNXInputVector for loop
486  onnx_input_tensor_values.push_back(ONNXInputVector);
487  }
488 
489  // Do the actual calculation of classifier score part
490  if(m_isTMVA){
491  mlp = tmva_reader->EvaluateMVA(TMVAInputVector, TMVAMethod);
492  ATH_MSG_VERBOSE("TMVA output: " << (tmva_reader == m_mva1.get() ? "MVA1 ": "MVA2 ")<< mlp);
493  }
494  else{ //assume ony ONNX for now
495  if(isConverted){
496  mlp = getScore(m_nVars, onnx_input_tensor_values,
499  }
500  if(!isConverted){
501  mlp = getScore(m_nVars, onnx_input_tensor_values,
504  }
505  ATH_MSG_VERBOSE("log(abs(sumPt)): " << sumPt <<
506  " log(abs(sumPt2)): " << sumPt2 <<
507  " deltaPhi: " << deltaPhi <<
508  " log(abs(deltaZ)): " << deltaZ);
509  ATH_MSG_VERBOSE("ONNX output, isConverted = " << isConverted << ", mlp=" << mlp);
510  }
511 
512  // Skip vertices above 10 sigma from pointing or 15 sigma from conversion (HPV)
513  // Simply displace the mlp variable we calculate before by a predefined value
514  if(m_doSkipByZSigma){
515  if ((isConverted && deltaZ > 15) || (!isConverted && deltaZ > 10)) {
516  mlp = doSkipByZSigmaScore;
517  }
518  }
519 
520  // add the new vertex and its score to vertexMLP container
521  vertexMLP.emplace_back(vertex, mlp);
522 
523  // Keep track of maximal likelihood vertex
524  if (mlp > mlp_max) {
525  mlp_max = mlp;
526  prime_vertex = vertex;
527  }
528  } // end loop over vertices
529 
530  // from all the looped vertices, decide the max score which should be more than the minimum we set
531  // (which should be more than the initial mlp_max value above or more than the skip vertex by z-sigma score)
532  // if this does not pass, return hardest primary vertex
533  if (mlp_max <= thresGoodVtxScore) {
534  ATH_MSG_DEBUG("No good vertex candidates from pointing, returning hardest vertex.");
535  prime_vertex = xAOD::PVHelpers::getHardestVertex(&*vertices);
536  fail = FailType::NoGdCandidate;
537  vertexMLP.clear();
538  vertexMLP.emplace_back(xAOD::PVHelpers::getHardestVertex(&*vertices), 20.);
539  }
540 
541  ATH_MSG_VERBOSE("getVertex case "<< (int)vtxCase << " exit code "<< (int)fail);
542  return StatusCode::SUCCESS;
543  }
544 
545  //____________________________________________________________________________
546  bool PhotonVertexSelectionTool::sortMLP(const std::pair<const xAOD::Vertex*, float> &a,
547  const std::pair<const xAOD::Vertex*, float> &b)
548  { return a.second > b.second; }
549 
550  //____________________________________________________________________________
552  {
553  if (photons == nullptr) {
554  ATH_MSG_WARNING("Passed nullptr photon container, returning nullptr vertex from getPrimaryVertexFromConv");
555  return nullptr;
556  }
557 
558  std::vector<const xAOD::Vertex*> vertices;
559  const xAOD::Vertex *conversionVertex = nullptr, *primary = nullptr;
560  const xAOD::TrackParticle *tp = nullptr;
561  size_t NumberOfTracks = 0;
562 
563  // Retrieve PV collection from TEvent
565 
566 
567  for (const auto *photon: *photons) {
568  conversionVertex = photon->vertex();
569  if (conversionVertex == nullptr) continue;
570 
571  NumberOfTracks = conversionVertex->nTrackParticles();
572  for (size_t i = 0; i < NumberOfTracks; ++i) {
573  // Get trackParticle in GSF collection
574  const auto *gsfTp = conversionVertex->trackParticle(i);
575  if (gsfTp == nullptr) continue;
576  if (!xAOD::PVHelpers::passConvSelection(*conversionVertex, i, m_convPtCut)) continue;
577 
578  // Get trackParticle in InDet collection
580  if (tp == nullptr) continue;
581 
582  primary = getVertexFromTrack(tp, &*all_vertices);
583  if (primary == nullptr) continue;
584 
585  if (primary->vertexType() == xAOD::VxType::VertexType::PriVtx ||
586  primary->vertexType() == xAOD::VxType::VertexType::PileUp) {
587  if (std::find(vertices.begin(), vertices.end(), primary) == vertices.end()) {
588  vertices.push_back(primary);
589  continue;
590  }
591  }
592  }
593  }
594 
595  if (!vertices.empty()) {
596  if (vertices.size() > 1)
597  ATH_MSG_WARNING("Photons associated to different vertices! Returning lead photon association.");
598  return vertices[0];
599  }
600 
601  return nullptr;
602  }
603 
604  //____________________________________________________________________________
605  TLorentzVector PhotonVertexSelectionTool::getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const
606  {
607  TLorentzVector v, v1;
608  const xAOD::CaloCluster *cluster = nullptr;
609  for (const xAOD::Egamma* egamma: *egammas) {
610  if (egamma == nullptr) {
611  ATH_MSG_DEBUG("No egamma object to get four vector");
612  failType = FailType::FailEgamVect;
613  continue;
614  }
615  cluster = egamma->caloCluster();
616  if (cluster == nullptr) {
617  ATH_MSG_WARNING("No cluster associated to egamma, not adding to 4-vector.");
618  continue;
619  }
620 
621  v1.SetPtEtaPhiM(egamma->e()/cosh(cluster->etaBE(2)),
622  cluster->etaBE(2),
623  cluster->phiBE(2),
624  0.0);
625  v += v1;
626  }
627  return v;
628  }
629 
630 } // namespace CP
CP::PhotonVertexSelectionTool::PhotonVertexSelectionTool
PhotonVertexSelectionTool(const std::string &name)
Definition: PhotonVertexSelectionTool.cxx:55
CP::PhotonVertexSelectionTool::setONNXSession
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
Definition: PhotonVertexSelectionTool.cxx:187
GetLCDefs::Unknown
@ Unknown
Definition: GetLCDefs.h:21
xAOD::PVHelpers::getVertexSumPt
float getVertexSumPt(const xAOD::Vertex *vertex, int power=1, bool useAux=true)
Loop over track particles associated with vertex and return scalar sum of pT^power in GeV (from auxda...
Definition: PhotonVertexHelpers.cxx:215
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath2
std::string m_TMVAModelFilePath2
Definition: PhotonVertexSelectionTool.h:61
CP::PhotonVertexSelectionTool::m_derivationPrefix
std::string m_derivationPrefix
Definition: PhotonVertexSelectionTool.h:51
xAOD::Vertex_v1::nTrackParticles
size_t nTrackParticles() const
Get the number of tracks associated with this vertex.
Definition: Vertex_v1.cxx:270
CP::PhotonVertexSelectionTool::getPrimaryVertexFromConv
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
Definition: PhotonVertexSelectionTool.cxx:551
CP::PhotonVertexSelectionTool::m_sumPt2Key
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
Definition: PhotonVertexSelectionTool.h:150
SG::WriteDecorHandle::isAvailable
bool isAvailable()
Test to see if this variable exists in the store, for the referenced object.
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
find
std::string find(const std::string &s)
return a remapped string
Definition: hcg.cxx:135
ShowerDepthTool.h
CP::PhotonVertexSelectionTool::getScore
float getScore(int nVars, const std::vector< std::vector< float >> &input_data, const std::shared_ptr< Ort::Session > sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
Definition: PhotonVertexSelectionTool.cxx:90
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath1
std::string m_TMVAModelFilePath1
Definition: PhotonVertexSelectionTool.h:60
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:70
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
CP::PhotonVertexSelectionTool::getEgammaVector
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
Definition: PhotonVertexSelectionTool.cxx:605
xAOD::deltaPhi
setSAddress setEtaMS setDirPhiMS setDirZMS setBarrelRadius setEndcapAlpha setEndcapRadius setInterceptInner setEtaMap setEtaBin setIsTgcFailure setDeltaPt deltaPhi
Definition: L2StandAloneMuon_v1.cxx:160
SG::decorKeyFromKey
std::string decorKeyFromKey(const std::string &key)
Extract the decoration part of key.
Definition: DecorKeyHelpers.cxx:41
xAOD::Vertex_v1::trackParticleLinks
const TrackParticleLinks_t & trackParticleLinks() const
Get all the particles associated with the vertex.
CP::PhotonVertexSelectionTool::m_vertexContainer
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
Definition: PhotonVertexSelectionTool.h:55
xAOD::Vertex_v1::trackWeights
const std::vector< float > & trackWeights() const
Get all the track weights.
AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)
Definition: AthCommonDataStore.h:380
CP::PhotonVertexSelectionTool::m_input_node_dims1
std::vector< int64_t > m_input_node_dims1
Definition: PhotonVertexSelectionTool.h:76
xAOD::PVHelpers::getZCommonAndError
std::pair< float, float > getZCommonAndError(const xAOD::EventInfo *eventInfo, const xAOD::EgammaContainer *egammas, float convPtCut=2e3)
Return zCommon and zCommonError.
Definition: PhotonVertexHelpers.cxx:44
asg
Definition: DataHandleTestTool.h:28
CP::PhotonVertexSelectionTool::m_sessionHandle2
std::shared_ptr< Ort::Session > m_sessionHandle2
Definition: PhotonVertexSelectionTool.h:84
ParticleTest.tp
tp
Definition: ParticleTest.py:25
CP::PhotonVertexSelectionTool::m_mva1
std::unique_ptr< TMVA::Reader > m_mva1
Definition: PhotonVertexSelectionTool.h:65
xAOD::Egamma_v1
Definition: Egamma_v1.h:56
SG::ConstAccessor
Helper class to provide constant type-safe access to aux data.
Definition: ConstAccessor.h:55
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
isValid
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
Definition: AtlasPID.h:812
CP::PhotonVertexSelectionTool::sortMLP
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
Definition: PhotonVertexSelectionTool.cxx:546
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath1
std::string m_ONNXModelFilePath1
Definition: PhotonVertexSelectionTool.h:71
CP
Select isolated Photons, Electrons and Muons.
Definition: Control/xAODRootAccess/xAODRootAccess/TEvent.h:49
CP::PhotonVertexSelectionTool::m_sessionHandle1
std::shared_ptr< Ort::Session > m_sessionHandle1
Definition: PhotonVertexSelectionTool.h:83
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
egamma
Definition: egamma.h:58
ParticleImpl::e
virtual double e() const
energy
Definition: ParticleImpl.h:534
CP::PhotonVertexSelectionTool::m_output_node_dims2
std::vector< int64_t > m_output_node_dims2
Definition: PhotonVertexSelectionTool.h:79
CP::PhotonVertexSelectionTool::m_eventInfo
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
Definition: PhotonVertexSelectionTool.h:54
xAOD::CaloCluster_v1
Description of a calorimeter cluster.
Definition: CaloCluster_v1.h:59
xAOD::PVHelpers::getHardestVertex
const xAOD::Vertex * getHardestVertex(const xAOD::VertexContainer *vertices)
Return vertex with highest sum pT^2.
Definition: PhotonVertexHelpers.cxx:29
EgammaxAODHelpers.h
PhotonVertexSelectionTool.h
StateLessPT_NewConfig.primary
primary
Definition: StateLessPT_NewConfig.py:228
CP::PhotonVertexSelectionTool::initialize
virtual StatusCode initialize()
Function initialising the tool.
Definition: PhotonVertexSelectionTool.cxx:206
CheckAppliedSFs.e3
e3
Definition: CheckAppliedSFs.py:264
CP::PhotonVertexSelectionTool::m_allocator1
Ort::AllocatorWithDefaultOptions m_allocator1
Definition: PhotonVertexSelectionTool.h:86
CP::IPhotonVertexSelectionTool::yyVtxType
yyVtxType
Definition: IPhotonVertexSelectionTool.h:44
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
CP::PhotonVertexSelectionTool::m_sumPtKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
Definition: PhotonVertexSelectionTool.h:152
SG::WriteDecorHandle
Handle class for adding a decoration to an object.
Definition: StoreGate/StoreGate/WriteDecorHandle.h:100
CP::PhotonVertexSelectionTool::m_convPtCut
float m_convPtCut
Definition: PhotonVertexSelectionTool.h:48
CP::PhotonVertexSelectionTool::~PhotonVertexSelectionTool
virtual ~PhotonVertexSelectionTool()
CP::IPhotonVertexSelectionTool::FailType
FailType
Declare the interface that the class provides.
Definition: IPhotonVertexSelectionTool.h:33
xAOD::VxType::PriVtx
@ PriVtx
Primary vertex.
Definition: TrackingPrimitives.h:572
LHEF::Reader
Pythia8::Reader Reader
Definition: Prophecy4fMerger.cxx:11
CP::PhotonVertexSelectionTool::m_nVars
int m_nVars
Create a proper constructor for Athena.
Definition: PhotonVertexSelectionTool.h:47
CP::PhotonVertexSelectionTool::getOutputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:157
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
xAOD::Vertex_v1::trackParticle
const TrackParticle * trackParticle(size_t i) const
Get the pointer to a given track that was used in vertex reco.
Definition: Vertex_v1.cxx:249
CP::PhotonVertexSelectionTool::getVertexImp
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float > > &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
Definition: PhotonVertexSelectionTool.cxx:346
SG::VarHandleKey::initialize
StatusCode initialize(bool used=true)
If this object is used as a property, then this should be called during the initialize phase.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:103
xAOD::PVHelpers::getVertexMomentum
TLorentzVector getVertexMomentum(const xAOD::Vertex *vertex, bool useAux=true, const std::string &derivationPrefix="")
Return vector sum of tracks associated with vertex (from auxdata if available and useAux = true)
Definition: PhotonVertexHelpers.cxx:175
DataVector
Derived DataVector<T>.
Definition: DataVector.h:794
WriteDecorHandle.h
Handle class for adding a decoration to an object.
CP::PhotonVertexSelectionTool::m_deltaPhiKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
Definition: PhotonVertexSelectionTool.h:146
PhotonVertexHelpers.h
CP::PhotonVertexSelectionTool::m_input_node_dims2
std::vector< int64_t > m_input_node_dims2
Definition: PhotonVertexSelectionTool.h:79
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath2
std::string m_ONNXModelFilePath2
Definition: PhotonVertexSelectionTool.h:72
xAOD::VxType::PileUp
@ PileUp
Pile-up vertex.
Definition: TrackingPrimitives.h:574
CP::PhotonVertexSelectionTool::m_output_node_names1
std::vector< const char * > m_output_node_names1
Definition: PhotonVertexSelectionTool.h:77
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
ReadHandle.h
Handle class for reading from StoreGate.
CP::PhotonVertexSelectionTool::getVertex
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
Definition: PhotonVertexSelectionTool.cxx:336
CP::PhotonVertexSelectionTool::m_mva2
std::unique_ptr< TMVA::Reader > m_mva2
Definition: PhotonVertexSelectionTool.h:66
xAOD::EgammaHelpers::getOriginalTrackParticleFromGSF
const xAOD::TrackParticle * getOriginalTrackParticleFromGSF(const xAOD::TrackParticle *trkPar)
Helper function for getting the "Original" Track Particle (i.e before GSF) via the GSF Track Particle...
Definition: ElectronxAODHelpers.cxx:22
CP::PhotonVertexSelectionTool::m_output_node_names2
std::vector< const char * > m_output_node_names2
Definition: PhotonVertexSelectionTool.h:80
CP::PhotonVertexSelectionTool::m_isTMVA
bool m_isTMVA
Definition: PhotonVertexSelectionTool.h:59
EventInfo.h
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
CP::PhotonVertexSelectionTool::m_doSkipByZSigma
bool m_doSkipByZSigma
Definition: PhotonVertexSelectionTool.h:49
python.PyAthena.v
v
Definition: PyAthena.py:154
CP::PhotonVertexSelectionTool::m_allocator2
Ort::AllocatorWithDefaultOptions m_allocator2
Definition: PhotonVertexSelectionTool.h:87
VertexContainer.h
xAOD::photon
@ photon
Definition: TrackingPrimitives.h:200
a
TList * a
Definition: liststreamerinfos.cxx:10
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
CP::PhotonVertexSelectionTool::decorateInputs
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
Definition: PhotonVertexSelectionTool.cxx:264
EgammaDefs.h
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
xAOD::PVHelpers::passConvSelection
bool passConvSelection(const xAOD::Photon *photon, float convPtCut=2e3)
Check if photon is converted, and tracks have Si hits and pass selection.
Definition: PhotonVertexHelpers.cxx:133
CP::PhotonVertexSelectionTool::m_input_node_names1
std::vector< const char * > m_input_node_names1
Definition: PhotonVertexSelectionTool.h:77
CP::PhotonVertexSelectionTool::m_output_node_dims1
std::vector< int64_t > m_output_node_dims1
Definition: PhotonVertexSelectionTool.h:76
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
CP::PhotonVertexSelectionTool::getInputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:126
makeComparison.deltaZ
int deltaZ
Definition: makeComparison.py:46
CP::PhotonVertexSelectionTool::m_deltaZKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Definition: PhotonVertexSelectionTool.h:148
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32
CP::PhotonVertexSelectionTool::m_input_node_names2
std::vector< const char * > m_input_node_names2
Definition: PhotonVertexSelectionTool.h:80
PhotonContainer.h
CP::getVertexFromTrack
const xAOD::Vertex * getVertexFromTrack(const xAOD::TrackParticle *track, const xAOD::VertexContainer *vertices)
Definition: PhotonVertexSelectionTool.cxx:31
beamspotman.fail
def fail(message)
Definition: beamspotman.py:201