ATLAS Offline Software
Loading...
Searching...
No Matches
TrackOverlayDecisionAlg.cxx
Go to the documentation of this file.
1
5
7
8#include "GaudiKernel/SystemOfUnits.h"
12//
15
16#include <Eigen/Core>
17
19 TrackOverlayDecisionAlg::TrackOverlayDecisionAlg( const std::string& name, ISvcLocator* pSvcLocator ) :
20 ::AthReentrantAlgorithm( name, pSvcLocator )
21 {
22 }
23
25{
26 ATH_CHECK(m_filterParams.initialize(false));
27 ATH_CHECK( m_eventInfoContainerName.initialize() );
28 ATH_CHECK( m_truthParticleName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthParticleName.key().empty() ) );
29 ATH_CHECK(m_truthSelectionTool.retrieve(EnableTool {not m_truthParticleName.key().empty()} ));
30 ATH_CHECK( m_truthEventName.initialize( (m_pileupSwitch == "HardScatter" or m_pileupSwitch == "All") and not m_truthEventName.key().empty() ) );
31 ATH_CHECK( m_truthPileUpEventName.initialize( (m_pileupSwitch == "PileUp" or m_pileupSwitch == "All") and not m_truthPileUpEventName.key().empty() ) );
32 ATH_CHECK(m_svc.retrieve());
33 std::string this_file = __FILE__;
34 const std::string model_path = PathResolverFindCalibFile("TrackOverlay/TrackOverlay_J7_model.onnx");
35 Ort::SessionOptions session_options;
36
37 m_session = std::make_unique<Ort::Session>(m_svc->env(), model_path.c_str(), session_options);
40 return StatusCode::SUCCESS;
41}
42
44{
45 ATH_MSG_VERBOSE ( "Finalizing ..." );
46 ATH_MSG_VERBOSE("-----------------------------------------------------------------");
47 ATH_MSG_VERBOSE("m_filterParams.summary()" << m_filterParams.summary());
48 ATH_MSG_VERBOSE("-----------------------------------------------------------------");
50 ATH_MSG_VERBOSE(" =====================================================================");
51
52 return StatusCode::SUCCESS;
53}
54
55const std::vector<const xAOD::TruthParticle*> TrackOverlayDecisionAlg::getTruthParticles() const {
56 std::vector<const xAOD::TruthParticle*> tempVec {};
57 if (m_pileupSwitch == "All") {
58 if (m_truthParticleName.key().empty()) {
59 return tempVec;
60 }
62 if (not truthParticleContainer.isValid()) {
63 return tempVec;
64 }
65 tempVec.insert(tempVec.begin(), truthParticleContainer->begin(), truthParticleContainer->end());
66 } else {
67 if (m_pileupSwitch == "HardScatter") {
68 if (not m_truthEventName.key().empty()) {
69 ATH_MSG_VERBOSE("Getting TruthEvents container.");
71 const xAOD::TruthEvent* event = (truthEventContainer.isValid()) ? truthEventContainer->at(0) : nullptr;
72 if (not event) {
73 return tempVec;
74 }
75 const auto& links = event->truthParticleLinks();
76 tempVec.reserve(event->nTruthParticles());
77 for (const auto& link : links) {
78 if (link.isValid()){
79 tempVec.push_back(*link);
80 }
81 }
82 }
83 }else if (m_pileupSwitch == "PileUp") {
84 if (not m_truthPileUpEventName.key().empty()) {
85 ATH_MSG_VERBOSE("getting TruthPileupEvents container");
86 // get truth particles from all pileup events
88 if (truthPileupEventContainer.isValid()) {
89 const unsigned int nPileup = truthPileupEventContainer->size();
90 tempVec.reserve(nPileup * 200); // quick initial guess, will still save some time
91 for (unsigned int i(0); i != nPileup; ++i) {
92 const auto *eventPileup = truthPileupEventContainer->at(i);
93 // get truth particles from each pileup event
94 int ntruth = eventPileup->nTruthParticles();
95 ATH_MSG_VERBOSE("Adding " << ntruth << " truth particles from TruthPileupEvents container");
96 const auto& links = eventPileup->truthParticleLinks();
97 for (const auto& link : links) {
98 if (link.isValid()){
99 tempVec.push_back(*link);
100 }
101 }
102 }
103 } else {
104 ATH_MSG_ERROR("no entries in TruthPileupEvents container!");
105 }
106 }
107 } else {
108 ATH_MSG_ERROR("bad value for PileUpSwitch");
109 }
110 }
111 return tempVec;
112}
113
114StatusCode TrackOverlayDecisionAlg::execute(const EventContext &ctx) const
115{
116 ATH_MSG_DEBUG ("Executing ...");
117
118 std::vector<const xAOD::TruthParticle*> truthParticlesVec = TrackOverlayDecisionAlg::getTruthParticles();
119
120 //Access truth info for the NN input
121 float eventPxSum = 0.0;
122 float eventPySum = 0.0;
123 float eventPt = 0.0;
124 float puEvents = 0.0;
125
126 std::vector<float> pxValues, pyValues, pzValues, eValues, etaValues, phiValues, ptValues;
127 float truthMultiplicity = 0.0;
128 const int truthParticles = truthParticlesVec.size();
129 for (int itruth = 0; itruth < truthParticles; itruth++) {
130 const xAOD::TruthParticle* thisTruth = truthParticlesVec[itruth];
131 const IAthSelectionTool::CutResult accept = m_truthSelectionTool->accept(thisTruth);
132 if(accept){
133 pxValues.push_back((thisTruth->px()*0.001-1.46988000e+03)* px_diff); //as MinMaxScaler: 1.46988000e+03 is the lowest value of px from a J7 sample; *(0.001) is used to convert unit rather than *(1/1000) to speed up.
134 pyValues.push_back((thisTruth->py()*0.001-1.35142000e+03)* py_diff); //the lowest value of py: 1.35142000e+03
135 pzValues.push_back((thisTruth->pz()*0.001-1.50464000e+03)* pz_diff); //the lowest value of pz: 1.50464000e+03
136 ptValues.push_back((thisTruth->pt()*0.001-5.00006000e-01)* pt_diff); //the lowest value of pt: 5.00006000e-01
137
138 etaValues.push_back(thisTruth->eta());
139 phiValues.push_back(thisTruth->phi());
140 eValues.push_back((thisTruth->e()*0.001-5.08307000e-01)*e_diff); //the lowest value of energy: 5.08307000e-01
141
142 eventPxSum += thisTruth->px();
143 eventPySum += thisTruth->py();
144 truthMultiplicity++;
145 }//accept
146 }//for itruth
147 SG::ReadHandle<xAOD::TruthPileupEventContainer> truthPileupEventContainer;
149 if (!m_truthPileUpEventName.key().empty()) {
151 }
152 puEvents = !m_truthPileUpEventName.key().empty() and truthPileupEventContainer.isValid() ? static_cast<int>( truthPileupEventContainer->size() ) : pie.isValid() ? pie->actualInteractionsPerCrossing() : 0;
153 eventPt = std::sqrt(eventPxSum*eventPxSum + eventPySum*eventPySum)*0.001;
154
155 std::vector<float> puEventsVec(pxValues.size(), (puEvents-1.55000000e+01)*pu_diff); //min of puEvents= 15.5, max of puEvents=84.5
156 std::vector<float> truthMultiplicityVec(pxValues.size(), (truthMultiplicity-1.80000000e+01)*multi_diff);
157 std::vector<float> eventPtVec(pxValues.size(), (eventPt-3.42359395e-01)*eventPt_diff);
158 std::vector<float> predictions;
159
160 //Compute the distances using Eigen for Eigen's optimized operations. Initialize matrices. Observed a significant improvement on computing calculation.
161 Eigen::VectorXf ptEigen = Eigen::VectorXf::Map(ptValues.data(), ptValues.size());
162 Eigen::VectorXf phiEigen = Eigen::VectorXf::Map(phiValues.data(), phiValues.size());
163 Eigen::VectorXf etaEigen = Eigen::VectorXf::Map(etaValues.data(), etaValues.size());
164 for (std::size_t i = 0; i < truthMultiplicity; ++i) {
165 float multiplicity_0p05 = 0.0, multiplicity_0p2 = 0.0;
166 float sum_0p05 = 0.0, sum_0p2 = 0.0;
167 float pt_0p05 = 0.0, pt_0p2 = 0.0;
168 float deltaEtaI = etaEigen[i];
169 float phiI = phiEigen[i];
170 for (std::size_t j = 0; j < truthMultiplicity; ++j) {
171 if (i == j) continue; // Skip the particle itself
172 float deltaEta = deltaEtaI - etaEigen[j];
173 float deltaPhi = phiI - phiEigen[j];
174 if (deltaPhi > M_PI) {
175 deltaPhi -= 2.0 * M_PI;
176 }
177 else if (deltaPhi < -M_PI) {
178 deltaPhi += 2.0 * M_PI;
179 }
180 float distances = std::sqrt(deltaEta * deltaEta + deltaPhi * deltaPhi);
181 if (distances < 0.05){
182 multiplicity_0p05++;
183 sum_0p05 += distances;
184 pt_0p05 += ptEigen[j];
185 }
186 if (distances < 0.2){
187 multiplicity_0p2++;
188 sum_0p2 += distances;
189 pt_0p2 += ptEigen[j];
190 }
191 }// for j
192
193 std::vector<float> featData;
194 featData.push_back(pxValues[i]);
195 featData.push_back(pyValues[i]);
196 featData.push_back(pzValues[i]);
197 featData.push_back(eValues[i]);
198 featData.push_back(ptValues[i]);
199 featData.push_back((multiplicity_0p2 * area0p2) * constant1);
200 featData.push_back((multiplicity_0p05 * area0p05) * constant2);
201 featData.push_back((sum_0p2 * area0p2) * constant3);
202 featData.push_back((sum_0p05 * area0p05) * constant4);
203 featData.push_back(pt_0p2 * constant5);
204 featData.push_back(pt_0p05 * constant6);
205
206 featData.push_back(puEventsVec[i]);
207 featData.push_back(truthMultiplicityVec[i]);
208 featData.push_back(eventPtVec[i]);
209
210 std::vector<int64_t> input_node_dims;
211 std::vector<char*> input_node_names;
212 input_node_dims = std::get<0>(m_inputInfo);
213 input_node_names = std::get<1>(m_inputInfo);
214
215 std::vector<int64_t> output_node_dims;
216 std::vector<char*> output_node_names;
217 output_node_dims = std::get<0>(m_outputInfo);
218 output_node_names = std::get<1>(m_outputInfo);
219
220 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
221 input_node_dims[0]=1;
222 Ort::Value input_data = Ort::Value::CreateTensor(memoryInfo, featData.data(), featData.size(), input_node_dims.data(), input_node_dims.size());
223 Ort::RunOptions run_options(nullptr);
224 //Run the inference
225 Ort::Session& mysession ATLAS_THREAD_SAFE = *m_session;
226 auto output_values = mysession.Run(run_options, input_node_names.data(), &input_data, input_node_names.size(), output_node_names.data(), output_node_names.size());
227 float* predictionData = output_values[0].GetTensorMutableData<float>();
228 float prediction = predictionData[0];
229
230 predictions.push_back(prediction);
231 }//for i
232 float threshold = m_MLthreshold;
233 ATH_MSG_ALWAYS("ML threshold:" << threshold);
234 int badTracks = 0;
235 for (float prediction : predictions) {
236 if (prediction > threshold) {
237 badTracks++;
238 }
239 }
240 if (truthMultiplicity == 0){
241 ATH_MSG_ERROR("truthMultiplicity is zero!");
242 return StatusCode::FAILURE;
243 }
244 float rouletteScore = static_cast<float>(badTracks) / static_cast<float>(truthMultiplicity);
245
246 FilterReporter filter(m_filterParams, false, ctx);
247 bool pass = false;
248 int decision = rouletteScore == 0;
249 if (decision==0){ //if ML decision is False, it goes to the MC-overlay workflow
250 pass = true;
251 }
252 else{
253 pass = false;
254 }
255
256 if (m_invertfilter) {
257 pass =! pass;
258 }
259 filter.setPassed(pass);
260 ATH_MSG_ALWAYS("End TrackOverlayDecisionAlg, difference in filters: "<<(pass ? "found" : "not found")<<"="<<pass<<", invert="<<m_invertfilter);
261 return StatusCode::SUCCESS;
262}
263
264
265}// end namespace TrackOverlayDecisionAlg
#define M_PI
Scalar deltaPhi(const MatrixBase< Derived > &vec) const
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_ALWAYS(x)
#define ATH_MSG_DEBUG(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Handle class for reading from StoreGate.
#define ATLAS_THREAD_SAFE
An algorithm that can be simultaneously executed in multiple threads.
a guard class for use with ref FilterReporterParams
virtual bool isValid() override final
Can the handle be successfully dereferenced?
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfoContainerName
std::tuple< std::vector< int64_t >, std::vector< char * > > GetInputNodeInfo(const std::unique_ptr< Ort::Session > &session)
Gaudi::Property< bool > m_invertfilter
invert filter decision at the end
virtual StatusCode execute(const EventContext &ctx) const override final
Athena algorithm's interface method execute().
std::tuple< std::vector< int64_t >, std::vector< char * > > m_outputInfo
SG::ReadHandleKey< xAOD::TruthEventContainer > m_truthEventName
std::tuple< std::vector< int64_t >, std::vector< char * > > GetOutputNodeInfo(const std::unique_ptr< Ort::Session > &session)
std::tuple< std::vector< int64_t >, std::vector< char * > > m_inputInfo
virtual StatusCode finalize() override final
Athena algorithm's interface method finalize().
TrackOverlayDecisionAlg(const std::string &name, ISvcLocator *pSvcLocator)
Constructor with parameters.
virtual StatusCode initialize() override final
Athena algorithm's interface method initialize().
const std::vector< const xAOD::TruthParticle * > getTruthParticles() const
SG::ReadHandleKey< xAOD::TruthParticleContainer > m_truthParticleName
SG::ReadHandleKey< xAOD::TruthPileupEventContainer > m_truthPileUpEventName
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
float px() const
The x component of the particle's momentum.
virtual double e() const override final
The total energy of the particle.
virtual double pt() const override final
The transverse momentum ( ) of the particle.
float py() const
The y component of the particle's momentum.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
virtual double phi() const override final
The azimuthal angle ( ) of the particle.
float pz() const
The z component of the particle's momentum.
Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration.
TruthEvent_v1 TruthEvent
Typedef to implementation.
Definition TruthEvent.h:17
TruthParticle_v1 TruthParticle
Typedef to implementation.