ATLAS Offline Software
PhotonVertexSelectionTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 // Local includes
8 
9 // EDM includes
12 #include "xAODEgamma/EgammaDefs.h"
15 
16 // Framework includes
20 
21 // ROOT includes
22 #include "TMVA/Reader.h"
23 
24 // std includes
25 #include <algorithm>
26 
27 namespace CP {
28 
29  // helper function to get the vertex of a track
31  const xAOD::VertexContainer* vertices)
32  {
33  const xAOD::Vertex* vtxWithLargestWeight = nullptr;
34  float largestWeight = 0;
35 
36  for (const auto *vtx : *vertices) {
37  //Search for vertex linked to this track
38  const auto& trkLinks=vtx->trackParticleLinks();
39  const size_t nTrackLinks=trkLinks.size();
40  for (unsigned i=0;i<nTrackLinks;++i) {
41  if (trkLinks[i].isValid() && *(trkLinks[i]) == track) {//ptr comparison
42  if( vtx->trackWeights()[i] > largestWeight ){
43  vtxWithLargestWeight = vtx;
44  largestWeight = vtx->trackWeights()[i];
45  }
46  }
47  }
48  }
49 
50  return vtxWithLargestWeight;
51  }
52 
53  //____________________________________________________________________________
55  : asg::AsgTool(name)
56  {
57  // run 2 NN model:
58  // m_doSkipByZSigma = true, m_isTMVA = true
59  // run 3 NN model:
60  // m_doSkipByZSigma = false, m_isTMVA = false
61 
62  // default variables
63  declareProperty("nVars", m_nVars = 4);
64  declareProperty("conversionPtCut", m_convPtCut = 2e3);
65  declareProperty("DoSkipByZSigma", m_doSkipByZSigma = false);
66 
67  declareProperty("derivationPrefix", m_derivationPrefix = "");
68 
69  // boolean for TMVA, default true
70  declareProperty("isTMVA", m_isTMVA = false);
71 
72  // config files (TMVA), default paths if not set
73  declareProperty("ConfigFileCase1",
74  m_TMVAModelFilePath1 = "PhotonVertexSelection/v1/DiphotonVertex_case1.weights.xml");
75  declareProperty("ConfigFileCase2",
76  m_TMVAModelFilePath2 = "PhotonVertexSelection/v1/DiphotonVertex_case2.weights.xml");
77 
78  // config files (ONNX), default paths if not set
79  declareProperty("ONNXModelFileCase1", m_ONNXModelFilePath1 = "PhotonVertexSelection/run3nn/model1.onnx");
80  declareProperty("ONNXModelFileCase2", m_ONNXModelFilePath2 = "PhotonVertexSelection/run3nn/model2.onnx");
81  }
82 
83  //____________________________________________________________________________
85  = default;
86 
87  //____________________________________________________________________________
88  //new additions for ONNX
90  int nVars, const std::vector<std::vector<float>>& input_data,
91  const std::shared_ptr<Ort::Session>& sessionHandle,
92  std::vector<int64_t> input_node_dims,
93  std::vector<const char*> input_node_names,
94  std::vector<const char*> output_node_names) const {
95  //*************************************************************************
96  // score the model using sample data, and inspect values
97  // loading input data
98  std::vector<std::vector<float>> input_tensor_values_ = input_data;
99 
100  //preparing container to hold input data
101  size_t input_tensor_size = nVars;
102  std::vector<float> input_tensor_values(nVars);
103  input_tensor_values = input_tensor_values_[0]; //0th element since only batch_size of 1, otherwise loop
104 
105  // create input tensor object from data values
106  auto memory_info =
107  Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108  // create tensor using info from inputs
109  Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
110  memory_info, input_tensor_values.data(), input_tensor_size,
111  input_node_dims.data(), input_node_dims.size());
112 
113  // check if input is of type tensor
114  assert(input_tensor.IsTensor());
115 
116  // run the inference
117  auto output_tensors =
118  sessionHandle->Run(Ort::RunOptions{nullptr}, input_node_names.data(),
119  &input_tensor, input_node_names.size(),
120  output_node_names.data(), output_node_names.size());
121 
122  // check size of output tensor
123  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
124 
125  // get pointer to output tensor float values
126  // float* floatarr = output_tensors.front().GetTensorMutableData<float>();
127  float* floatarr = output_tensors[0].GetTensorMutableData<float>();
128 
129  int arrSize = sizeof(*floatarr) / sizeof(floatarr[0]);
130  ATH_MSG_DEBUG("The size of the array is: " << arrSize);
131  ATH_MSG_DEBUG("floatarr[0] = " << floatarr[0]);
132  return floatarr[0];
133  }
134 
135  //new additions for ONNX
136  std::tuple<std::vector<int64_t>, std::vector<const char*>>
138  const std::shared_ptr<Ort::Session>& sessionHandle,
139  Ort::AllocatorWithDefaultOptions& allocator) {
140  // input nodes
141  std::vector<int64_t> input_node_dims;
142  size_t num_input_nodes = sessionHandle->GetInputCount();
143  std::vector<const char*> input_node_names(num_input_nodes);
144 
145  // Loop the input nodes
146  for( std::size_t i = 0; i < num_input_nodes; i++ ) {
147  // Print input node names
148  char* input_name = sessionHandle->GetInputNameAllocated(i, allocator).release();
149  ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<input_name);
150  input_node_names[i] = input_name;
151 
152  // Print input node types
153  Ort::TypeInfo type_info = sessionHandle->GetInputTypeInfo(i);
154  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
155  ONNXTensorElementDataType type = tensor_info.GetElementType();
156  ATH_MSG_DEBUG("Input "<<i<<" : "<<" type= "<<type);
157 
158  // Print input shapes/dims
159  input_node_dims = tensor_info.GetShape();
160  ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<input_node_dims.size());
161  for (std::size_t j = 0; j < input_node_dims.size(); j++){
162  if(input_node_dims[j]<0){input_node_dims[j] =1;}
163  ATH_MSG_DEBUG("Input"<<i<<" : dim "<<j<<"= "<<input_node_dims[j]);
164  }
165  }
166  return std::make_tuple(input_node_dims, input_node_names);
167  }
168 
169  //new additions for ONNX
170  std::tuple<std::vector<int64_t>, std::vector<const char*>>
172  const std::shared_ptr<Ort::Session>& sessionHandle,
173  Ort::AllocatorWithDefaultOptions& allocator) {
174  // output nodes
175  std::vector<int64_t> output_node_dims;
176  size_t num_output_nodes = sessionHandle->GetOutputCount();
177  std::vector<const char*> output_node_names(num_output_nodes);
178 
179  // Loop the output nodes
180  for( std::size_t i = 0; i < num_output_nodes; i++ ) {
181  // Print output node names
182  char* output_name = sessionHandle->GetOutputNameAllocated(i, allocator).release();
183  ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<output_name);
184  output_node_names[i] = output_name;
185 
186  Ort::TypeInfo type_info = sessionHandle->GetOutputTypeInfo(i);
187  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
188  ONNXTensorElementDataType type = tensor_info.GetElementType();
189  ATH_MSG_DEBUG("Output "<<i<<" : "<<" type= "<<type);
190 
191  // Print output shapes/dims
192  output_node_dims = tensor_info.GetShape();
193  ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<output_node_dims.size());
194  for (std::size_t j = 0; j < output_node_dims.size(); j++){
195  if(output_node_dims[j]<0){output_node_dims[j] =1;}
196  ATH_MSG_DEBUG("Output"<<i<<" : dim "<<j<<"= "<<output_node_dims[j]);
197  }
198  }
199  return std::make_tuple(output_node_dims, output_node_names);
200  }
201 
202  //new additions for ONNX
203  std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
205  const std::string& modelFilePath) {
206  // Find the model file.
207  const std::string modelFileName = PathResolverFindCalibFile( modelFilePath );
208  ATH_MSG_INFO( "Using model file: " << modelFileName );
209 
210  // set onnx session options
211  Ort::SessionOptions sessionOptions;
212  sessionOptions.SetIntraOpNumThreads( 1 );
213  sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
214  // set allocator
215  Ort::AllocatorWithDefaultOptions allocator;
216  // set the onnx runtime session
217  std::shared_ptr<Ort::Session> sessionHandle = std::make_shared<Ort::Session>( env, modelFileName.c_str(), sessionOptions );
218 
219  ATH_MSG_INFO( "Created the ONNX Runtime session for model file = " << modelFileName);
220  return std::make_tuple(sessionHandle, allocator);
221  }
222 
223  //____________________________________________________________________________
225  {
226  ATH_MSG_INFO("Initializing PhotonVertexSelectionTool...");
227  // initialize the readers or sessions
228  if(m_isTMVA){
229  // Get full path of configuration files for MVA
232  // Setup MVAs
233  std::vector<std::string> var_names = {
234  "deltaZ := TMath::Min(abs(PrimaryVerticesAuxDyn.z-zCommon)/zCommonError,20)",
235  "deltaPhi := abs(deltaPhi(PrimaryVerticesAuxDyn.phi,egamma_phi))" ,
236  "logSumpt := log10(PrimaryVerticesAuxDyn.sumPt)" ,
237  "logSumpt2 := log10(PrimaryVerticesAuxDyn.sumPt2)"
238  };
239  auto *mva1 = new TMVA::Reader(var_names, "!Silent:Color");
240  mva1->BookMVA ("MLP method", m_TMVAModelFilePath1 );
241  m_mva1 = std::unique_ptr<TMVA::Reader>( mva1 );
242 
243  auto mva2 = std::make_unique<TMVA::Reader>(var_names, "!Silent:Color");
244  mva2->BookMVA ("MLP method", m_TMVAModelFilePath2 );
245  m_mva2 = std::unique_ptr<TMVA::Reader>( std::move(mva2) );
246  }
247  else{ // assume only ONNX for now
248  // create onnx environment
249  Ort::Env env;
250  // converted
254 
255  // unconverted
259  }
260 
261  // initialize the containers
263  ATH_CHECK( m_vertexContainer.initialize() );
264 
269  ATH_CHECK( m_deltaPhiKey.initialize() );
270  ATH_CHECK( m_deltaZKey.initialize() );
271  ATH_CHECK( m_sumPt2Key.initialize() );
272  ATH_CHECK( m_sumPtKey.initialize() );
273 #ifndef XAOD_STANDALONE
276 #endif
277 
278  return StatusCode::SUCCESS;
279  }
280 
281  //____________________________________________________________________________
283  auto fail = FailType::NoFail;
284 
285  const EventContext& ctx = Gaudi::Hive::currentContext();
290 
291  // Get the EventInfo
293 
294  // Find the common z-position from beam / photon pointing information
295  std::pair<float, float> zCommon = xAOD::PVHelpers::getZCommonAndError(&*eventInfo, &egammas, m_convPtCut);
296  // Vector sum of photons
297  TLorentzVector vegamma = getEgammaVector(&egammas, fail);
298 
299  // Retrieve PV collection from TEvent
301 
302  bool writeSumPt2 = !sumPt2.isAvailable();
303  bool writeSumPt = !sumPt.isAvailable();
304 
305  for (const xAOD::Vertex* vertex: *vertices) {
306 
307  // Skip dummy vertices
308  if (!(vertex->vertexType() == xAOD::VxType::VertexType::PriVtx ||
309  vertex->vertexType() == xAOD::VxType::VertexType::PileUp)) continue;
310 
311  // Set input variables
312  if (writeSumPt) {
313  sumPt(*vertex) = xAOD::PVHelpers::getVertexSumPt(vertex, 1, false);
314  }
315 
316  if (writeSumPt2) {
318  }
319 
320  // Get momentum vector of vertex
321  TLorentzVector vmom = xAOD::PVHelpers::getVertexMomentum(vertex, true, m_derivationPrefix);
322 
323  deltaPhi(*vertex) = (fail != FailType::FailEgamVect) ? std::abs(vmom.DeltaPhi(vegamma)) : -999.;
324  deltaZ(*vertex) = std::abs((zCommon.first - vertex->z())/zCommon.second);
325 
326  } // loop over vertices
327 
328  ATH_MSG_DEBUG("DecorateInputs exit code "<< fail);
329  if(failType!=nullptr)
330  *failType = fail;
331  return StatusCode::SUCCESS;
332  }
333 
334  //____________________________________________________________________________
335  std::vector<std::pair<const xAOD::Vertex*, float>>
337  bool ignoreConv,
338  bool noDecorate,
339  yyVtxType* vtxCasePtr,
340  FailType* failTypePtr) const {
341  const xAOD::Vertex *vertex = nullptr;
342  std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
343  yyVtxType vtxCase = yyVtxType::Unknown;
344  FailType failType = FailType::NoFail;
345  if (getVertexImp( egammas, vertex, ignoreConv, noDecorate, vertexMLP, vtxCase, failType ).isSuccess()) {
346  std::sort(vertexMLP.begin(), vertexMLP.end(), sortMLP);
347  }
348  if(vtxCasePtr!=nullptr)
349  *vtxCasePtr = vtxCase;
350  if(failTypePtr!=nullptr)
351  *failTypePtr = failType;
352 
353  return vertexMLP;
354  }
355 
356  //____________________________________________________________________________
358  const xAOD::Vertex* &prime_vertex,
359  bool ignoreConv) const
360  {
361  std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
362  yyVtxType vtxcase = yyVtxType::Unknown;
363  FailType failType = FailType::NoFail;
364  return getVertexImp( egammas, prime_vertex, ignoreConv, false, vertexMLP, vtxcase, failType );
365  }
366 
368  const xAOD::EgammaContainer& egammas,
369  const xAOD::Vertex*& prime_vertex,
370  bool ignoreConv,
371  bool noDecorate,
372  std::vector<std::pair<const xAOD::Vertex*, float>>& vertexMLP,
373  yyVtxType& vtxCase,
374  FailType& fail) const {
375  // Set default vertex case and declare photon container
376  vtxCase = yyVtxType::Unknown;
377  const xAOD::PhotonContainer *photons = dynamic_cast<const xAOD::PhotonContainer*>(&egammas);
378 
379  // Retrieve PV collection from TEvent
381 
382  if (!noDecorate && !decorateInputs(egammas).isSuccess()){
383  return StatusCode::FAILURE;
384  }
385 
386  // Check if a conversion photon has a track attached to a primary/pileup vertex
387  if (!ignoreConv && photons) {
388  prime_vertex = getPrimaryVertexFromConv(photons);
389  if (prime_vertex != nullptr) {
390  vtxCase = yyVtxType::ConvTrack;
391  fail = FailType::MatchedTrack;
392  vertexMLP.emplace_back(prime_vertex, 0.);
393  return StatusCode::SUCCESS;
394  }
395  }
396 
397  if (fail != FailType::NoFail){
398  ATH_MSG_VERBOSE("Returning hardest vertex. Fail detected (type="<< fail <<")");
399  vertexMLP.clear();
400  prime_vertex = xAOD::PVHelpers::getHardestVertex(&*vertices);
401  vertexMLP.emplace_back(prime_vertex, 10.);
402  return StatusCode::SUCCESS;
403  }
404 
405  // Get the EventInfo
407 
408  // If there are any silicon conversions passing selection
409  // ==> use Model 1 (Conv) otherwise Model 2 (Unconv)
410  // Set default for conversion bool as false unless otherwise
411  bool isConverted = false;
412 
413  // assume default NoSiTrack (unconverted) unless otherwise
414  vtxCase = yyVtxType::NoSiTracks;
415  if (!ignoreConv && photons) {
416  for (const auto *photon: *photons) {
417  if (!photon)
418  {
419  ATH_MSG_WARNING("Null pointer to photon");
420  return StatusCode::FAILURE;
421  }
422  // find out if pass conversion selection criteria and tag as SiConvTrack case
424  {
425  isConverted = true;
426  vtxCase = yyVtxType::SiConvTrack;
427  }
428  }
429  }
430 
431  // if TMVA chosen, declare tmva_reader only once (before for looping vertex)
432  TMVA::Reader *tmva_reader = new TMVA::Reader();
433  if(m_isTMVA){
434  if(isConverted){
435  // If there are any silicon conversions passing selection, use MVA1 (converted case)
436  tmva_reader = m_mva1.get();
437  }
438  // Otherwise, use MVA2 (unconverted case)
439  if(!isConverted){
440  tmva_reader = m_mva2.get();
441  }
442  }
443  ATH_MSG_DEBUG("Vtx Case: " << vtxCase);
444 
445  // Vector sum of photons
446  TLorentzVector vegamma = getEgammaVector(&egammas, fail);
447 
452 
453  // Loop over vertices and find best candidate
454  std::vector<float> ONNXInputVector;
455  std::vector<std::vector<float>> onnx_input_tensor_values;
456  std::vector<float> TMVAInputVector;
457  TString TMVAMethod;
458  float mlp = 0.0, mlp_max = -99999.0;
459  float doSkipByZSigmaScore = -9999.0;
460  // assign threshold score value to compare later for good vtx
461  float thresGoodVtxScore;
462  if(m_doSkipByZSigma){thresGoodVtxScore = doSkipByZSigmaScore;}
463  else{thresGoodVtxScore = mlp_max;}
464  for (const xAOD::Vertex* vertex: *vertices) {
465  // Skip dummy vertices
466  if (!(vertex->vertexType() == xAOD::VxType::VertexType::PriVtx ||
467  vertex->vertexType() == xAOD::VxType::VertexType::PileUp)) continue;
468 
469  onnx_input_tensor_values.clear();
470 
471  // Variables used as input features in classifier
472  float sumPt, sumPt2, deltaPhi, deltaZ;
473  float log10_sumPt, log10_sumPt2;
474 
475  sumPt = (sumPtA)(*vertex);
476  sumPt2 = (sumPt2A)(*vertex);
477  deltaPhi = (deltaPhiA)(*vertex);
478  deltaZ = (deltaZA)(*vertex);
479  ATH_MSG_VERBOSE("sumPt: " << sumPt <<
480  " sumPt2: " << sumPt2 <<
481  " deltaPhi: " << deltaPhi <<
482  " deltaZ: " << deltaZ);
483 
484  // setup the vector of input features based on selected inference framework
485  if(m_isTMVA){
486  // Get likelihood probability from TMVA model
487  TMVAMethod = "MLP method";
488  log10_sumPt = static_cast<float>(log10(sumPt));
489  log10_sumPt2 = static_cast<float>(log10(sumPt2));
490  TMVAInputVector = {deltaZ,deltaPhi,log10_sumPt,log10_sumPt2};
491  }
492  else{ //assume ony ONNX for now
493  // Get likelihood probability from onnx model
494  // check if value is 0, assign small number like 1e-8 as dummy, as we will take log later (log(0) is nan)
495  // note that the ordering here is a bit different, following the order used when training
496  ONNXInputVector = {sumPt2, sumPt, deltaPhi, deltaZ};
497  for (long unsigned int i = 0; i < ONNXInputVector.size(); i++) {
498  // skip log for deltaPhi and take log for the rest
499  if (i == 2) {
500  continue;
501  }
502  if (ONNXInputVector[i] != 0 && std::isinf(ONNXInputVector[i]) != true && std::isnan(ONNXInputVector[i]) != true){
503  ONNXInputVector[i] = log(std::abs(ONNXInputVector[i]));
504  }
505  else{
506  ONNXInputVector[i] = log(std::abs(0.00000001)); //log(abs(1e-8))
507  }
508  } //end ONNXInputVector for loop
509  onnx_input_tensor_values.push_back(ONNXInputVector);
510  }
511 
512  // Do the actual calculation of classifier score part
513  if(m_isTMVA){
514  mlp = tmva_reader->EvaluateMVA(TMVAInputVector, TMVAMethod);
515  ATH_MSG_VERBOSE("TMVA output: " << (tmva_reader == m_mva1.get() ? "MVA1 ": "MVA2 ")<< mlp);
516  }
517  else{ //assume ony ONNX for now
518  if(isConverted){
519  mlp = getScore(m_nVars, onnx_input_tensor_values,
522  }
523  if(!isConverted){
524  mlp = getScore(m_nVars, onnx_input_tensor_values,
527  }
528  ATH_MSG_VERBOSE("log(abs(sumPt)): " << sumPt <<
529  " log(abs(sumPt2)): " << sumPt2 <<
530  " deltaPhi: " << deltaPhi <<
531  " log(abs(deltaZ)): " << deltaZ);
532  ATH_MSG_VERBOSE("ONNX output, isConverted = " << isConverted << ", mlp=" << mlp);
533  }
534 
535  // Skip vertices above 10 sigma from pointing or 15 sigma from conversion (HPV)
536  // Simply displace the mlp variable we calculate before by a predefined value
537  if(m_doSkipByZSigma){
538  if ((isConverted && deltaZ > 15) || (!isConverted && deltaZ > 10)) {
539  mlp = doSkipByZSigmaScore;
540  }
541  }
542 
543  // add the new vertex and its score to vertexMLP container
544  vertexMLP.emplace_back(vertex, mlp);
545 
546  // Keep track of maximal likelihood vertex
547  if (mlp > mlp_max) {
548  mlp_max = mlp;
549  prime_vertex = vertex;
550  }
551  } // end loop over vertices
552 
553  // from all the looped vertices, decide the max score which should be more than the minimum we set
554  // (which should be more than the initial mlp_max value above or more than the skip vertex by z-sigma score)
555  // if this does not pass, return hardest primary vertex
556  if (mlp_max <= thresGoodVtxScore) {
557  ATH_MSG_DEBUG("No good vertex candidates from pointing, returning hardest vertex.");
558  prime_vertex = xAOD::PVHelpers::getHardestVertex(&*vertices);
559  fail = FailType::NoGdCandidate;
560  vertexMLP.clear();
561  vertexMLP.emplace_back(xAOD::PVHelpers::getHardestVertex(&*vertices), 20.);
562  }
563 
564  ATH_MSG_VERBOSE("getVertex case "<< (int)vtxCase << " exit code "<< (int)fail);
565  return StatusCode::SUCCESS;
566  }
567 
568  //____________________________________________________________________________
569  bool PhotonVertexSelectionTool::sortMLP(const std::pair<const xAOD::Vertex*, float> &a,
570  const std::pair<const xAOD::Vertex*, float> &b)
571  { return a.second > b.second; }
572 
573  //____________________________________________________________________________
575  {
576  if (photons == nullptr) {
577  ATH_MSG_WARNING("Passed nullptr photon container, returning nullptr vertex from getPrimaryVertexFromConv");
578  return nullptr;
579  }
580 
581  std::vector<const xAOD::Vertex*> vertices;
582  const xAOD::Vertex *conversionVertex = nullptr, *primary = nullptr;
583  const xAOD::TrackParticle *tp = nullptr;
584  size_t NumberOfTracks = 0;
585 
586  // Retrieve PV collection from TEvent
588 
589 
590  for (const auto *photon: *photons) {
591  conversionVertex = photon->vertex();
592  if (conversionVertex == nullptr) continue;
593 
594  NumberOfTracks = conversionVertex->nTrackParticles();
595  for (size_t i = 0; i < NumberOfTracks; ++i) {
596  // Get trackParticle in GSF collection
597  const auto *gsfTp = conversionVertex->trackParticle(i);
598  if (gsfTp == nullptr) continue;
599  if (!xAOD::PVHelpers::passConvSelection(*conversionVertex, i, m_convPtCut)) continue;
600 
601  // Get trackParticle in InDet collection
603  if (tp == nullptr) continue;
604 
605  primary = getVertexFromTrack(tp, &*all_vertices);
606  if (primary == nullptr) continue;
607 
608  if (primary->vertexType() == xAOD::VxType::VertexType::PriVtx ||
609  primary->vertexType() == xAOD::VxType::VertexType::PileUp) {
610  if (std::find(vertices.begin(), vertices.end(), primary) == vertices.end()) {
611  vertices.push_back(primary);
612  continue;
613  }
614  }
615  }
616  }
617 
618  if (!vertices.empty()) {
619  if (vertices.size() > 1)
620  ATH_MSG_WARNING("Photons associated to different vertices! Returning lead photon association.");
621  return vertices[0];
622  }
623 
624  return nullptr;
625  }
626 
627  //____________________________________________________________________________
628  TLorentzVector PhotonVertexSelectionTool::getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const
629  {
630  TLorentzVector v, v1;
631  const xAOD::CaloCluster *cluster = nullptr;
632  for (const xAOD::Egamma* egamma: *egammas) {
633  if (egamma == nullptr) {
634  ATH_MSG_DEBUG("No egamma object to get four vector");
635  failType = FailType::FailEgamVect;
636  continue;
637  }
638  cluster = egamma->caloCluster();
639  if (cluster == nullptr) {
640  ATH_MSG_WARNING("No cluster associated to egamma, not adding to 4-vector.");
641  continue;
642  }
643 
644  v1.SetPtEtaPhiM(egamma->e()/cosh(cluster->etaBE(2)),
645  cluster->etaBE(2),
646  cluster->phiBE(2),
647  0.0);
648  v += v1;
649  }
650  return v;
651  }
652 
653 } // namespace CP
CP::PhotonVertexSelectionTool::PhotonVertexSelectionTool
PhotonVertexSelectionTool(const std::string &name)
Definition: PhotonVertexSelectionTool.cxx:54
CP::PhotonVertexSelectionTool::setONNXSession
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
Definition: PhotonVertexSelectionTool.cxx:204
GetLCDefs::Unknown
@ Unknown
Definition: GetLCDefs.h:21
xAOD::PVHelpers::getVertexSumPt
float getVertexSumPt(const xAOD::Vertex *vertex, int power=1, bool useAux=true)
Loop over track particles associated with vertex and return scalar sum of pT^power in GeV (from auxda...
Definition: PhotonVertexHelpers.cxx:214
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath2
std::string m_TMVAModelFilePath2
Definition: PhotonVertexSelectionTool.h:63
CP::PhotonVertexSelectionTool::m_derivationPrefix
std::string m_derivationPrefix
Definition: PhotonVertexSelectionTool.h:51
xAOD::Vertex_v1::nTrackParticles
size_t nTrackParticles() const
Get the number of tracks associated with this vertex.
Definition: Vertex_v1.cxx:270
CP::PhotonVertexSelectionTool::getPrimaryVertexFromConv
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
Definition: PhotonVertexSelectionTool.cxx:574
CP::PhotonVertexSelectionTool::m_sumPt2Key
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
Definition: PhotonVertexSelectionTool.h:171
SG::WriteDecorHandle::isAvailable
bool isAvailable()
Test to see if this variable exists in the store, for the referenced object.
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
find
std::string find(const std::string &s)
return a remapped string
Definition: hcg.cxx:135
CP::PhotonVertexSelectionTool::m_TMVAModelFilePath1
std::string m_TMVAModelFilePath1
Definition: PhotonVertexSelectionTool.h:62
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:67
CP::PhotonVertexSelectionTool::getEgammaVector
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
Definition: PhotonVertexSelectionTool.cxx:628
xAOD::deltaPhi
setSAddress setEtaMS setDirPhiMS setDirZMS setBarrelRadius setEndcapAlpha setEndcapRadius setInterceptInner setEtaMap setEtaBin setIsTgcFailure setDeltaPt deltaPhi
Definition: L2StandAloneMuon_v1.cxx:161
xAOD::Vertex_v1::trackParticleLinks
const TrackParticleLinks_t & trackParticleLinks() const
Get all the particles associated with the vertex.
CP::PhotonVertexSelectionTool::m_vertexContainer
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
Definition: PhotonVertexSelectionTool.h:56
xAOD::Vertex_v1::trackWeights
const std::vector< float > & trackWeights() const
Get all the track weights.
AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)
Definition: AthCommonDataStore.h:380
CP::PhotonVertexSelectionTool::m_input_node_dims1
std::vector< int64_t > m_input_node_dims1
Definition: PhotonVertexSelectionTool.h:78
xAOD::PVHelpers::getZCommonAndError
std::pair< float, float > getZCommonAndError(const xAOD::EventInfo *eventInfo, const xAOD::EgammaContainer *egammas, float convPtCut=2e3)
Return zCommon and zCommonError.
Definition: PhotonVertexHelpers.cxx:43
asg
Definition: DataHandleTestTool.h:28
CP::PhotonVertexSelectionTool::m_sessionHandle2
std::shared_ptr< Ort::Session > m_sessionHandle2
Definition: PhotonVertexSelectionTool.h:86
ParticleTest.tp
tp
Definition: ParticleTest.py:25
CP::PhotonVertexSelectionTool::m_mva1
std::unique_ptr< TMVA::Reader > m_mva1
Definition: PhotonVertexSelectionTool.h:67
xAOD::CaloCluster_v1::phiBE
float phiBE(const unsigned layer) const
Get the phi in one layer of the EM Calo.
Definition: CaloCluster_v1.cxx:634
xAOD::Egamma_v1
Definition: Egamma_v1.h:56
SG::ConstAccessor
Helper class to provide constant type-safe access to aux data.
Definition: ConstAccessor.h:55
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
isValid
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
Definition: AtlasPID.h:867
CP::PhotonVertexSelectionTool::sortMLP
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
Definition: PhotonVertexSelectionTool.cxx:569
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath1
std::string m_ONNXModelFilePath1
Definition: PhotonVertexSelectionTool.h:73
CP
Select isolated Photons, Electrons and Muons.
Definition: Control/xAODRootAccess/xAODRootAccess/TEvent.h:49
CP::PhotonVertexSelectionTool::m_sessionHandle1
std::shared_ptr< Ort::Session > m_sessionHandle1
Definition: PhotonVertexSelectionTool.h:85
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
egamma
Definition: egamma.h:58
ParticleImpl::e
virtual double e() const
energy
Definition: ParticleImpl.h:534
CP::PhotonVertexSelectionTool::m_output_node_dims2
std::vector< int64_t > m_output_node_dims2
Definition: PhotonVertexSelectionTool.h:81
xAOD::CaloCluster_v1::etaBE
float etaBE(const unsigned layer) const
Get the eta in one layer of the EM Calo.
Definition: CaloCluster_v1.cxx:628
CP::PhotonVertexSelectionTool::m_eventInfo
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
Definition: PhotonVertexSelectionTool.h:54
xAOD::CaloCluster_v1
Description of a calorimeter cluster.
Definition: CaloCluster_v1.h:62
xAOD::PVHelpers::getHardestVertex
const xAOD::Vertex * getHardestVertex(const xAOD::VertexContainer *vertices)
Return vertex with highest sum pT^2.
Definition: PhotonVertexHelpers.cxx:28
EgammaxAODHelpers.h
PhotonVertexSelectionTool.h
StateLessPT_NewConfig.primary
primary
Definition: StateLessPT_NewConfig.py:234
CP::PhotonVertexSelectionTool::initialize
virtual StatusCode initialize()
Function initialising the tool.
Definition: PhotonVertexSelectionTool.cxx:224
CheckAppliedSFs.e3
e3
Definition: CheckAppliedSFs.py:264
CP::PhotonVertexSelectionTool::m_allocator1
Ort::AllocatorWithDefaultOptions m_allocator1
Definition: PhotonVertexSelectionTool.h:88
CP::PhotonVertexSelectionTool::getVertexImp
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float >> &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
Definition: PhotonVertexSelectionTool.cxx:367
CP::IPhotonVertexSelectionTool::yyVtxType
yyVtxType
Definition: IPhotonVertexSelectionTool.h:44
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
CP::PhotonVertexSelectionTool::m_sumPtKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
Definition: PhotonVertexSelectionTool.h:173
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
Definition: AthCommonDataStore.h:145
SG::WriteDecorHandle
Handle class for adding a decoration to an object.
Definition: StoreGate/StoreGate/WriteDecorHandle.h:100
CP::PhotonVertexSelectionTool::m_convPtCut
float m_convPtCut
Definition: PhotonVertexSelectionTool.h:48
CP::PhotonVertexSelectionTool::getOutputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:171
CP::PhotonVertexSelectionTool::~PhotonVertexSelectionTool
virtual ~PhotonVertexSelectionTool()
CP::IPhotonVertexSelectionTool::FailType
FailType
Declare the interface that the class provides.
Definition: IPhotonVertexSelectionTool.h:33
xAOD::VxType::PriVtx
@ PriVtx
Primary vertex.
Definition: TrackingPrimitives.h:572
LHEF::Reader
Pythia8::Reader Reader
Definition: Prophecy4fMerger.cxx:11
CP::PhotonVertexSelectionTool::m_nVars
int m_nVars
Create a proper constructor for Athena.
Definition: PhotonVertexSelectionTool.h:47
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
CP::PhotonVertexSelectionTool::getInputNodes
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
Definition: PhotonVertexSelectionTool.cxx:137
xAOD::Vertex_v1::trackParticle
const TrackParticle * trackParticle(size_t i) const
Get the pointer to a given track that was used in vertex reco.
Definition: Vertex_v1.cxx:249
SG::VarHandleKey::initialize
StatusCode initialize(bool used=true)
If this object is used as a property, then this should be called during the initialize phase.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:103
xAOD::PVHelpers::getVertexMomentum
TLorentzVector getVertexMomentum(const xAOD::Vertex *vertex, bool useAux=true, const std::string &derivationPrefix="")
Return vector sum of tracks associated with vertex (from auxdata if available and useAux = true)
Definition: PhotonVertexHelpers.cxx:174
SG::decorKeyFromKey
std::string decorKeyFromKey(const std::string &key, const std::string &deflt)
Extract the decoration part of key.
Definition: DecorKeyHelpers.cxx:42
DataVector
Derived DataVector<T>.
Definition: DataVector.h:794
WriteDecorHandle.h
Handle class for adding a decoration to an object.
CP::PhotonVertexSelectionTool::m_deltaPhiKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
Definition: PhotonVertexSelectionTool.h:167
PhotonVertexHelpers.h
CP::PhotonVertexSelectionTool::m_input_node_dims2
std::vector< int64_t > m_input_node_dims2
Definition: PhotonVertexSelectionTool.h:81
CP::PhotonVertexSelectionTool::m_ONNXModelFilePath2
std::string m_ONNXModelFilePath2
Definition: PhotonVertexSelectionTool.h:74
xAOD::VxType::PileUp
@ PileUp
Pile-up vertex.
Definition: TrackingPrimitives.h:574
CP::PhotonVertexSelectionTool::m_output_node_names1
std::vector< const char * > m_output_node_names1
Definition: PhotonVertexSelectionTool.h:79
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:76
ReadHandle.h
Handle class for reading from StoreGate.
CP::PhotonVertexSelectionTool::getVertex
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
Definition: PhotonVertexSelectionTool.cxx:357
CP::PhotonVertexSelectionTool::m_mva2
std::unique_ptr< TMVA::Reader > m_mva2
Definition: PhotonVertexSelectionTool.h:68
xAOD::EgammaHelpers::getOriginalTrackParticleFromGSF
const xAOD::TrackParticle * getOriginalTrackParticleFromGSF(const xAOD::TrackParticle *trkPar)
Helper function for getting the "Original" Track Particle (i.e before GSF) via the GSF Track Particle...
Definition: ElectronxAODHelpers.cxx:22
CP::PhotonVertexSelectionTool::m_output_node_names2
std::vector< const char * > m_output_node_names2
Definition: PhotonVertexSelectionTool.h:82
CP::PhotonVertexSelectionTool::m_isTMVA
bool m_isTMVA
Definition: PhotonVertexSelectionTool.h:61
EventInfo.h
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
CP::PhotonVertexSelectionTool::m_doSkipByZSigma
bool m_doSkipByZSigma
Definition: PhotonVertexSelectionTool.h:49
python.PyAthena.v
v
Definition: PyAthena.py:154
CP::PhotonVertexSelectionTool::m_allocator2
Ort::AllocatorWithDefaultOptions m_allocator2
Definition: PhotonVertexSelectionTool.h:89
Trk::vertex
@ vertex
Definition: MeasurementType.h:21
VertexContainer.h
xAOD::photon
@ photon
Definition: TrackingPrimitives.h:200
a
TList * a
Definition: liststreamerinfos.cxx:10
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
CP::PhotonVertexSelectionTool::decorateInputs
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
Definition: PhotonVertexSelectionTool.cxx:282
EgammaDefs.h
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
CP::PhotonVertexSelectionTool::getScore
float getScore(int nVars, const std::vector< std::vector< float >> &input_data, const std::shared_ptr< Ort::Session > &sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
Definition: PhotonVertexSelectionTool.cxx:89
xAOD::PVHelpers::passConvSelection
bool passConvSelection(const xAOD::Photon *photon, float convPtCut=2e3)
Check if photon is converted, and tracks have Si hits and pass selection.
Definition: PhotonVertexHelpers.cxx:132
CP::PhotonVertexSelectionTool::m_input_node_names1
std::vector< const char * > m_input_node_names1
Definition: PhotonVertexSelectionTool.h:79
CP::PhotonVertexSelectionTool::m_output_node_dims1
std::vector< int64_t > m_output_node_dims1
Definition: PhotonVertexSelectionTool.h:78
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
makeComparison.deltaZ
int deltaZ
Definition: makeComparison.py:46
CP::PhotonVertexSelectionTool::m_deltaZKey
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Definition: PhotonVertexSelectionTool.h:169
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32
CP::PhotonVertexSelectionTool::m_input_node_names2
std::vector< const char * > m_input_node_names2
Definition: PhotonVertexSelectionTool.h:82
PhotonContainer.h
CP::getVertexFromTrack
const xAOD::Vertex * getVertexFromTrack(const xAOD::TrackParticle *track, const xAOD::VertexContainer *vertices)
Definition: PhotonVertexSelectionTool.cxx:30
beamspotman.fail
def fail(message)
Definition: beamspotman.py:199