ATLAS Offline Software
Loading...
Searching...
No Matches
TrackOverlayDecisionAlg.cxx
Go to the documentation of this file.
1/*
2 * * Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 * * */
4//
5#include <fstream>
6#include <string>
7#include <iostream>
8#include <vector>
9#include <sstream>
10#include <unistd.h>
11#include <Eigen/Core>
12//
13#include "GaudiKernel/SystemOfUnits.h"
14#include "Gaudi/Property.h"
19
20// ONNX Runtime include(s).
21#include <onnxruntime_cxx_api.h>
22//
25
27 TrackOverlayDecisionAlg::TrackOverlayDecisionAlg( const std::string& name, ISvcLocator* pSvcLocator ) :
28 ::AthReentrantAlgorithm( name, pSvcLocator )
29 {
30 }
31
33{
34 ATH_CHECK(m_filterParams.initialize(false));
35 ATH_CHECK( m_eventInfoContainerName.initialize() );
36 ATH_CHECK( m_truthParticleName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthParticleName.key().empty() ) );
37 ATH_CHECK(m_truthSelectionTool.retrieve(EnableTool {not m_truthParticleName.key().empty()} ));
38 ATH_CHECK( m_truthEventName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthEventName.key().empty() ) );
39 ATH_CHECK( m_truthPileUpEventName.initialize( (m_pileupSwitch == "PileUp" or m_pileupSwitch == "All") and not m_truthPileUpEventName.key().empty() ) );
40 ATH_CHECK(m_svc.retrieve());
41 std::string this_file = __FILE__;
42 const std::string model_path = PathResolverFindCalibFile("TrackOverlay/TrackOverlay_J7_model.onnx");
43 Ort::SessionOptions session_options;
44
45 m_session = std::make_unique<Ort::Session>(m_svc->env(), model_path.c_str(), session_options);
48 return StatusCode::SUCCESS;
49}
50
52{
53 ATH_MSG_VERBOSE ( "Finalizing ..." );
54 ATH_MSG_VERBOSE("-----------------------------------------------------------------");
55 ATH_MSG_VERBOSE("m_filterParams.summary()" << m_filterParams.summary());
56 ATH_MSG_VERBOSE("-----------------------------------------------------------------");
58 ATH_MSG_VERBOSE(" =====================================================================");
59
60 return StatusCode::SUCCESS;
61}
62
63const std::vector<const xAOD::TruthParticle*> TrackOverlayDecisionAlg::getTruthParticles() const {
64 std::vector<const xAOD::TruthParticle*> tempVec {};
65 if (m_pileupSwitch == "All") {
66 if (m_truthParticleName.key().empty()) {
67 return tempVec;
68 }
70 if (not truthParticleContainer.isValid()) {
71 return tempVec;
72 }
73 tempVec.insert(tempVec.begin(), truthParticleContainer->begin(), truthParticleContainer->end());
74 } else {
75 if (m_pileupSwitch == "HardScatter") {
76 if (not m_truthEventName.key().empty()) {
77 ATH_MSG_VERBOSE("Getting TruthEvents container.");
79 const xAOD::TruthEvent* event = (truthEventContainer.isValid()) ? truthEventContainer->at(0) : nullptr;
80 if (not event) {
81 return tempVec;
82 }
83 const auto& links = event->truthParticleLinks();
84 tempVec.reserve(event->nTruthParticles());
85 for (const auto& link : links) {
86 if (link.isValid()){
87 tempVec.push_back(*link);
88 }
89 }
90 }
91 }else if (m_pileupSwitch == "PileUp") {
92 if (not m_truthPileUpEventName.key().empty()) {
93 ATH_MSG_VERBOSE("getting TruthPileupEvents container");
94 // get truth particles from all pileup events
96 if (truthPileupEventContainer.isValid()) {
97 const unsigned int nPileup = truthPileupEventContainer->size();
98 tempVec.reserve(nPileup * 200); // quick initial guess, will still save some time
99 for (unsigned int i(0); i != nPileup; ++i) {
100 const auto *eventPileup = truthPileupEventContainer->at(i);
101 // get truth particles from each pileup event
102 int ntruth = eventPileup->nTruthParticles();
103 ATH_MSG_VERBOSE("Adding " << ntruth << " truth particles from TruthPileupEvents container");
104 const auto& links = eventPileup->truthParticleLinks();
105 for (const auto& link : links) {
106 if (link.isValid()){
107 tempVec.push_back(*link);
108 }
109 }
110 }
111 } else {
112 ATH_MSG_ERROR("no entries in TruthPileupEvents container!");
113 }
114 }
115 } else {
116 ATH_MSG_ERROR("bad value for PileUpSwitch");
117 }
118 }
119 return tempVec;
120}
121
122StatusCode TrackOverlayDecisionAlg::execute(const EventContext &ctx) const
123{
124 ATH_MSG_DEBUG ("Executing ...");
125
126 std::vector<const xAOD::TruthParticle*> truthParticlesVec = TrackOverlayDecisionAlg::getTruthParticles();
127
128 //Access truth info for the NN input
129 float eventPxSum = 0.0;
130 float eventPySum = 0.0;
131 float eventPt = 0.0;
132 float puEvents = 0.0;
133
134 std::vector<float> pxValues, pyValues, pzValues, eValues, etaValues, phiValues, ptValues;
135 float truthMultiplicity = 0.0;
136 const int truthParticles = truthParticlesVec.size();
137 for (int itruth = 0; itruth < truthParticles; itruth++) {
138 const xAOD::TruthParticle* thisTruth = truthParticlesVec[itruth];
139 const IAthSelectionTool::CutResult accept = m_truthSelectionTool->accept(thisTruth);
140 if(accept){
141 pxValues.push_back((thisTruth->px()*0.001-1.46988000e+03)* px_diff); //as MinMaxScaler: 1.46988000e+03 is the lowest value of px from a J7 sample; *(0.001) is used to convert unit rather than *(1/1000) to speed up.
142 pyValues.push_back((thisTruth->py()*0.001-1.35142000e+03)* py_diff); //the lowest value of py: 1.35142000e+03
143 pzValues.push_back((thisTruth->pz()*0.001-1.50464000e+03)* pz_diff); //the lowest value of pz: 1.50464000e+03
144 ptValues.push_back((thisTruth->pt()*0.001-5.00006000e-01)* pt_diff); //the lowest value of pt: 5.00006000e-01
145
146 etaValues.push_back(thisTruth->eta());
147 phiValues.push_back(thisTruth->phi());
148 eValues.push_back((thisTruth->e()*0.001-5.08307000e-01)*e_diff); //the lowest value of energy: 5.08307000e-01
149
150 eventPxSum += thisTruth->px();
151 eventPySum += thisTruth->py();
152 truthMultiplicity++;
153 }//accept
154 }//for itruth
155 SG::ReadHandle<xAOD::TruthPileupEventContainer> truthPileupEventContainer;
157 if (!m_truthPileUpEventName.key().empty()) {
159 }
160 puEvents = !m_truthPileUpEventName.key().empty() and truthPileupEventContainer.isValid() ? static_cast<int>( truthPileupEventContainer->size() ) : pie.isValid() ? pie->actualInteractionsPerCrossing() : 0;
161 eventPt = std::sqrt(eventPxSum*eventPxSum + eventPySum*eventPySum)*0.001;
162
163 std::vector<float> puEventsVec(pxValues.size(), (puEvents-1.55000000e+01)*pu_diff); //min of puEvents= 15.5, max of puEvents=84.5
164 std::vector<float> truthMultiplicityVec(pxValues.size(), (truthMultiplicity-1.80000000e+01)*multi_diff);
165 std::vector<float> eventPtVec(pxValues.size(), (eventPt-3.42359395e-01)*eventPt_diff);
166 std::vector<float> predictions;
167
168 //Compute the distances using Eigen for Eigen's optimized operations. Initialize matirces. Observed a significant improvement on computing calculation.
169 Eigen::VectorXf ptEigen = Eigen::VectorXf::Map(ptValues.data(), ptValues.size());
170 Eigen::VectorXf phiEigen = Eigen::VectorXf::Map(phiValues.data(), phiValues.size());
171 Eigen::VectorXf etaEigen = Eigen::VectorXf::Map(etaValues.data(), etaValues.size());
172 for (std::size_t i = 0; i < truthMultiplicity; ++i) {
173 float multiplicity_0p05 = 0.0, multiplicity_0p2 = 0.0;
174 float sum_0p05 = 0.0, sum_0p2 = 0.0;
175 float pt_0p05 = 0.0, pt_0p2 = 0.0;
176 float deltaEtaI = etaEigen[i];
177 float phiI = phiEigen[i];
178 for (std::size_t j = 0; j < truthMultiplicity; ++j) {
179 if (i == j) continue; // Skip the particle itself
180 float deltaEta = deltaEtaI - etaEigen[j];
181 float deltaPhi = phiI - phiEigen[j];
182 if (deltaPhi > M_PI) {
183 deltaPhi -= 2.0 * M_PI;
184 }
185 float distances = std::sqrt(deltaEta * deltaEta + deltaPhi * deltaPhi);
186 if (distances < 0.05){
187 multiplicity_0p05++;
188 sum_0p05 += distances;
189 pt_0p05 += ptEigen[j];
190 }
191 if (distances < 0.2){
192 multiplicity_0p2++;
193 sum_0p2 += distances;
194 pt_0p2 += ptEigen[j];
195 }
196 }// for j
197
198 std::vector<float> featData;
199 featData.push_back(pxValues[i]);
200 featData.push_back(pyValues[i]);
201 featData.push_back(pzValues[i]);
202 featData.push_back(eValues[i]);
203 featData.push_back(ptValues[i]);
204 featData.push_back((multiplicity_0p2 * area0p2) * constant1);
205 featData.push_back((multiplicity_0p05 * area0p05) * constant2);
206 featData.push_back((sum_0p2 * area0p2) * constant3);
207 featData.push_back((sum_0p05 * area0p05) * constant4);
208 featData.push_back(pt_0p2 * constant5);
209 featData.push_back(pt_0p05 * constant6);
210
211 featData.push_back(puEventsVec[i]);
212 featData.push_back(truthMultiplicityVec[i]);
213 featData.push_back(eventPtVec[i]);
214
215 std::vector<int64_t> input_node_dims;
216 std::vector<char*> input_node_names;
217 input_node_dims = std::get<0>(m_inputInfo);
218 input_node_names = std::get<1>(m_inputInfo);
219
220 std::vector<int64_t> output_node_dims;
221 std::vector<char*> output_node_names;
222 output_node_dims = std::get<0>(m_outputInfo);
223 output_node_names = std::get<1>(m_outputInfo);
224
225 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
226 input_node_dims[0]=1;
227 Ort::Value input_data = Ort::Value::CreateTensor(memoryInfo, featData.data(), featData.size(), input_node_dims.data(), input_node_dims.size());
228 Ort::RunOptions run_options(nullptr);
229 //Run the inference
230 Ort::Session& mysession ATLAS_THREAD_SAFE = *m_session;
231 auto output_values = mysession.Run(run_options, input_node_names.data(), &input_data, input_node_names.size(), output_node_names.data(), output_node_names.size());
232 float* predictionData = output_values[0].GetTensorMutableData<float>();
233 float prediction = predictionData[0];
234
235 predictions.push_back(prediction);
236 }//for i
237 float threshold = m_MLthreshold;
238 ATH_MSG_ALWAYS("ML threshold:" << threshold);
239 int badTracks = 0;
240 for (float prediction : predictions) {
241 if (prediction > threshold) {
242 badTracks++;
243 }
244 }
245 float rouletteScore = static_cast<float>(badTracks) / static_cast<float>(truthMultiplicity);
246
247 FilterReporter filter(m_filterParams, false, ctx);
248 bool pass = false;
249 int decision = rouletteScore == 0;
250 if (decision==0){ //if ML decision is False, it goes to the MC-overlay workflow
251 pass = true;
252 }
253 else{
254 pass = false;
255 }
256
257 if (m_invertfilter) {
258 pass =! pass;
259 }
260 filter.setPassed(pass);
261 ATH_MSG_ALWAYS("End TrackOverlayDecisionAlg, difference in filters: "<<(pass ? "found" : "not found")<<"="<<pass<<", invert="<<m_invertfilter);
262 return StatusCode::SUCCESS;
263}
264
265
266}// end namespace TrackOverlayDecisionAlg
#define M_PI
Scalar deltaPhi(const MatrixBase< Derived > &vec) const
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_ALWAYS(x)
#define ATH_MSG_DEBUG(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define ATLAS_THREAD_SAFE
An algorithm that can be simultaneously executed in multiple threads.
a guard class for use with ref FilterReporterParams
virtual bool isValid() override final
Can the handle be successfully dereferenced?
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfoContainerName
std::tuple< std::vector< int64_t >, std::vector< char * > > GetInputNodeInfo(const std::unique_ptr< Ort::Session > &session)
Gaudi::Property< bool > m_invertfilter
invert filter decision at the end
virtual StatusCode execute(const EventContext &ctx) const override final
Athena algorithm's interface method execute()
std::tuple< std::vector< int64_t >, std::vector< char * > > m_outputInfo
SG::ReadHandleKey< xAOD::TruthEventContainer > m_truthEventName
std::tuple< std::vector< int64_t >, std::vector< char * > > GetOutputNodeInfo(const std::unique_ptr< Ort::Session > &session)
std::tuple< std::vector< int64_t >, std::vector< char * > > m_inputInfo
virtual StatusCode finalize() override final
Athena algorithm's interface method finalize()
TrackOverlayDecisionAlg(const std::string &name, ISvcLocator *pSvcLocator)
Constructor with parameters.
virtual StatusCode initialize() override final
Athena algorithm's interface method initialize()
const std::vector< const xAOD::TruthParticle * > getTruthParticles() const
SG::ReadHandleKey< xAOD::TruthParticleContainer > m_truthParticleName
SG::ReadHandleKey< xAOD::TruthPileupEventContainer > m_truthPileUpEventName
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
float px() const
The x component of the particle's momentum.
virtual double e() const override final
The total energy of the particle.
virtual double pt() const override final
The transverse momentum ( ) of the particle.
float py() const
The y component of the particle's momentum.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
virtual double phi() const override final
The azimuthal angle ( ) of the particle.
float pz() const
The z component of the particle's momentum.
TruthEvent_v1 TruthEvent
Typedef to implementation.
Definition TruthEvent.h:17
TruthParticle_v1 TruthParticle
Typedef to implementation.