ATLAS Offline Software
TrackOverlayDecisionAlg.cxx
Go to the documentation of this file.
1 /*
2  * * Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3  * * */
4 //
5 #include <fstream>
6 #include <string>
7 #include <iostream>
8 #include <vector>
9 #include <sstream>
10 #include <unistd.h>
11 #include <Eigen/Core>
12 //
13 #include "GaudiKernel/SystemOfUnits.h"
14 #include "Gaudi/Property.h"
19 
20 // ONNX Runtime include(s).
21 #include <onnxruntime_cxx_api.h>
22 //
25 
27  TrackOverlayDecisionAlg::TrackOverlayDecisionAlg( const std::string& name, ISvcLocator* pSvcLocator ) :
28  ::AthReentrantAlgorithm( name, pSvcLocator )
29  {
30  }
31 
33 {
34  ATH_CHECK(m_filterParams.initialize(false));
35  ATH_CHECK( m_eventInfoContainerName.initialize() );
36  ATH_CHECK( m_truthParticleName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthParticleName.key().empty() ) );
37  ATH_CHECK(m_truthSelectionTool.retrieve(EnableTool {not m_truthParticleName.key().empty()} ));
38  ATH_CHECK( m_truthEventName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthEventName.key().empty() ) );
39  ATH_CHECK( m_truthPileUpEventName.initialize( (m_pileupSwitch == "PileUp" or m_pileupSwitch == "All") and not m_truthPileUpEventName.key().empty() ) );
40  ATH_CHECK(m_svc.retrieve());
41  std::string this_file = __FILE__;
42  const std::string model_path = PathResolverFindCalibFile("TrackOverlay/TrackOverlay_J7_model.onnx");
43  Ort::SessionOptions session_options;
44 
45  m_session = std::make_unique<Ort::Session>(m_svc->env(), model_path.c_str(), session_options);
46  m_inputInfo = TrackOverlayDecisionAlg::GetInputNodeInfo(m_session);
47  m_outputInfo = TrackOverlayDecisionAlg::GetOutputNodeInfo(m_session);
48  return StatusCode::SUCCESS;
49 }
50 
52 {
53  ATH_MSG_VERBOSE ( "Finalizing ..." );
54  ATH_MSG_VERBOSE("-----------------------------------------------------------------");
55  ATH_MSG_VERBOSE("m_filterParams.summary()" << m_filterParams.summary());
56  ATH_MSG_VERBOSE("-----------------------------------------------------------------");
57  ATH_MSG_INFO(m_filterParams.summary());
58  ATH_MSG_VERBOSE(" =====================================================================");
59 
60  return StatusCode::SUCCESS;
61 }
62 
63 const std::vector<const xAOD::TruthParticle*> TrackOverlayDecisionAlg::getTruthParticles() const {
64  std::vector<const xAOD::TruthParticle*> tempVec {};
65  if (m_pileupSwitch == "All") {
66  if (m_truthParticleName.key().empty()) {
67  return tempVec;
68  }
69  SG::ReadHandle<xAOD::TruthParticleContainer> truthParticleContainer( m_truthParticleName);
70  if (not truthParticleContainer.isValid()) {
71  return tempVec;
72  }
73  tempVec.insert(tempVec.begin(), truthParticleContainer->begin(), truthParticleContainer->end());
74  } else {
75  if (m_pileupSwitch == "HardScatter") {
76  if (not m_truthEventName.key().empty()) {
77  ATH_MSG_VERBOSE("Getting TruthEvents container.");
78  SG::ReadHandle<xAOD::TruthEventContainer> truthEventContainer( m_truthEventName);
79  const xAOD::TruthEvent* event = (truthEventContainer.isValid()) ? truthEventContainer->at(0) : nullptr;
80  if (not event) {
81  return tempVec;
82  }
83  const auto& links = event->truthParticleLinks();
84  tempVec.reserve(event->nTruthParticles());
85  for (const auto& link : links) {
86  if (link.isValid()){
87  tempVec.push_back(*link);
88  }
89  }
90  }
91  }else if (m_pileupSwitch == "PileUp") {
92  if (not m_truthPileUpEventName.key().empty()) {
93  ATH_MSG_VERBOSE("getting TruthPileupEvents container");
94  // get truth particles from all pileup events
95  SG::ReadHandle<xAOD::TruthPileupEventContainer> truthPileupEventContainer(m_truthPileUpEventName);
96  if (truthPileupEventContainer.isValid()) {
97  const unsigned int nPileup = truthPileupEventContainer->size();
98  tempVec.reserve(nPileup * 200); // quick initial guess, will still save some time
99  for (unsigned int i(0); i != nPileup; ++i) {
100  const auto *eventPileup = truthPileupEventContainer->at(i);
101  // get truth particles from each pileup event
102  int ntruth = eventPileup->nTruthParticles();
103  ATH_MSG_VERBOSE("Adding " << ntruth << " truth particles from TruthPileupEvents container");
104  const auto& links = eventPileup->truthParticleLinks();
105  for (const auto& link : links) {
106  if (link.isValid()){
107  tempVec.push_back(*link);
108  }
109  }
110  }
111  } else {
112  ATH_MSG_ERROR("no entries in TruthPileupEvents container!");
113  }
114  }
115  } else {
116  ATH_MSG_ERROR("bad value for PileUpSwitch");
117  }
118  }
119  return tempVec;
120 }
121 
122 StatusCode TrackOverlayDecisionAlg::execute(const EventContext &ctx) const
123 {
124  ATH_MSG_DEBUG ("Executing ...");
125 
126  std::vector<const xAOD::TruthParticle*> truthParticlesVec = TrackOverlayDecisionAlg::getTruthParticles();
127 
128  //Access truth info for the NN input
129  float eventPxSum = 0.0;
130  float eventPySum = 0.0;
131  float eventPt = 0.0;
132  float puEvents = 0.0;
133 
134  std::vector<float> pxValues, pyValues, pzValues, eValues, etaValues, phiValues, ptValues;
135  float truthMultiplicity = 0.0;
136  const int truthParticles = truthParticlesVec.size();
137  for (int itruth = 0; itruth < truthParticles; itruth++) {
138  const xAOD::TruthParticle* thisTruth = truthParticlesVec[itruth];
139  const IAthSelectionTool::CutResult accept = m_truthSelectionTool->accept(thisTruth);
140  if(accept){
141  pxValues.push_back((thisTruth->px()*0.001-1.46988000e+03)* px_diff); //as MinMaxScaler: 1.46988000e+03 is the lowest value of px from a J7 sample; *(0.001) is used to convert unit rather than *(1/1000) to speed up.
142  pyValues.push_back((thisTruth->py()*0.001-1.35142000e+03)* py_diff); //the lowest value of py: 1.35142000e+03
143  pzValues.push_back((thisTruth->pz()*0.001-1.50464000e+03)* pz_diff); //the lowest value of pz: 1.50464000e+03
144  ptValues.push_back((thisTruth->pt()*0.001-5.00006000e-01)* pt_diff); //the lowest value of pt: 5.00006000e-01
145 
146  etaValues.push_back(thisTruth->eta());
147  phiValues.push_back(thisTruth->phi());
148  eValues.push_back((thisTruth->e()*0.001-5.08307000e-01)*e_diff); //the lowest value of energy: 5.08307000e-01
149 
150  eventPxSum += thisTruth->px();
151  eventPySum += thisTruth->py();
152  truthMultiplicity++;
153  }//accept
154  }//for itruth
155  SG::ReadHandle<xAOD::TruthPileupEventContainer> truthPileupEventContainer;
156  SG::ReadHandle<xAOD::EventInfo> pie = SG::ReadHandle<xAOD::EventInfo>(m_eventInfoContainerName, ctx);
157  if (!m_truthPileUpEventName.key().empty()) {
158  truthPileupEventContainer = SG::ReadHandle<xAOD::TruthPileupEventContainer>(m_truthPileUpEventName, ctx);
159  }
160  puEvents = !m_truthPileUpEventName.key().empty() and truthPileupEventContainer.isValid() ? static_cast<int>( truthPileupEventContainer->size() ) : pie.isValid() ? pie->actualInteractionsPerCrossing() : 0;
161  eventPt = std::sqrt(eventPxSum*eventPxSum + eventPySum*eventPySum)*0.001;
162 
163  std::vector<float> puEventsVec(pxValues.size(), (puEvents-1.55000000e+01)*pu_diff); //min of puEvents= 15.5, max of puEvents=84.5
164  std::vector<float> truthMultiplicityVec(pxValues.size(), (truthMultiplicity-1.80000000e+01)*multi_diff);
165  std::vector<float> eventPtVec(pxValues.size(), (eventPt-3.42359395e-01)*eventPt_diff);
166  std::vector<float> predictions;
167 
168  //Compute the distances using Eigen for Eigen's optimized operations. Initialize matirces. Observed a significant improvement on computing calculation.
169  Eigen::VectorXf ptEigen = Eigen::VectorXf::Map(ptValues.data(), ptValues.size());
170  Eigen::VectorXf phiEigen = Eigen::VectorXf::Map(phiValues.data(), phiValues.size());
171  Eigen::VectorXf etaEigen = Eigen::VectorXf::Map(etaValues.data(), etaValues.size());
172  for (std::size_t i = 0; i < truthMultiplicity; ++i) {
173  float multiplicity_0p05 = 0.0, multiplicity_0p2 = 0.0;
174  float sum_0p05 = 0.0, sum_0p2 = 0.0;
175  float pt_0p05 = 0.0, pt_0p2 = 0.0;
176  float deltaEtaI = etaEigen[i];
177  float phiI = phiEigen[i];
178  for (std::size_t j = 0; j < truthMultiplicity; ++j) {
179  if (i == j) continue; // Skip the particle itself
180  float deltaEta = deltaEtaI - etaEigen[j];
181  float deltaPhi = phiI - phiEigen[j];
182  if (deltaPhi > M_PI) {
183  deltaPhi -= 2.0 * M_PI;
184  }
185  float distances = std::sqrt(deltaEta * deltaEta + deltaPhi * deltaPhi);
186  if (distances < 0.05){
187  multiplicity_0p05++;
188  sum_0p05 += distances;
189  pt_0p05 += ptEigen[j];
190  }
191  if (distances < 0.2){
192  multiplicity_0p2++;
193  sum_0p2 += distances;
194  pt_0p2 += ptEigen[j];
195  }
196  }// for j
197 
198  std::vector<float> featData;
199  featData.push_back(pxValues[i]);
200  featData.push_back(pyValues[i]);
201  featData.push_back(pzValues[i]);
202  featData.push_back(eValues[i]);
203  featData.push_back(ptValues[i]);
204  featData.push_back((multiplicity_0p2 * area0p2) * constant1);
205  featData.push_back((multiplicity_0p05 * area0p05) * constant2);
206  featData.push_back((sum_0p2 * area0p2) * constant3);
207  featData.push_back((sum_0p05 * area0p05) * constant4);
208  featData.push_back(pt_0p2 * constant5);
209  featData.push_back(pt_0p05 * constant6);
210 
211  featData.push_back(puEventsVec[i]);
212  featData.push_back(truthMultiplicityVec[i]);
213  featData.push_back(eventPtVec[i]);
214 
215  std::vector<int64_t> input_node_dims;
216  std::vector<char*> input_node_names;
217  input_node_dims = std::get<0>(m_inputInfo);
218  input_node_names = std::get<1>(m_inputInfo);
219 
220  std::vector<int64_t> output_node_dims;
221  std::vector<char*> output_node_names;
222  output_node_dims = std::get<0>(m_outputInfo);
223  output_node_names = std::get<1>(m_outputInfo);
224 
225  Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
226  input_node_dims[0]=1;
227  Ort::Value input_data = Ort::Value::CreateTensor(memoryInfo, featData.data(), featData.size(), input_node_dims.data(), input_node_dims.size());
228  Ort::RunOptions run_options(nullptr);
229  //Run the inference
230  Ort::Session& mysession ATLAS_THREAD_SAFE = *m_session;
231  auto output_values = mysession.Run(run_options, input_node_names.data(), &input_data, input_node_names.size(), output_node_names.data(), output_node_names.size());
232  float* predictionData = output_values[0].GetTensorMutableData<float>();
233  float prediction = predictionData[0];
234 
235  predictions.push_back(prediction);
236  }//for i
237  float threshold = m_MLthreshold;
238  ATH_MSG_ALWAYS("ML threshold:" << threshold);
239  int badTracks = 0;
240  for (float prediction : predictions) {
241  if (prediction > threshold) {
242  badTracks++;
243  }
244  }
245  float rouletteScore = static_cast<float>(badTracks) / static_cast<float>(truthMultiplicity);
246 
247  FilterReporter filter(m_filterParams, false, ctx);
248  bool pass = false;
249  int decision = rouletteScore == 0;
250  if (decision==0){ //if ML decision is False, it goes to the MC-overlay workflow
251  pass = true;
252  }
253  else{
254  pass = false;
255  }
256 
257  if (m_invertfilter) {
258  pass =! pass;
259  }
260  filter.setPassed(pass);
261  ATH_MSG_ALWAYS("End TrackOverlayDecisionAlg, difference in filters: "<<(pass ? "found" : "not found")<<"="<<pass<<", invert="<<m_invertfilter);
262  return StatusCode::SUCCESS;
263 }
264 
265 
266 }// end namespace TrackOverlayDecisionAlg
FilterReporter
a guard class for use with ref FilterReporterParams
Definition: FilterReporter.h:35
python.tests.PyTestsLib.finalize
def finalize(self)
_info( "content of StoreGate..." ) self.sg.dump()
Definition: PyTestsLib.py:50
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TrackOverlayDecisionAlg::px_diff
const float px_diff
Definition: TrackOverlayDecisionAlg.h:29
TrackOverlayDecisionAlg::area0p05
const float area0p05
Definition: TrackOverlayDecisionAlg.h:38
xAOD::TruthParticle_v1::pz
float pz() const
The z component of the particle's momentum.
IAthSelectionTool::CutResult
Definition: IAthSelectionTool.h:30
SG::ReadHandle< xAOD::TruthParticleContainer >
xAOD::deltaPhi
setSAddress setEtaMS setDirPhiMS setDirZMS setBarrelRadius setEndcapAlpha setEndcapRadius setInterceptInner setEtaMap setEtaBin setIsTgcFailure setDeltaPt deltaPhi
Definition: L2StandAloneMuon_v1.cxx:161
python.PyParticleTools.getTruthParticles
def getTruthParticles(aKey)
Definition: PyParticleTools.py:102
initialize
void initialize()
Definition: run_EoverP.cxx:894
CutsMETMaker::accept
StatusCode accept(const xAOD::Muon *mu)
Definition: CutsMETMaker.cxx:18
TruthPileupEvent.h
xAOD::TruthParticle_v1::px
float px() const
The x component of the particle's momentum.
M_PI
#define M_PI
Definition: ActiveFraction.h:11
TrackOverlayDecisionAlg::constant3
const float constant3
Definition: TrackOverlayDecisionAlg.h:41
xAOD::TruthParticle_v1::py
float py() const
The y component of the particle's momentum.
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
python.oracle.Session
Session
Definition: oracle.py:78
TrackOverlayDecisionAlg
Definition: TrackOverlayDecisionAlg.cxx:26
AthReentrantAlgorithm
An algorithm that can be simultaneously executed in multiple threads.
Definition: AthReentrantAlgorithm.h:74
covarianceTool.filter
filter
Definition: covarianceTool.py:514
LArG4FSStartPointFilterLegacy.execute
execute
Definition: LArG4FSStartPointFilterLegacy.py:20
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::TruthParticle_v1::e
virtual double e() const override final
The total energy of the particle.
event
POOL::TEvent event(POOL::TEvent::kClassAccess)
P4Helpers::deltaEta
double deltaEta(const I4Momentum &p1, const I4Momentum &p2)
Computes efficiently .
Definition: P4Helpers.h:66
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
DMTest::links
links
Definition: CLinks_v1.cxx:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TruthParticle_v1
Class describing a truth particle in the MC record.
Definition: TruthParticle_v1.h:37
xAOD::TruthEvent_v1
Class describing a signal truth event in the MC record.
Definition: TruthEvent_v1.h:35
ATH_MSG_ALWAYS
#define ATH_MSG_ALWAYS(x)
Definition: AthMsgStreamMacros.h:35
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
TrackOverlayDecisionAlg::constant1
const float constant1
Definition: TrackOverlayDecisionAlg.h:39
SG::ReadHandle::isValid
virtual bool isValid() override final
Can the handle be successfully dereferenced?
TrackOverlayDecisionAlg::py_diff
const float py_diff
Definition: TrackOverlayDecisionAlg.h:30
TrackOverlayDecisionAlg::pt_diff
const float pt_diff
Definition: TrackOverlayDecisionAlg.h:32
FilterReporter.h
PathResolver.h
TrackOverlayDecisionAlg::constant5
const float constant5
Definition: TrackOverlayDecisionAlg.h:43
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
threshold
Definition: chainparser.cxx:74
xAOD::TruthParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TruthParticle_v1.cxx:169
TrackOverlayDecisionAlg::area0p2
const float area0p2
Definition: TrackOverlayDecisionAlg.h:37
TrackOverlayDecisionAlg::eventPt_diff
const float eventPt_diff
Definition: TrackOverlayDecisionAlg.h:36
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
TrackOverlayDecisionAlg::pz_diff
const float pz_diff
Definition: TrackOverlayDecisionAlg.h:31
DataVector::end
const_iterator end() const noexcept
Return a const_iterator pointing past the end of the collection.
TrackOverlayDecisionAlg::TrackOverlayDecisionAlg
Definition: TrackOverlayDecisionAlg.h:46
TrackOverlayDecisionAlg::constant2
const float constant2
Definition: TrackOverlayDecisionAlg.h:40
xAOD::TruthParticle_v1::phi
virtual double phi() const override final
The azimuthal angle ( ) of the particle.
Definition: TruthParticle_v1.cxx:176
TrackOverlayDecisionAlg::constant6
const float constant6
Definition: TrackOverlayDecisionAlg.h:44
TrackOverlayDecisionAlg::multi_diff
const float multi_diff
Definition: TrackOverlayDecisionAlg.h:35
xAOD::TruthParticle_v1::pt
virtual double pt() const override final
The transverse momentum ( ) of the particle.
Definition: TruthParticle_v1.cxx:161
TruthPileupEventAuxContainer.h
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
DataVector::at
const T * at(size_type n) const
Access an element, as an rvalue.
TrackOverlayDecisionAlg::pu_diff
const float pu_diff
Definition: TrackOverlayDecisionAlg.h:34
TrackOverlayDecisionAlg::constant4
const float constant4
Definition: TrackOverlayDecisionAlg.h:42
TruthParticle.h
DataVector::size
size_type size() const noexcept
Returns the number of elements in the collection.
DataVector::begin
const_iterator begin() const noexcept
Return a const_iterator pointing at the beginning of the collection.
TrackOverlayDecisionAlg.h
xAOD::EventInfo_v1::actualInteractionsPerCrossing
float actualInteractionsPerCrossing() const
Average interactions per crossing for the current BCID - for in-time pile-up.
Definition: EventInfo_v1.cxx:380
TrackOverlayDecisionAlg::e_diff
const float e_diff
Definition: TrackOverlayDecisionAlg.h:33