ATLAS Offline Software
TrackOverlayDecisionAlg.cxx
Go to the documentation of this file.
1 /*
2  * * Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3  * * */
4 //
5 #include <fstream>
6 #include <string>
7 #include <iostream>
8 #include <vector>
9 #include <sstream>
10 #include <unistd.h>
11 #include <Eigen/Core>
12 //
13 #include "GaudiKernel/SystemOfUnits.h"
14 #include "Gaudi/Property.h"
15 #include "EventInfo/EventInfo.h"
16 #include "EventInfo/EventID.h"
21 
22 // ONNX Runtime include(s).
23 #include <onnxruntime_cxx_api.h>
24 //
27 
29  TrackOverlayDecisionAlg::TrackOverlayDecisionAlg( const std::string& name, ISvcLocator* pSvcLocator ) :
30  ::AthReentrantAlgorithm( name, pSvcLocator )
31  {
32  }
33 
35 {
36  ATH_CHECK(m_filterParams.initialize(false));
37  ATH_CHECK( m_eventInfoContainerName.initialize() );
38  ATH_CHECK( m_truthParticleName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthParticleName.key().empty() ) );
39  ATH_CHECK(m_truthSelectionTool.retrieve(EnableTool {not m_truthParticleName.key().empty()} ));
40  ATH_CHECK( m_truthEventName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthEventName.key().empty() ) );
41  ATH_CHECK( m_truthPileUpEventName.initialize( (m_pileupSwitch == "PileUp" or m_pileupSwitch == "All") and not m_truthPileUpEventName.key().empty() ) );
42  ATH_CHECK(m_svc.retrieve());
43  std::string this_file = __FILE__;
44  const std::string model_path = PathResolverFindCalibFile("TrackOverlay/TrackOverlay_J7_model.onnx");
45  Ort::SessionOptions session_options;
46 
47  m_session = std::make_unique<Ort::Session>(m_svc->env(), model_path.c_str(), session_options);
48  m_inputInfo = TrackOverlayDecisionAlg::GetInputNodeInfo(m_session);
49  m_outputInfo = TrackOverlayDecisionAlg::GetOutputNodeInfo(m_session);
50  return StatusCode::SUCCESS;
51 }
52 
54 {
55  ATH_MSG_VERBOSE ( "Finalizing ..." );
56  ATH_MSG_VERBOSE("-----------------------------------------------------------------");
57  ATH_MSG_VERBOSE("m_filterParams.summary()" << m_filterParams.summary());
58  ATH_MSG_VERBOSE("-----------------------------------------------------------------");
59  ATH_MSG_INFO(m_filterParams.summary());
60  ATH_MSG_VERBOSE(" =====================================================================");
61 
62  return StatusCode::SUCCESS;
63 }
64 
65 const std::vector<const xAOD::TruthParticle*> TrackOverlayDecisionAlg::getTruthParticles() const {
66  std::vector<const xAOD::TruthParticle*> tempVec {};
67  if (m_pileupSwitch == "All") {
68  if (m_truthParticleName.key().empty()) {
69  return tempVec;
70  }
71  SG::ReadHandle<xAOD::TruthParticleContainer> truthParticleContainer( m_truthParticleName);
72  if (not truthParticleContainer.isValid()) {
73  return tempVec;
74  }
75  tempVec.insert(tempVec.begin(), truthParticleContainer->begin(), truthParticleContainer->end());
76  } else {
77  if (m_pileupSwitch == "HardScatter") {
78  if (not m_truthEventName.key().empty()) {
79  ATH_MSG_VERBOSE("Getting TruthEvents container.");
80  SG::ReadHandle<xAOD::TruthEventContainer> truthEventContainer( m_truthEventName);
81  const xAOD::TruthEvent* event = (truthEventContainer.isValid()) ? truthEventContainer->at(0) : nullptr;
82  if (not event) {
83  return tempVec;
84  }
85  const auto& links = event->truthParticleLinks();
86  tempVec.reserve(event->nTruthParticles());
87  for (const auto& link : links) {
88  if (link.isValid()){
89  tempVec.push_back(*link);
90  }
91  }
92  }
93  }else if (m_pileupSwitch == "PileUp") {
94  if (not m_truthPileUpEventName.key().empty()) {
95  ATH_MSG_VERBOSE("getting TruthPileupEvents container");
96  // get truth particles from all pileup events
97  SG::ReadHandle<xAOD::TruthPileupEventContainer> truthPileupEventContainer(m_truthPileUpEventName);
98  if (truthPileupEventContainer.isValid()) {
99  const unsigned int nPileup = truthPileupEventContainer->size();
100  tempVec.reserve(nPileup * 200); // quick initial guess, will still save some time
101  for (unsigned int i(0); i != nPileup; ++i) {
102  const auto *eventPileup = truthPileupEventContainer->at(i);
103  // get truth particles from each pileup event
104  int ntruth = eventPileup->nTruthParticles();
105  ATH_MSG_VERBOSE("Adding " << ntruth << " truth particles from TruthPileupEvents container");
106  const auto& links = eventPileup->truthParticleLinks();
107  for (const auto& link : links) {
108  if (link.isValid()){
109  tempVec.push_back(*link);
110  }
111  }
112  }
113  } else {
114  ATH_MSG_ERROR("no entries in TruthPileupEvents container!");
115  }
116  }
117  } else {
118  ATH_MSG_ERROR("bad value for PileUpSwitch");
119  }
120  }
121  return tempVec;
122 }
123 
124 StatusCode TrackOverlayDecisionAlg::execute(const EventContext &ctx) const
125 {
126  ATH_MSG_DEBUG ("Executing ...");
127 
128  std::vector<const xAOD::TruthParticle*> truthParticlesVec = TrackOverlayDecisionAlg::getTruthParticles();
129 
130  //Access truth info for the NN input
131  float eventPxSum = 0.0;
132  float eventPySum = 0.0;
133  float eventPt = 0.0;
134  float puEvents = 0.0;
135 
136  std::vector<float> pxValues, pyValues, pzValues, eValues, etaValues, phiValues, ptValues;
137  float truthMultiplicity = 0.0;
138  const int truthParticles = truthParticlesVec.size();
139  for (int itruth = 0; itruth < truthParticles; itruth++) {
140  const xAOD::TruthParticle* thisTruth = truthParticlesVec[itruth];
141  const IAthSelectionTool::CutResult accept = m_truthSelectionTool->accept(thisTruth);
142  if(accept){
143  pxValues.push_back((thisTruth->px()*0.001-1.46988000e+03)* px_diff); //as MinMaxScaler: 1.46988000e+03 is the lowest value of px from a J7 sample; *(0.001) is used to convert unit rather than *(1/1000) to speed up.
144  pyValues.push_back((thisTruth->py()*0.001-1.35142000e+03)* py_diff); //the lowest value of py: 1.35142000e+03
145  pzValues.push_back((thisTruth->pz()*0.001-1.50464000e+03)* pz_diff); //the lowest value of pz: 1.50464000e+03
146  ptValues.push_back((thisTruth->pt()*0.001-5.00006000e-01)* pt_diff); //the lowest value of pt: 5.00006000e-01
147 
148  etaValues.push_back(thisTruth->eta());
149  phiValues.push_back(thisTruth->phi());
150  eValues.push_back((thisTruth->e()*0.001-5.08307000e-01)*e_diff); //the lowest value of energy: 5.08307000e-01
151 
152  eventPxSum += thisTruth->px();
153  eventPySum += thisTruth->py();
154  truthMultiplicity++;
155  }//accept
156  }//for itruth
157  SG::ReadHandle<xAOD::TruthPileupEventContainer> truthPileupEventContainer;
158  SG::ReadHandle<xAOD::EventInfo> pie = SG::ReadHandle<xAOD::EventInfo>(m_eventInfoContainerName, ctx);
159  if (!m_truthPileUpEventName.key().empty()) {
160  truthPileupEventContainer = SG::ReadHandle<xAOD::TruthPileupEventContainer>(m_truthPileUpEventName, ctx);
161  }
162  puEvents = !m_truthPileUpEventName.key().empty() and truthPileupEventContainer.isValid() ? static_cast<int>( truthPileupEventContainer->size() ) : pie.isValid() ? pie->actualInteractionsPerCrossing() : 0;
163  eventPt = std::sqrt(eventPxSum*eventPxSum + eventPySum*eventPySum)*0.001;
164 
165  std::vector<float> puEventsVec(pxValues.size(), (puEvents-1.55000000e+01)*pu_diff); //min of puEvents= 15.5, max of puEvents=84.5
166  std::vector<float> truthMultiplicityVec(pxValues.size(), (truthMultiplicity-1.80000000e+01)*multi_diff);
167  std::vector<float> eventPtVec(pxValues.size(), (eventPt-3.42359395e-01)*eventPt_diff);
168  std::vector<float> predictions;
169 
170  //Compute the distances using Eigen for Eigen's optimized operations. Initialize matirces. Observed a significant improvement on computing calculation.
171  Eigen::VectorXf ptEigen = Eigen::VectorXf::Map(ptValues.data(), ptValues.size());
172  Eigen::VectorXf phiEigen = Eigen::VectorXf::Map(phiValues.data(), phiValues.size());
173  Eigen::VectorXf etaEigen = Eigen::VectorXf::Map(etaValues.data(), etaValues.size());
174  for (std::size_t i = 0; i < truthMultiplicity; ++i) {
175  float multiplicity_0p05 = 0.0, multiplicity_0p2 = 0.0;
176  float sum_0p05 = 0.0, sum_0p2 = 0.0;
177  float pt_0p05 = 0.0, pt_0p2 = 0.0;
178  float deltaEtaI = etaEigen[i];
179  float phiI = phiEigen[i];
180  for (std::size_t j = 0; j < truthMultiplicity; ++j) {
181  if (i == j) continue; // Skip the particle itself
182  float deltaEta = deltaEtaI - etaEigen[j];
183  float deltaPhi = phiI - phiEigen[j];
184  if (deltaPhi > M_PI) {
185  deltaPhi -= 2.0 * M_PI;
186  }
187  float distances = std::sqrt(deltaEta * deltaEta + deltaPhi * deltaPhi);
188  if (distances < 0.05){
189  multiplicity_0p05++;
190  sum_0p05 += distances;
191  pt_0p05 += ptEigen[j];
192  }
193  if (distances < 0.2){
194  multiplicity_0p2++;
195  sum_0p2 += distances;
196  pt_0p2 += ptEigen[j];
197  }
198  }// for j
199 
200  std::vector<float> featData;
201  featData.push_back(pxValues[i]);
202  featData.push_back(pyValues[i]);
203  featData.push_back(pzValues[i]);
204  featData.push_back(eValues[i]);
205  featData.push_back(ptValues[i]);
206  featData.push_back((multiplicity_0p2 * area0p2) * constant1);
207  featData.push_back((multiplicity_0p05 * area0p05) * constant2);
208  featData.push_back((sum_0p2 * area0p2) * constant3);
209  featData.push_back((sum_0p05 * area0p05) * constant4);
210  featData.push_back(pt_0p2 * constant5);
211  featData.push_back(pt_0p05 * constant6);
212 
213  featData.push_back(puEventsVec[i]);
214  featData.push_back(truthMultiplicityVec[i]);
215  featData.push_back(eventPtVec[i]);
216 
217  std::vector<int64_t> input_node_dims;
218  std::vector<char*> input_node_names;
219  input_node_dims = std::get<0>(m_inputInfo);
220  input_node_names = std::get<1>(m_inputInfo);
221 
222  std::vector<int64_t> output_node_dims;
223  std::vector<char*> output_node_names;
224  output_node_dims = std::get<0>(m_outputInfo);
225  output_node_names = std::get<1>(m_outputInfo);
226 
227  Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
228  input_node_dims[0]=1;
229  Ort::Value input_data = Ort::Value::CreateTensor(memoryInfo, featData.data(), featData.size(), input_node_dims.data(), input_node_dims.size());
230  Ort::RunOptions run_options(nullptr);
231  //Run the inference
232  Ort::Session& mysession ATLAS_THREAD_SAFE = *m_session;
233  auto output_values = mysession.Run(run_options, input_node_names.data(), &input_data, input_node_names.size(), output_node_names.data(), output_node_names.size());
234  float* predictionData = output_values[0].GetTensorMutableData<float>();
235  float prediction = predictionData[0];
236 
237  predictions.push_back(prediction);
238  }//for i
239  float threshold = m_MLthreshold;
240  ATH_MSG_ALWAYS("ML threshold:" << threshold);
241  int badTracks = 0;
242  for (float prediction : predictions) {
243  if (prediction > threshold) {
244  badTracks++;
245  }
246  }
247  float rouletteScore = static_cast<float>(badTracks) / static_cast<float>(truthMultiplicity);
248 
249  FilterReporter filter(m_filterParams, false, ctx);
250  bool pass = false;
251  int decision = rouletteScore == 0;
252  if (decision==0){ //if ML decision is False, it goes to the MC-overlay workflow
253  pass = true;
254  }
255  else{
256  pass = false;
257  }
258 
259  if (m_invertfilter) {
260  pass =! pass;
261  }
262  filter.setPassed(pass);
263  ATH_MSG_ALWAYS("End TrackOverlayDecisionAlg, difference in filters: "<<(pass ? "found" : "not found")<<"="<<pass<<", invert="<<m_invertfilter);
264  return StatusCode::SUCCESS;
265 }
266 
267 
268 }// end namespace TrackOverlayDecisionAlg
FilterReporter
a guard class for use with ref FilterReporterParams
Definition: FilterReporter.h:35
python.tests.PyTestsLib.finalize
def finalize(self)
_info( "content of StoreGate..." ) self.sg.dump()
Definition: PyTestsLib.py:53
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TrackOverlayDecisionAlg::px_diff
const float px_diff
Definition: TrackOverlayDecisionAlg.h:29
TrackOverlayDecisionAlg::area0p05
const float area0p05
Definition: TrackOverlayDecisionAlg.h:38
xAOD::TruthParticle_v1::pz
float pz() const
The z component of the particle's momentum.
IAthSelectionTool::CutResult
Definition: IAthSelectionTool.h:30
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:70
xAOD::deltaPhi
setSAddress setEtaMS setDirPhiMS setDirZMS setBarrelRadius setEndcapAlpha setEndcapRadius setInterceptInner setEtaMap setEtaBin setIsTgcFailure setDeltaPt deltaPhi
Definition: L2StandAloneMuon_v1.cxx:160
python.PyParticleTools.getTruthParticles
def getTruthParticles(aKey)
Definition: PyParticleTools.py:102
initialize
void initialize()
Definition: run_EoverP.cxx:894
CutsMETMaker::accept
StatusCode accept(const xAOD::Muon *mu)
Definition: CutsMETMaker.cxx:18
TruthPileupEvent.h
xAOD::TruthParticle_v1::px
float px() const
The x component of the particle's momentum.
M_PI
#define M_PI
Definition: ActiveFraction.h:11
TrackOverlayDecisionAlg::constant3
const float constant3
Definition: TrackOverlayDecisionAlg.h:41
xAOD::TruthParticle_v1::py
float py() const
The y component of the particle's momentum.
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
python.oracle.Session
Session
Definition: oracle.py:78
TrackOverlayDecisionAlg
Definition: TrackOverlayDecisionAlg.cxx:28
AthReentrantAlgorithm
An algorithm that can be simultaneously executed in multiple threads.
Definition: AthReentrantAlgorithm.h:83
covarianceTool.filter
filter
Definition: covarianceTool.py:514
LArG4FSStartPointFilterLegacy.execute
execute
Definition: LArG4FSStartPointFilterLegacy.py:20
EventID.h
This class provides a unique identification for each event, in terms of run/event number and/or a tim...
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::TruthParticle_v1::e
virtual double e() const override final
The total energy of the particle.
event
POOL::TEvent event(POOL::TEvent::kClassAccess)
P4Helpers::deltaEta
double deltaEta(const I4Momentum &p1, const I4Momentum &p2)
Computes efficiently .
Definition: P4Helpers.h:53
lumiFormat.i
int i
Definition: lumiFormat.py:92
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
DMTest::links
links
Definition: CLinks_v1.cxx:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TruthParticle_v1
Class describing a truth particle in the MC record.
Definition: TruthParticle_v1.h:41
xAOD::TruthEvent_v1
Class describing a signal truth event in the MC record.
Definition: TruthEvent_v1.h:35
ATH_MSG_ALWAYS
#define ATH_MSG_ALWAYS(x)
Definition: AthMsgStreamMacros.h:35
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
TrackOverlayDecisionAlg::constant1
const float constant1
Definition: TrackOverlayDecisionAlg.h:39
SG::ReadHandle::isValid
virtual bool isValid() override final
Can the handle be successfully dereferenced?
TrackOverlayDecisionAlg::py_diff
const float py_diff
Definition: TrackOverlayDecisionAlg.h:30
TrackOverlayDecisionAlg::pt_diff
const float pt_diff
Definition: TrackOverlayDecisionAlg.h:32
FilterReporter.h
PathResolver.h
TrackOverlayDecisionAlg::constant5
const float constant5
Definition: TrackOverlayDecisionAlg.h:43
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
threshold
Definition: chainparser.cxx:74
xAOD::TruthParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TruthParticle_v1.cxx:174
TrackOverlayDecisionAlg::area0p2
const float area0p2
Definition: TrackOverlayDecisionAlg.h:37
TrackOverlayDecisionAlg::eventPt_diff
const float eventPt_diff
Definition: TrackOverlayDecisionAlg.h:36
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
TrackOverlayDecisionAlg::pz_diff
const float pz_diff
Definition: TrackOverlayDecisionAlg.h:31
DataVector::end
const_iterator end() const noexcept
Return a const_iterator pointing past the end of the collection.
TrackOverlayDecisionAlg::TrackOverlayDecisionAlg
Definition: TrackOverlayDecisionAlg.h:46
TrackOverlayDecisionAlg::constant2
const float constant2
Definition: TrackOverlayDecisionAlg.h:40
xAOD::TruthParticle_v1::phi
virtual double phi() const override final
The azimuthal angle ( ) of the particle.
Definition: TruthParticle_v1.cxx:181
TrackOverlayDecisionAlg::constant6
const float constant6
Definition: TrackOverlayDecisionAlg.h:44
TrackOverlayDecisionAlg::multi_diff
const float multi_diff
Definition: TrackOverlayDecisionAlg.h:35
xAOD::TruthParticle_v1::pt
virtual double pt() const override final
The transverse momentum ( ) of the particle.
Definition: TruthParticle_v1.cxx:166
TruthPileupEventAuxContainer.h
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
DataVector::at
const T * at(size_type n) const
Access an element, as an rvalue.
TrackOverlayDecisionAlg::pu_diff
const float pu_diff
Definition: TrackOverlayDecisionAlg.h:34
TrackOverlayDecisionAlg::constant4
const float constant4
Definition: TrackOverlayDecisionAlg.h:42
TruthParticle.h
DataVector::size
size_type size() const noexcept
Returns the number of elements in the collection.
DataVector::begin
const_iterator begin() const noexcept
Return a const_iterator pointing at the beginning of the collection.
TrackOverlayDecisionAlg.h
xAOD::EventInfo_v1::actualInteractionsPerCrossing
float actualInteractionsPerCrossing() const
Average interactions per crossing for the current BCID - for in-time pile-up.
Definition: EventInfo_v1.cxx:380
TrackOverlayDecisionAlg::e_diff
const float e_diff
Definition: TrackOverlayDecisionAlg.h:33