ATLAS Offline Software
Loading...
Searching...
No Matches
PhotonVertexSelectionTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5// Local includes
8
9// EDM includes
15
16// Framework includes
20
21// ROOT includes
22#include "TMVA/Reader.h"
23
24// std includes
25#include <algorithm>
26
27namespace CP {
28
29 // helper function to get the vertex of a track
31 const xAOD::VertexContainer* vertices)
32 {
33 const xAOD::Vertex* vtxWithLargestWeight = nullptr;
34 float largestWeight = 0;
35
36 for (const auto *vtx : *vertices) {
37 //Search for vertex linked to this track
38 const auto& trkLinks=vtx->trackParticleLinks();
39 const size_t nTrackLinks=trkLinks.size();
40 for (unsigned i=0;i<nTrackLinks;++i) {
41 if (trkLinks[i].isValid() && *(trkLinks[i]) == track) {//ptr comparison
42 if( vtx->trackWeights()[i] > largestWeight ){
43 vtxWithLargestWeight = vtx;
44 largestWeight = vtx->trackWeights()[i];
45 }
46 }
47 }
48 }
49
50 return vtxWithLargestWeight;
51 }
52
53 //____________________________________________________________________________
55 : asg::AsgTool(name)
56 {
57 // run 2 NN model:
58 // m_doSkipByZSigma = true, m_isTMVA = true
59 // run 3 NN model:
60 // m_doSkipByZSigma = false, m_isTMVA = false
61
62 // default variables
63 declareProperty("nVars", m_nVars = 4);
64 declareProperty("conversionPtCut", m_convPtCut = 2e3);
65 declareProperty("DoSkipByZSigma", m_doSkipByZSigma = false);
66
67 declareProperty("derivationPrefix", m_derivationPrefix = "");
68
69 // boolean for TMVA, default true
70 declareProperty("isTMVA", m_isTMVA = false);
71
72 // config files (TMVA), default paths if not set
73 declareProperty("ConfigFileCase1",
74 m_TMVAModelFilePath1 = "PhotonVertexSelection/v1/DiphotonVertex_case1.weights.xml");
75 declareProperty("ConfigFileCase2",
76 m_TMVAModelFilePath2 = "PhotonVertexSelection/v1/DiphotonVertex_case2.weights.xml");
77
78 // config files (ONNX), default paths if not set
79 declareProperty("ONNXModelFileCase1", m_ONNXModelFilePath1 = "PhotonVertexSelection/run3nn/model1.onnx");
80 declareProperty("ONNXModelFileCase2", m_ONNXModelFilePath2 = "PhotonVertexSelection/run3nn/model2.onnx");
81 }
82
83 //____________________________________________________________________________
85 = default;
86
87 //____________________________________________________________________________
88 //new additions for ONNX
90 int nVars, const std::vector<std::vector<float>>& input_data,
91 const std::shared_ptr<Ort::Session>& sessionHandle,
92 std::vector<int64_t> input_node_dims,
93 std::vector<const char*> input_node_names,
94 std::vector<const char*> output_node_names) const {
95 //*************************************************************************
96 // score the model using sample data, and inspect values
97 // loading input data
98 std::vector<std::vector<float>> input_tensor_values_ = input_data;
99
100 //preparing container to hold input data
101 size_t input_tensor_size = nVars;
102 std::vector<float> input_tensor_values(nVars);
103 input_tensor_values = input_tensor_values_[0]; //0th element since only batch_size of 1, otherwise loop
104
105 // create input tensor object from data values
106 auto memory_info =
107 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108 // create tensor using info from inputs
109 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
110 memory_info, input_tensor_values.data(), input_tensor_size,
111 input_node_dims.data(), input_node_dims.size());
112
113 // check if input is of type tensor
114 assert(input_tensor.IsTensor());
115
116 // run the inference
117 auto output_tensors =
118 sessionHandle->Run(Ort::RunOptions{nullptr}, input_node_names.data(),
119 &input_tensor, input_node_names.size(),
120 output_node_names.data(), output_node_names.size());
121
122 // check size of output tensor
123 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
124
125 // get pointer to output tensor float values
126 // float* floatarr = output_tensors.front().GetTensorMutableData<float>();
127 float* floatarr = output_tensors[0].GetTensorMutableData<float>();
128
129 int arrSize = sizeof(*floatarr) / sizeof(floatarr[0]);
130 ATH_MSG_DEBUG("The size of the array is: " << arrSize);
131 ATH_MSG_DEBUG("floatarr[0] = " << floatarr[0]);
132 return floatarr[0];
133 }
134
135 //new additions for ONNX
136 std::tuple<std::vector<int64_t>, std::vector<const char*>>
138 const std::shared_ptr<Ort::Session>& sessionHandle,
139 Ort::AllocatorWithDefaultOptions& allocator) {
140 // input nodes
141 std::vector<int64_t> input_node_dims;
142 size_t num_input_nodes = sessionHandle->GetInputCount();
143 std::vector<const char*> input_node_names(num_input_nodes);
144
145 // Loop the input nodes
146 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
147 // Print input node names
148 char* input_name = sessionHandle->GetInputNameAllocated(i, allocator).release();
149 ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<input_name);
150 input_node_names[i] = input_name;
151
152 // Print input node types
153 Ort::TypeInfo type_info = sessionHandle->GetInputTypeInfo(i);
154 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
155 ONNXTensorElementDataType type = tensor_info.GetElementType();
156 ATH_MSG_DEBUG("Input "<<i<<" : "<<" type= "<<type);
157
158 // Print input shapes/dims
159 input_node_dims = tensor_info.GetShape();
160 ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<input_node_dims.size());
161 for (std::size_t j = 0; j < input_node_dims.size(); j++){
162 if(input_node_dims[j]<0){input_node_dims[j] =1;}
163 ATH_MSG_DEBUG("Input"<<i<<" : dim "<<j<<"= "<<input_node_dims[j]);
164 }
165 }
166 return std::make_tuple(input_node_dims, input_node_names);
167 }
168
169 //new additions for ONNX
170 std::tuple<std::vector<int64_t>, std::vector<const char*>>
172 const std::shared_ptr<Ort::Session>& sessionHandle,
173 Ort::AllocatorWithDefaultOptions& allocator) {
174 // output nodes
175 std::vector<int64_t> output_node_dims;
176 size_t num_output_nodes = sessionHandle->GetOutputCount();
177 std::vector<const char*> output_node_names(num_output_nodes);
178
179 // Loop the output nodes
180 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
181 // Print output node names
182 char* output_name = sessionHandle->GetOutputNameAllocated(i, allocator).release();
183 ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<output_name);
184 output_node_names[i] = output_name;
185
186 Ort::TypeInfo type_info = sessionHandle->GetOutputTypeInfo(i);
187 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
188 ONNXTensorElementDataType type = tensor_info.GetElementType();
189 ATH_MSG_DEBUG("Output "<<i<<" : "<<" type= "<<type);
190
191 // Print output shapes/dims
192 output_node_dims = tensor_info.GetShape();
193 ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<output_node_dims.size());
194 for (std::size_t j = 0; j < output_node_dims.size(); j++){
195 if(output_node_dims[j]<0){output_node_dims[j] =1;}
196 ATH_MSG_DEBUG("Output"<<i<<" : dim "<<j<<"= "<<output_node_dims[j]);
197 }
198 }
199 return std::make_tuple(output_node_dims, output_node_names);
200 }
201
202 //new additions for ONNX
203 std::tuple<std::shared_ptr<Ort::Session>, Ort::AllocatorWithDefaultOptions>
205 const std::string& modelFilePath) {
206 // Find the model file.
207 const std::string modelFileName = PathResolverFindCalibFile( modelFilePath );
208 ATH_MSG_INFO( "Using model file: " << modelFileName );
209
210 // set onnx session options
211 Ort::SessionOptions sessionOptions;
212 sessionOptions.SetIntraOpNumThreads( 1 );
213 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
214 // set allocator
215 Ort::AllocatorWithDefaultOptions allocator;
216 // set the onnx runtime session
217 std::shared_ptr<Ort::Session> sessionHandle = std::make_shared<Ort::Session>( env, modelFileName.c_str(), sessionOptions );
218
219 ATH_MSG_INFO( "Created the ONNX Runtime session for model file = " << modelFileName);
220 return std::make_tuple(sessionHandle, allocator);
221 }
222
223 //____________________________________________________________________________
225 {
226 ATH_MSG_INFO("Initializing PhotonVertexSelectionTool...");
227 // initialize the readers or sessions
228 if(m_isTMVA){
229 // Get full path of configuration files for MVA
232 // Setup MVAs
233 std::vector<std::string> var_names = {
234 "deltaZ := TMath::Min(abs(PrimaryVerticesAuxDyn.z-zCommon)/zCommonError,20)",
235 "deltaPhi := abs(deltaPhi(PrimaryVerticesAuxDyn.phi,egamma_phi))" ,
236 "logSumpt := log10(PrimaryVerticesAuxDyn.sumPt)" ,
237 "logSumpt2 := log10(PrimaryVerticesAuxDyn.sumPt2)"
238 };
239 auto *mva1 = new TMVA::Reader(var_names, "!Silent:Color");
240 mva1->BookMVA ("MLP method", m_TMVAModelFilePath1 );
241 m_mva1 = std::unique_ptr<TMVA::Reader>( mva1 );
242
243 auto mva2 = std::make_unique<TMVA::Reader>(var_names, "!Silent:Color");
244 mva2->BookMVA ("MLP method", m_TMVAModelFilePath2 );
245 m_mva2 = std::unique_ptr<TMVA::Reader>( std::move(mva2) );
246 }
247 else{ // assume only ONNX for now
248 // create onnx environment
249 Ort::Env env;
250 // converted
254
255 // unconverted
259 }
260
261 // initialize the containers
262 ATH_CHECK( m_eventInfo.initialize() );
263 ATH_CHECK( m_vertexContainer.initialize() );
264
269 ATH_CHECK( m_deltaPhiKey.initialize() );
270 ATH_CHECK( m_deltaZKey.initialize() );
271 ATH_CHECK( m_sumPt2Key.initialize() );
272 ATH_CHECK( m_sumPtKey.initialize() );
273#ifndef XAOD_STANDALONE
276#endif
277
278 return StatusCode::SUCCESS;
279 }
280
281 //____________________________________________________________________________
283 auto fail = FailType::NoFail;
284
285 const EventContext& ctx = Gaudi::Hive::currentContext();
290
291 // Get the EventInfo
293
294 // Find the common z-position from beam / photon pointing information
295 std::pair<float, float> zCommon = xAOD::PVHelpers::getZCommonAndError(&*eventInfo, &egammas, m_convPtCut);
296 // Vector sum of photons
297 TLorentzVector vegamma = getEgammaVector(&egammas, fail);
298
299 // Retrieve PV collection from TEvent
301
302 bool writeSumPt2 = !sumPt2.isAvailable();
303 bool writeSumPt = !sumPt.isAvailable();
304
305 for (const xAOD::Vertex* vertex: *vertices) {
306
307 // Skip dummy vertices
308 if (!(vertex->vertexType() == xAOD::VxType::VertexType::PriVtx ||
309 vertex->vertexType() == xAOD::VxType::VertexType::PileUp)) continue;
310
311 // Set input variables
312 if (writeSumPt) {
313 sumPt(*vertex) = xAOD::PVHelpers::getVertexSumPt(vertex, 1, false);
314 }
315
316 if (writeSumPt2) {
317 sumPt2(*vertex) = xAOD::PVHelpers::getVertexSumPt(vertex, 2);
318 }
319
320 // Get momentum vector of vertex
321 TLorentzVector vmom = xAOD::PVHelpers::getVertexMomentum(vertex, true, m_derivationPrefix);
322
323 deltaPhi(*vertex) = (fail != FailType::FailEgamVect) ? std::abs(vmom.DeltaPhi(vegamma)) : -999.;
324 deltaZ(*vertex) = std::abs((zCommon.first - vertex->z())/zCommon.second);
325
326 } // loop over vertices
327
328 ATH_MSG_DEBUG("DecorateInputs exit code "<< fail);
329 if(failType!=nullptr)
330 *failType = fail;
331 return StatusCode::SUCCESS;
332 }
333
334 //____________________________________________________________________________
335 std::vector<std::pair<const xAOD::Vertex*, float>>
337 bool ignoreConv,
338 bool noDecorate,
339 yyVtxType* vtxCasePtr,
340 FailType* failTypePtr) const {
341 const xAOD::Vertex *vertex = nullptr;
342 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
344 FailType failType = FailType::NoFail;
345 if (getVertexImp( egammas, vertex, ignoreConv, noDecorate, vertexMLP, vtxCase, failType ).isSuccess()) {
346 std::sort(vertexMLP.begin(), vertexMLP.end(), sortMLP);
347 }
348 if(vtxCasePtr!=nullptr)
349 *vtxCasePtr = vtxCase;
350 if(failTypePtr!=nullptr)
351 *failTypePtr = failType;
352
353 return vertexMLP;
354 }
355
356 //____________________________________________________________________________
358 const xAOD::Vertex* &prime_vertex,
359 bool ignoreConv) const
360 {
361 std::vector<std::pair<const xAOD::Vertex*, float> > vertexMLP;
363 FailType failType = FailType::NoFail;
364 return getVertexImp( egammas, prime_vertex, ignoreConv, false, vertexMLP, vtxcase, failType );
365 }
366
368 const xAOD::EgammaContainer& egammas,
369 const xAOD::Vertex*& prime_vertex,
370 bool ignoreConv,
371 bool noDecorate,
372 std::vector<std::pair<const xAOD::Vertex*, float>>& vertexMLP,
373 yyVtxType& vtxCase,
374 FailType& fail) const {
375 // Set default vertex case and declare photon container
376 vtxCase = yyVtxType::Unknown;
377 const xAOD::PhotonContainer *photons = dynamic_cast<const xAOD::PhotonContainer*>(&egammas);
378
379 // Retrieve PV collection from TEvent
381
382 if (!noDecorate && !decorateInputs(egammas).isSuccess()){
383 return StatusCode::FAILURE;
384 }
385
386 // Check if a conversion photon has a track attached to a primary/pileup vertex
387 if (!ignoreConv && photons) {
388 prime_vertex = getPrimaryVertexFromConv(photons);
389 if (prime_vertex != nullptr) {
390 vtxCase = yyVtxType::ConvTrack;
392 vertexMLP.emplace_back(prime_vertex, 0.);
393 return StatusCode::SUCCESS;
394 }
395 }
396
397 if (fail != FailType::NoFail){
398 ATH_MSG_VERBOSE("Returning hardest vertex. Fail detected (type="<< fail <<")");
399 vertexMLP.clear();
400 prime_vertex = xAOD::PVHelpers::getHardestVertex(&*vertices);
401 vertexMLP.emplace_back(prime_vertex, 10.);
402 return StatusCode::SUCCESS;
403 }
404
405 // Get the EventInfo
407
408 // If there are any silicon conversions passing selection
409 // ==> use Model 1 (Conv) otherwise Model 2 (Unconv)
410 // Set default for conversion bool as false unless otherwise
411 bool isConverted = false;
412
413 // assume default NoSiTrack (unconverted) unless otherwise
414 vtxCase = yyVtxType::NoSiTracks;
415 if (!ignoreConv && photons) {
416 for (const auto *photon: *photons) {
417 if (!photon)
418 {
419 ATH_MSG_WARNING("Null pointer to photon");
420 return StatusCode::FAILURE;
421 }
422 // find out if pass conversion selection criteria and tag as SiConvTrack case
424 {
425 isConverted = true;
426 vtxCase = yyVtxType::SiConvTrack;
427 }
428 }
429 }
430
431 // if TMVA chosen, declare tmva_reader only once (before for looping vertex)
432 TMVA::Reader *tmva_reader = new TMVA::Reader();
433 if(m_isTMVA){
434 if(isConverted){
435 // If there are any silicon conversions passing selection, use MVA1 (converted case)
436 tmva_reader = m_mva1.get();
437 }
438 // Otherwise, use MVA2 (unconverted case)
439 if(!isConverted){
440 tmva_reader = m_mva2.get();
441 }
442 }
443 ATH_MSG_DEBUG("Vtx Case: " << vtxCase);
444
445 // Vector sum of photons
446 TLorentzVector vegamma = getEgammaVector(&egammas, fail);
447
452
453 // Loop over vertices and find best candidate
454 std::vector<float> ONNXInputVector;
455 std::vector<std::vector<float>> onnx_input_tensor_values;
456 std::vector<float> TMVAInputVector;
457 TString TMVAMethod;
458 float mlp = 0.0, mlp_max = -99999.0;
459 float doSkipByZSigmaScore = -9999.0;
460 // assign threshold score value to compare later for good vtx
461 float thresGoodVtxScore;
462 if(m_doSkipByZSigma){thresGoodVtxScore = doSkipByZSigmaScore;}
463 else{thresGoodVtxScore = mlp_max;}
464 for (const xAOD::Vertex* vertex: *vertices) {
465 // Skip dummy vertices
466 if (!(vertex->vertexType() == xAOD::VxType::VertexType::PriVtx ||
467 vertex->vertexType() == xAOD::VxType::VertexType::PileUp)) continue;
468
469 onnx_input_tensor_values.clear();
470
471 // Variables used as input features in classifier
472 float sumPt, sumPt2, deltaPhi, deltaZ;
473 float log10_sumPt, log10_sumPt2;
474
475 sumPt = (sumPtA)(*vertex);
476 sumPt2 = (sumPt2A)(*vertex);
477 deltaPhi = (deltaPhiA)(*vertex);
478 deltaZ = (deltaZA)(*vertex);
479 ATH_MSG_VERBOSE("sumPt: " << sumPt <<
480 " sumPt2: " << sumPt2 <<
481 " deltaPhi: " << deltaPhi <<
482 " deltaZ: " << deltaZ);
483
484 // setup the vector of input features based on selected inference framework
485 if(m_isTMVA){
486 // Get likelihood probability from TMVA model
487 TMVAMethod = "MLP method";
488 log10_sumPt = static_cast<float>(log10(sumPt));
489 log10_sumPt2 = static_cast<float>(log10(sumPt2));
490 TMVAInputVector = {deltaZ,deltaPhi,log10_sumPt,log10_sumPt2};
491 }
492 else{ //assume ony ONNX for now
493 // Get likelihood probability from onnx model
494 // check if value is 0, assign small number like 1e-8 as dummy, as we will take log later (log(0) is nan)
495 // note that the ordering here is a bit different, following the order used when training
496 ONNXInputVector = {sumPt2, sumPt, deltaPhi, deltaZ};
497 for (long unsigned int i = 0; i < ONNXInputVector.size(); i++) {
498 // skip log for deltaPhi and take log for the rest
499 if (i == 2) {
500 continue;
501 }
502 if (ONNXInputVector[i] != 0 && std::isinf(ONNXInputVector[i]) != true && std::isnan(ONNXInputVector[i]) != true){
503 ONNXInputVector[i] = log(std::abs(ONNXInputVector[i]));
504 }
505 else{
506 ONNXInputVector[i] = log(std::abs(0.00000001)); //log(abs(1e-8))
507 }
508 } //end ONNXInputVector for loop
509 onnx_input_tensor_values.push_back(ONNXInputVector);
510 }
511
512 // Do the actual calculation of classifier score part
513 if(m_isTMVA){
514 mlp = tmva_reader->EvaluateMVA(TMVAInputVector, TMVAMethod);
515 ATH_MSG_VERBOSE("TMVA output: " << (tmva_reader == m_mva1.get() ? "MVA1 ": "MVA2 ")<< mlp);
516 }
517 else{ //assume ony ONNX for now
518 if(isConverted){
519 mlp = getScore(m_nVars, onnx_input_tensor_values,
522 }
523 if(!isConverted){
524 mlp = getScore(m_nVars, onnx_input_tensor_values,
527 }
528 ATH_MSG_VERBOSE("log(abs(sumPt)): " << sumPt <<
529 " log(abs(sumPt2)): " << sumPt2 <<
530 " deltaPhi: " << deltaPhi <<
531 " log(abs(deltaZ)): " << deltaZ);
532 ATH_MSG_VERBOSE("ONNX output, isConverted = " << isConverted << ", mlp=" << mlp);
533 }
534
535 // Skip vertices above 10 sigma from pointing or 15 sigma from conversion (HPV)
536 // Simply displace the mlp variable we calculate before by a predefined value
538 if ((isConverted && deltaZ > 15) || (!isConverted && deltaZ > 10)) {
539 mlp = doSkipByZSigmaScore;
540 }
541 }
542
543 // add the new vertex and its score to vertexMLP container
544 vertexMLP.emplace_back(vertex, mlp);
545
546 // Keep track of maximal likelihood vertex
547 if (mlp > mlp_max) {
548 mlp_max = mlp;
549 prime_vertex = vertex;
550 }
551 } // end loop over vertices
552
553 // from all the looped vertices, decide the max score which should be more than the minimum we set
554 // (which should be more than the initial mlp_max value above or more than the skip vertex by z-sigma score)
555 // if this does not pass, return hardest primary vertex
556 if (mlp_max <= thresGoodVtxScore) {
557 ATH_MSG_DEBUG("No good vertex candidates from pointing, returning hardest vertex.");
558 prime_vertex = xAOD::PVHelpers::getHardestVertex(&*vertices);
560 vertexMLP.clear();
561 vertexMLP.emplace_back(xAOD::PVHelpers::getHardestVertex(&*vertices), 20.);
562 }
563
564 ATH_MSG_VERBOSE("getVertex case "<< (int)vtxCase << " exit code "<< (int)fail);
565 return StatusCode::SUCCESS;
566 }
567
568 //____________________________________________________________________________
569 bool PhotonVertexSelectionTool::sortMLP(const std::pair<const xAOD::Vertex*, float> &a,
570 const std::pair<const xAOD::Vertex*, float> &b)
571 { return a.second > b.second; }
572
573 //____________________________________________________________________________
575 {
576 if (photons == nullptr) {
577 ATH_MSG_WARNING("Passed nullptr photon container, returning nullptr vertex from getPrimaryVertexFromConv");
578 return nullptr;
579 }
580
581 std::vector<const xAOD::Vertex*> vertices;
582 const xAOD::Vertex *conversionVertex = nullptr, *primary = nullptr;
583 const xAOD::TrackParticle *tp = nullptr;
584 size_t NumberOfTracks = 0;
585
586 // Retrieve PV collection from TEvent
588
589
590 for (const auto *photon: *photons) {
591 conversionVertex = photon->vertex();
592 if (conversionVertex == nullptr) continue;
593
594 NumberOfTracks = conversionVertex->nTrackParticles();
595 for (size_t i = 0; i < NumberOfTracks; ++i) {
596 // Get trackParticle in GSF collection
597 const auto *gsfTp = conversionVertex->trackParticle(i);
598 if (gsfTp == nullptr) continue;
599 if (!xAOD::PVHelpers::passConvSelection(*conversionVertex, i, m_convPtCut)) continue;
600
601 // Get trackParticle in InDet collection
603 if (tp == nullptr) continue;
604
605 primary = getVertexFromTrack(tp, &*all_vertices);
606 if (primary == nullptr) continue;
607
608 if (primary->vertexType() == xAOD::VxType::VertexType::PriVtx ||
609 primary->vertexType() == xAOD::VxType::VertexType::PileUp) {
610 if (std::find(vertices.begin(), vertices.end(), primary) == vertices.end()) {
611 vertices.push_back(primary);
612 continue;
613 }
614 }
615 }
616 }
617
618 if (!vertices.empty()) {
619 if (vertices.size() > 1)
620 ATH_MSG_WARNING("Photons associated to different vertices! Returning lead photon association.");
621 return vertices[0];
622 }
623
624 return nullptr;
625 }
626
627 //____________________________________________________________________________
628 TLorentzVector PhotonVertexSelectionTool::getEgammaVector(const xAOD::EgammaContainer *egammas, FailType& failType) const
629 {
630 TLorentzVector v, v1;
631 const xAOD::CaloCluster *cluster = nullptr;
632 for (const xAOD::Egamma* egamma: *egammas) {
633 if (egamma == nullptr) {
634 ATH_MSG_DEBUG("No egamma object to get four vector");
635 failType = FailType::FailEgamVect;
636 continue;
637 }
638 cluster = egamma->caloCluster();
639 if (cluster == nullptr) {
640 ATH_MSG_WARNING("No cluster associated to egamma, not adding to 4-vector.");
641 continue;
642 }
643
644 v1.SetPtEtaPhiM(egamma->e()/cosh(cluster->etaBE(2)),
645 cluster->etaBE(2),
646 cluster->phiBE(2),
647 0.0);
648 v += v1;
649 }
650 return v;
651 }
652
653} // namespace CP
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
Handle class for reading from StoreGate.
Handle class for adding a decoration to an object.
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
Definition AtlasPID.h:878
static Double_t a
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)
FailType
Declare the interface that the class provides.
const xAOD::Vertex * getPrimaryVertexFromConv(const xAOD::PhotonContainer *photons) const
Get possible vertex directly associated with photon conversions.
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPtKey
TLorentzVector getEgammaVector(const xAOD::EgammaContainer *egammas, FailType &failType) const
Get combined 4-vector of photon container.
std::vector< const char * > m_output_node_names2
Ort::AllocatorWithDefaultOptions m_allocator2
std::shared_ptr< Ort::Session > m_sessionHandle2
StatusCode getVertex(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv=false) const
Given a list of photons, return the most likely vertex based on MVA likelihood.
std::unique_ptr< TMVA::Reader > m_mva2
std::vector< const char * > m_output_node_names1
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaZKey
Ort::AllocatorWithDefaultOptions m_allocator1
std::unique_ptr< TMVA::Reader > m_mva1
float getScore(int nVars, const std::vector< std::vector< float > > &input_data, const std::shared_ptr< Ort::Session > &sessionHandle, std::vector< int64_t > input_node_dims, std::vector< const char * > input_node_names, std::vector< const char * > output_node_names) const
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_sumPt2Key
std::tuple< std::shared_ptr< Ort::Session >, Ort::AllocatorWithDefaultOptions > setONNXSession(Ort::Env &env, const std::string &modelFilePath)
std::vector< const char * > m_input_node_names2
static bool sortMLP(const std::pair< const xAOD::Vertex *, float > &a, const std::pair< const xAOD::Vertex *, float > &b)
Sort MLP results.
virtual StatusCode initialize()
Function initialising the tool.
std::tuple< std::vector< int64_t >, std::vector< const char * > > getOutputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
StatusCode decorateInputs(const xAOD::EgammaContainer &egammas, FailType *failType=nullptr) const
Given a list of photons, decorate vertex container with MVA variables.
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfo
Container declarations.
StatusCode getVertexImp(const xAOD::EgammaContainer &egammas, const xAOD::Vertex *&vertex, bool ignoreConv, bool noDecorate, std::vector< std::pair< const xAOD::Vertex *, float > > &, yyVtxType &, FailType &) const
Given a list of photons, return the MLPs of all vertices in the event.
std::shared_ptr< Ort::Session > m_sessionHandle1
PhotonVertexSelectionTool(const std::string &name)
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainer
int m_nVars
Create a proper constructor for Athena.
SG::WriteDecorHandleKey< xAOD::VertexContainer > m_deltaPhiKey
std::vector< const char * > m_input_node_names1
std::tuple< std::vector< int64_t >, std::vector< const char * > > getInputNodes(const std::shared_ptr< Ort::Session > &sessionHandle, Ort::AllocatorWithDefaultOptions &allocator)
virtual double e() const
energy
SG::ConstAccessor< T, ALLOC > ConstAccessor
Definition AuxElement.h:569
Handle class for adding a decoration to an object.
bool isAvailable()
Test to see if this variable exists in the store, for the referenced object.
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
elec/gamma data class.
Definition egamma.h:58
float phiBE(const unsigned layer) const
Get the phi in one layer of the EM Calo.
float etaBE(const unsigned layer) const
Get the eta in one layer of the EM Calo.
size_t nTrackParticles() const
Get the number of tracks associated with this vertex.
const TrackParticle * trackParticle(size_t i) const
Get the pointer to a given track that was used in vertex reco.
const std::vector< float > & trackWeights() const
Get all the track weights.
Select isolated Photons, Electrons and Muons.
const xAOD::Vertex * getVertexFromTrack(const xAOD::TrackParticle *track, const xAOD::VertexContainer *vertices)
std::string decorKeyFromKey(const std::string &key, const std::string &deflt)
Extract the decoration part of key.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
const xAOD::TrackParticle * getOriginalTrackParticleFromGSF(const xAOD::TrackParticle *trkPar)
Helper function for getting the "Original" Track Particle (i.e before GSF) via the GSF Track Particle...
const xAOD::Vertex * getHardestVertex(const xAOD::VertexContainer *vertices)
Return vertex with highest sum pT^2.
float getVertexSumPt(const xAOD::Vertex *vertex, int power=1, bool useAux=true)
Loop over track particles associated with vertex and return scalar sum of pT^power in GeV (from auxda...
TLorentzVector getVertexMomentum(const xAOD::Vertex *vertex, bool useAux=true, const std::string &derivationPrefix="")
Return vector sum of tracks associated with vertex (from auxdata if available and useAux = true)
bool passConvSelection(const xAOD::Photon *photon, float convPtCut=2e3)
Check if photon is converted, and tracks have Si hits and pass selection.
std::pair< float, float > getZCommonAndError(const xAOD::EventInfo *eventInfo, const xAOD::EgammaContainer *egammas, float convPtCut=2e3)
Return zCommon and zCommonError.
@ PileUp
Pile-up vertex.
@ PriVtx
Primary vertex.
PhotonContainer_v1 PhotonContainer
Definition of the current "photon container version".
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
setSAddress setEtaMS setDirPhiMS setDirZMS setBarrelRadius setEndcapAlpha setEndcapRadius setInterceptInner setEtaMap setEtaBin setIsTgcFailure setDeltaPt deltaPhi
TrackParticle_v1 TrackParticle
Reference the current persistent version:
VertexContainer_v1 VertexContainer
Definition of the current "Vertex container version".
Vertex_v1 Vertex
Define the latest version of the vertex class.
Egamma_v1 Egamma
Definition of the current "egamma version".
Definition Egamma.h:17
EgammaContainer_v1 EgammaContainer
Definition of the current "egamma container version".