ATLAS Offline Software
Loading...
Searching...
No Matches
FPGATrackSimGNNGraphConstructionTool.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
4
5#include <TFile.h>
6#include <TTree.h>
8#include <cstdint>
9#include <algorithm>
10#include <map>
12// AthAlgTool
13
14FPGATrackSimGNNGraphConstructionTool::FPGATrackSimGNNGraphConstructionTool(const std::string& algname, const std::string &name, const IInterface *ifc)
15 : AthAlgTool(algname, name, ifc) {}
16
18{
20 if(m_graphTool == "ModuleMap") {
21 if (m_FPGATrackSimMapping->getGNNModuleMapString() != "") {
22 m_moduleMapPath = m_FPGATrackSimMapping->getGNNModuleMapString();
23 }
24 else {
25 ATH_MSG_ERROR("Path to 1st stage NN-based fake track removal ONNX file is empty! If you want to run this pipeline, you need to provide an input file.");
26 return StatusCode::FAILURE;
27 }
28
29 if(m_moduleMapType == "doublet") {
30 loadDoubletModuleMap(); // Load the doublet module map and store entry branches in vectors
31 }
32 else if(m_moduleMapType == "triplet") {
33 loadTripletModuleMap(); // Load the triplet module map and store entry branches in vectors
34 }
35 }
36 else if(m_graphTool == "MetricLearning") {
37 ATH_CHECK( m_MLInferenceTool.retrieve() );
38 m_MLInferenceTool->printModelInfo();
39 assert(m_MLFeatureNamesVec.size() == m_MLFeatureScalesVec.size());
40 }
41
42 return StatusCode::SUCCESS;
43}
44
46// Functions
47
48StatusCode FPGATrackSimGNNGraphConstructionTool::getEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
49{
50 if(m_graphTool == "ModuleMap") {
51 doModuleMap(hits, edges);
52 }
53 else if(m_graphTool == "MetricLearning") {
54 doMetricLearning(hits, edges);
55 }
56
57 return StatusCode::SUCCESS;
58}
59
61{
62 std::unique_ptr<TFile> file(TFile::Open(m_moduleMapPath.c_str()));
63 std::unique_ptr<TTree> tree(static_cast<TTree*>(file->Get("TreeModuleDoublet")));
64
65 unsigned mid1_value = 0;
66 unsigned mid2_value = 0;
67 float z0min_12_value = 0.0;
68 float dphimin_12_value = 0.0;
69 float phiSlopemin_12_value = 0.0;
70 float detamin_12_value = 0.0;
71 float z0max_12_value = 0.0;
72 float dphimax_12_value = 0.0;
73 float phiSlopemax_12_value = 0.0;
74 float detamax_12_value = 0.0;
75
76 tree->SetBranchAddress("Module1", &mid1_value);
77 tree->SetBranchAddress("Module2", &mid2_value);
78 tree->SetBranchAddress("z0min_12", &z0min_12_value);
79 tree->SetBranchAddress("dphimin_12", &dphimin_12_value);
80 tree->SetBranchAddress("phiSlopemin_12", &phiSlopemin_12_value);
81 tree->SetBranchAddress("detamin_12", &detamin_12_value);
82 tree->SetBranchAddress("z0max_12", &z0max_12_value);
83 tree->SetBranchAddress("dphimax_12", &dphimax_12_value);
84 tree->SetBranchAddress("phiSlopemax_12", &phiSlopemax_12_value);
85 tree->SetBranchAddress("detamax_12", &detamax_12_value);
86
87 int64_t nEntries = tree->GetEntries();
88 m_cfgs.reserve(nEntries);
89 for (int64_t i = 0; i < nEntries; ++i) {
90 tree->GetEntry(i);
91
93
94 cfg.mid1 = mid1_value;
95 cfg.mid2 = mid2_value;
96
97 cfg.doubletCuts[0].z0.min = z0min_12_value;
98 cfg.doubletCuts[0].z0.max = z0max_12_value;
99 cfg.doubletCuts[0].dphi.min = dphimin_12_value;
100 cfg.doubletCuts[0].dphi.max = dphimax_12_value;
101 cfg.doubletCuts[0].deta.min = detamin_12_value;
102 cfg.doubletCuts[0].deta.max = detamax_12_value;
103 cfg.doubletCuts[0].phiSlope.min = phiSlopemin_12_value;
104 cfg.doubletCuts[0].phiSlope.max = phiSlopemax_12_value;
105
106 m_cfgs.emplace_back(std::move(cfg));
107 }
108}
109
111{
112 std::unique_ptr<TFile> file(TFile::Open(m_moduleMapPath.c_str()));
113 std::unique_ptr<TTree> tree(static_cast<TTree*>(file->Get("TreeModuleTriplet")));
114
115 unsigned mid1_value = 0;
116 unsigned mid2_value = 0;
117 unsigned mid3_value = 0;
118 unsigned occurence_value = 0;
119
120 float z0min_12_value = 0.0;
121 float z0max_12_value = 0.0;
122 float z0sum_12_value = 0.0;
123 float z0sumSq_12_value = 0.0;
124 float z0mean_12_value = 0.0;
125 float z0rms_12_value = 0.0;
126 float z0min_23_value = 0.0;
127 float z0max_23_value = 0.0;
128 float z0sum_23_value = 0.0;
129 float z0sumSq_23_value = 0.0;
130 float z0mean_23_value = 0.0;
131 float z0rms_23_value = 0.0;
132
133 float dphimin_12_value = 0.0;
134 float dphimax_12_value = 0.0;
135 float dphisum_12_value = 0.0;
136 float dphisumSq_12_value = 0.0;
137 float dphimean_12_value = 0.0;
138 float dphirms_12_value = 0.0;
139 float dphimin_23_value = 0.0;
140 float dphimax_23_value = 0.0;
141 float dphisum_23_value = 0.0;
142 float dphisumSq_23_value = 0.0;
143 float dphimean_23_value = 0.0;
144 float dphirms_23_value = 0.0;
145
146 float phiSlopemin_12_value = 0.0;
147 float phiSlopemax_12_value = 0.0;
148 float phiSlopesum_12_value = 0.0;
149 float phiSlopesumSq_12_value = 0.0;
150 float phiSlopemean_12_value = 0.0;
151 float phiSloperms_12_value = 0.0;
152 float phiSlopemin_23_value = 0.0;
153 float phiSlopemax_23_value = 0.0;
154 float phiSlopesum_23_value = 0.0;
155 float phiSlopesumSq_23_value = 0.0;
156 float phiSlopemean_23_value = 0.0;
157 float phiSloperms_23_value = 0.0;
158
159 float detamin_12_value = 0.0;
160 float detamax_12_value = 0.0;
161 float detasum_12_value = 0.0;
162 float detasumSq_12_value = 0.0;
163 float detamean_12_value = 0.0;
164 float detarms_12_value = 0.0;
165 float detamin_23_value = 0.0;
166 float detamax_23_value = 0.0;
167 float detasum_23_value = 0.0;
168 float detasumSq_23_value = 0.0;
169 float detamean_23_value = 0.0;
170 float detarms_23_value = 0.0;
171
172 float diff_dydx_min_value = 0.0;
173 float diff_dydx_max_value = 0.0;
174 float diff_dydx_sum_value = 0.0;
175 float diff_dydx_sumSq_value = 0.0;
176 float diff_dydx_mean_value = 0.0;
177 float diff_dydx_rms_value = 0.0;
178
179 float diff_dzdr_min_value = 0.0;
180 float diff_dzdr_max_value = 0.0;
181 float diff_dzdr_sum_value = 0.0;
182 float diff_dzdr_sumSq_value = 0.0;
183 float diff_dzdr_mean_value = 0.0;
184 float diff_dzdr_rms_value = 0.0;
185
186 tree->SetBranchAddress("Module1", &mid1_value);
187 tree->SetBranchAddress("Module2", &mid2_value);
188 tree->SetBranchAddress("Module3", &mid3_value);
189 tree->SetBranchAddress("Occurence", &occurence_value);
190
191 tree->SetBranchAddress("z0min_12", &z0min_12_value);
192 tree->SetBranchAddress("z0max_12", &z0max_12_value);
193 tree->SetBranchAddress("z0sum_12", &z0sum_12_value);
194 tree->SetBranchAddress("z0sumSq_12", &z0sumSq_12_value);
195 tree->SetBranchAddress("z0_12_mean", &z0mean_12_value);
196 tree->SetBranchAddress("z0_12_rms", &z0rms_12_value);
197 tree->SetBranchAddress("z0min_23", &z0min_23_value);
198 tree->SetBranchAddress("z0max_23", &z0max_23_value);
199 tree->SetBranchAddress("z0sum_23", &z0sum_23_value);
200 tree->SetBranchAddress("z0sumSq_23", &z0sumSq_23_value);
201 tree->SetBranchAddress("z0_23_mean", &z0mean_23_value);
202 tree->SetBranchAddress("z0_23_rms", &z0rms_23_value);
203
204 tree->SetBranchAddress("dphimin_12", &dphimin_12_value);
205 tree->SetBranchAddress("dphimax_12", &dphimax_12_value);
206 tree->SetBranchAddress("dphisum_12", &dphisum_12_value);
207 tree->SetBranchAddress("dphisumSq_12", &dphisumSq_12_value);
208 tree->SetBranchAddress("dphi_12_mean", &dphimean_12_value);
209 tree->SetBranchAddress("dphi_12_rms", &dphirms_12_value);
210 tree->SetBranchAddress("dphimin_23", &dphimin_23_value);
211 tree->SetBranchAddress("dphimax_23", &dphimax_23_value);
212 tree->SetBranchAddress("dphisum_23", &dphisum_23_value);
213 tree->SetBranchAddress("dphisumSq_23", &dphisumSq_23_value);
214 tree->SetBranchAddress("dphi_23_mean", &dphimean_23_value);
215 tree->SetBranchAddress("dphi_23_rms", &dphirms_23_value);
216
217 tree->SetBranchAddress("phiSlopemin_12", &phiSlopemin_12_value);
218 tree->SetBranchAddress("phiSlopemax_12", &phiSlopemax_12_value);
219 tree->SetBranchAddress("phiSlopesum_12", &phiSlopesum_12_value);
220 tree->SetBranchAddress("phiSlopesumSq_12", &phiSlopesumSq_12_value);
221 tree->SetBranchAddress("phiSlope_12_mean", &phiSlopemean_12_value);
222 tree->SetBranchAddress("phiSlope_12_rms", &phiSloperms_12_value);
223 tree->SetBranchAddress("phiSlopemin_23", &phiSlopemin_23_value);
224 tree->SetBranchAddress("phiSlopemax_23", &phiSlopemax_23_value);
225 tree->SetBranchAddress("phiSlopesum_23", &phiSlopesum_23_value);
226 tree->SetBranchAddress("phiSlopesumSq_23", &phiSlopesumSq_23_value);
227 tree->SetBranchAddress("phiSlope_23_mean", &phiSlopemean_23_value);
228 tree->SetBranchAddress("phiSlope_23_rms", &phiSloperms_23_value);
229
230 tree->SetBranchAddress("detamin_12", &detamin_12_value);
231 tree->SetBranchAddress("detamax_12", &detamax_12_value);
232 tree->SetBranchAddress("detasum_12", &detasum_12_value);
233 tree->SetBranchAddress("detasumSq_12", &detasumSq_12_value);
234 tree->SetBranchAddress("deta_12_mean", &detamean_12_value);
235 tree->SetBranchAddress("deta_12_rms", &detarms_12_value);
236 tree->SetBranchAddress("detamin_23", &detamin_23_value);
237 tree->SetBranchAddress("detamax_23", &detamax_23_value);
238 tree->SetBranchAddress("detasum_23", &detasum_23_value);
239 tree->SetBranchAddress("detasumSq_23", &detasumSq_23_value);
240 tree->SetBranchAddress("deta_23_mean", &detamean_23_value);
241 tree->SetBranchAddress("deta_23_rms", &detarms_23_value);
242
243 tree->SetBranchAddress("diff_dzdr_min", &diff_dzdr_min_value);
244 tree->SetBranchAddress("diff_dzdr_max", &diff_dzdr_max_value);
245 tree->SetBranchAddress("diff_dzdr_sum", &diff_dzdr_sum_value);
246 tree->SetBranchAddress("diff_dzdr_sumSq", &diff_dzdr_sumSq_value);
247 tree->SetBranchAddress("diff_dzdr_mean", &diff_dzdr_mean_value);
248 tree->SetBranchAddress("diff_dzdr_rms", &diff_dzdr_rms_value);
249
250 tree->SetBranchAddress("diff_dydx_min", &diff_dydx_min_value);
251 tree->SetBranchAddress("diff_dydx_max", &diff_dydx_max_value);
252 tree->SetBranchAddress("diff_dydx_sum", &diff_dydx_sum_value);
253 tree->SetBranchAddress("diff_dydx_sumSq", &diff_dydx_sumSq_value);
254 tree->SetBranchAddress("diff_dydx_mean", &diff_dydx_mean_value);
255 tree->SetBranchAddress("diff_dydx_rms", &diff_dydx_rms_value);
256
257 int64_t nEntries = tree->GetEntries();
258 m_cfgs.reserve(nEntries);
259 for (int64_t i = 0; i < nEntries; ++i) {
260 tree->GetEntry(i);
261
262 ModuleMapConfig cfg;
263
264 cfg.mid1 = mid1_value;
265 cfg.mid2 = mid2_value;
266 cfg.mid3 = mid3_value;
267 cfg.occurence = occurence_value;
268
269 cfg.doubletCuts[0].z0.min = z0min_12_value;
270 cfg.doubletCuts[0].z0.max = z0max_12_value;
271 cfg.doubletCuts[0].z0.sum = z0sum_12_value;
272 cfg.doubletCuts[0].z0.sumSq = z0sumSq_12_value;
273 cfg.doubletCuts[0].z0.mean = z0mean_12_value;
274 cfg.doubletCuts[0].z0.rms = z0rms_12_value;
275
276 cfg.doubletCuts[1].z0.min = z0min_23_value;
277 cfg.doubletCuts[1].z0.max = z0max_23_value;
278 cfg.doubletCuts[1].z0.sum = z0sum_23_value;
279 cfg.doubletCuts[1].z0.sumSq = z0sumSq_23_value;
280 cfg.doubletCuts[1].z0.mean = z0mean_23_value;
281 cfg.doubletCuts[1].z0.rms = z0rms_23_value;
282
283 // dphi
284 cfg.doubletCuts[0].dphi.min = dphimin_12_value;
285 cfg.doubletCuts[0].dphi.max = dphimax_12_value;
286 cfg.doubletCuts[0].dphi.sum = dphisum_12_value;
287 cfg.doubletCuts[0].dphi.sumSq = dphisumSq_12_value;
288 cfg.doubletCuts[0].dphi.mean = dphimean_12_value;
289 cfg.doubletCuts[0].dphi.rms = dphirms_12_value;
290
291 cfg.doubletCuts[1].dphi.min = dphimin_23_value;
292 cfg.doubletCuts[1].dphi.max = dphimax_23_value;
293 cfg.doubletCuts[1].dphi.sum = dphisum_23_value;
294 cfg.doubletCuts[1].dphi.sumSq = dphisumSq_23_value;
295 cfg.doubletCuts[1].dphi.mean = dphimean_23_value;
296 cfg.doubletCuts[1].dphi.rms = dphirms_23_value;
297
298 // eta
299 cfg.doubletCuts[0].deta.min = detamin_12_value;
300 cfg.doubletCuts[0].deta.max = detamax_12_value;
301 cfg.doubletCuts[0].deta.sum = detasum_12_value;
302 cfg.doubletCuts[0].deta.sumSq = detasumSq_12_value;
303 cfg.doubletCuts[0].deta.mean = detamean_12_value;
304 cfg.doubletCuts[0].deta.rms = detarms_12_value;
305
306 cfg.doubletCuts[1].deta.min = detamin_23_value;
307 cfg.doubletCuts[1].deta.max = detamax_23_value;
308 cfg.doubletCuts[1].deta.sum = detasum_23_value;
309 cfg.doubletCuts[1].deta.sumSq = detasumSq_23_value;
310 cfg.doubletCuts[1].deta.mean = detamean_23_value;
311 cfg.doubletCuts[1].deta.rms = detarms_23_value;
312
313 // slope
314 cfg.doubletCuts[0].phiSlope.min = phiSlopemin_12_value;
315 cfg.doubletCuts[0].phiSlope.max = phiSlopemax_12_value;
316 cfg.doubletCuts[0].phiSlope.sum = phiSlopesum_12_value;
317 cfg.doubletCuts[0].phiSlope.sumSq = phiSlopesumSq_12_value;
318 cfg.doubletCuts[0].phiSlope.mean = phiSlopemean_12_value;
319 cfg.doubletCuts[0].phiSlope.rms = phiSloperms_12_value;
320
321 cfg.doubletCuts[1].phiSlope.min = phiSlopemin_23_value;
322 cfg.doubletCuts[1].phiSlope.max = phiSlopemax_23_value;
323 cfg.doubletCuts[1].phiSlope.sum = phiSlopesum_23_value;
324 cfg.doubletCuts[1].phiSlope.sumSq = phiSlopesumSq_23_value;
325 cfg.doubletCuts[1].phiSlope.mean = phiSlopemean_23_value;
326 cfg.doubletCuts[1].phiSlope.rms = phiSloperms_23_value;
327
328 // extras
329 cfg.tripletCuts.diff_dzdr.min = diff_dzdr_min_value;
330 cfg.tripletCuts.diff_dzdr.max = diff_dzdr_max_value;
331 cfg.tripletCuts.diff_dzdr.sum = diff_dzdr_sum_value;
332 cfg.tripletCuts.diff_dzdr.sumSq = diff_dzdr_sumSq_value;
333 cfg.tripletCuts.diff_dzdr.mean = diff_dzdr_mean_value;
334 cfg.tripletCuts.diff_dzdr.rms = diff_dzdr_rms_value;
335
336 cfg.tripletCuts.diff_dydx.min = diff_dydx_min_value;
337 cfg.tripletCuts.diff_dydx.max = diff_dydx_max_value;
338 cfg.tripletCuts.diff_dydx.sum = diff_dydx_sum_value;
339 cfg.tripletCuts.diff_dydx.sumSq = diff_dydx_sumSq_value;
340 cfg.tripletCuts.diff_dydx.mean = diff_dydx_mean_value;
341 cfg.tripletCuts.diff_dydx.rms = diff_dydx_rms_value;
342
343 m_cfgs.emplace_back(cfg);
344 m_tripletMap.emplace(TripletKey{cfg.mid1, cfg.mid2, cfg.mid3}, &m_cfgs.back());
345 }
346}
347
348void FPGATrackSimGNNGraphConstructionTool::doModuleMap(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
349{
350 // Use Module Map method for edge building
351 // Two types of module maps: Doublet and Triplet
352 // For each type of module map there is three functions: minmax, meanrms, and hybrid
353 // Use the proper configuration set by the input script and passed as Gaudi::Property variables
354 // Currently only Doublet Module Map with minmax cuts exist, but others can be implemented later on as desired
355
356 if(m_moduleMapType == "doublet") {
357 getDoubletEdges(hits, edges, 0);
358 }
359 else if(m_moduleMapType == "triplet") {
360 getTripletEdges(hits, edges);
361 }
362}
363
364void FPGATrackSimGNNGraphConstructionTool::getTripletEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
365{
366 std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> edges_12;
367 std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> edges_23;
368
369 getDoubletEdges(hits, edges_12, 0); // Build doublet edges between Module1 and Module2
370 getDoubletEdges(hits, edges_23, 1); // Build doublet edges between Module2 and Module3
371
372 std::unordered_map<int, std::vector<const FPGATrackSimGNNEdge*>> edge23_by_hit1;
373 for (const auto& e : edges_23) edge23_by_hit1[e->getEdgeIndex1()].push_back(e.get());
374
375 std::unordered_set<uint64_t> seen;
376 auto pack = [](int a, int b) -> uint64_t { return (uint64_t(a) << 32) | uint32_t(b); };
377
378 for (const auto& edge_12 : edges_12) {
379 int hit1 = edge_12->getEdgeIndex1();
380 int hit2 = edge_12->getEdgeIndex2();
381 const auto& h1 = hits[hit1];
382 const auto& h2 = hits[hit2];
383 unsigned mid1 = h1->getIdentifierHash();
384 unsigned mid2 = h2->getIdentifierHash();
385
386 auto it = edge23_by_hit1.find(hit2);
387 if (it == edge23_by_hit1.end()) continue;
388
389 for (const auto* edge_23 : it->second) { // Loop over all edges_23 that have a hit1 that matches to hit2 from edges_12
390 int hit3 = edge_23->getEdgeIndex2();
391 const auto& h3 = hits[hit3];
392 unsigned mid3 = h3->getIdentifierHash();
393
394 TripletKey key{mid1, mid2, mid3};
395 auto cfg_it = m_tripletMap.find(key);
396 if (cfg_it == m_tripletMap.end()) continue;
397
398 // Check doublet cuts with the cfg SPECIFIC to this triplet
399 // for BOTH legs — not the pre-built edge lists
400 const auto& cfg = *cfg_it->second;
401 if (!applyDoubletCuts(h1, h2, cfg.doubletCuts[0])) continue;
402 if (!applyDoubletCuts(h2, h3, cfg.doubletCuts[1])) continue;
403 if (!applyTripletCuts(h1, h2, h3, cfg.tripletCuts)) continue;
404
405 seen.insert(pack(hit1, hit2));
406 seen.insert(pack(hit2, hit3));
407 }
408 }
409 edges.clear();
410 edges.reserve(seen.size());
411
412 for (uint64_t packed : seen) {
413 int index1 = static_cast<int>(packed >> 32);
414 int index2 = static_cast<int>(packed & 0xFFFFFFFF);
415 auto edge = std::make_shared<FPGATrackSimGNNEdge>();
416 edge->setEdgeIndex1(index1);
417 edge->setEdgeIndex2(index2);
418 edges.emplace_back(std::move(edge));
419 }
420}
421
422void FPGATrackSimGNNGraphConstructionTool::getDoubletEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges, int cutIndex)
423{
424 // Take the list of hits and use the doublet module map to generate all the edges between hits that pass the doublet cuts
425
426 std::unordered_map<unsigned, std::vector<int>> hits_by_module;
427
428 for (size_t i = 0; i < hits.size(); i++) {
429 hits_by_module[hits[i]->getIdentifierHash()].push_back(i);
430 }
431
432 for (const auto& cfg : m_cfgs) {
433 unsigned midA{}, midB{};
434 const ModuleMapConfig::DoubletCuts& cuts = cfg.doubletCuts[cutIndex];
435
436 if (cutIndex == 0) {
437 midA = cfg.mid1;
438 midB = cfg.mid2;
439 }
440 else if (cutIndex == 1) {
441 midA = cfg.mid2;
442 midB = cfg.mid3;
443 }
444
445 const auto& hit1_indices = hits_by_module[midA];
446 const auto& hit2_indices = hits_by_module[midB];
447
448 for (size_t h1 = 0; h1 < hit1_indices.size(); h1++) {
449 for (size_t h2 = 0; h2 < hit2_indices.size(); h2++) {
450 int i1 = hit1_indices[h1];
451 int i2 = hit2_indices[h2];
452 const auto& hit1 = hits[i1];
453 const auto& hit2 = hits[i2];
454 if (!applyDoubletCuts(hit1, hit2, cuts)) continue;
455 auto edge = std::make_shared<FPGATrackSimGNNEdge>();
456 edge->setEdgeIndex1(i1);
457 edge->setEdgeIndex2(i2);
458 edges.emplace_back(std::move(edge));
459 }
460 }
461 }
462}
463
464bool FPGATrackSimGNNGraphConstructionTool::applyDoubletCuts(const std::shared_ptr<FPGATrackSimGNNHit> & hit1, const std::shared_ptr<FPGATrackSimGNNHit> & hit2, const ModuleMapConfig::DoubletCuts& cuts)
465{
466 // Four types of doublet cuts (dEta, z0, dPhi, phiSlope)
467 // If an edge passes all four, then it is a valid edge and can be stored
468
469 // delta_eta cuts
470 float deta = hit2->getEta() - hit1->getEta();
471 if(!doMask(deta, cuts.deta)) return false;
472
473 // z0 cuts
474 float dz = hit2->getZ() - hit1->getZ();
475 float dr = hit2->getR() - hit1->getR();
476 float z0 = dr==0. ? 0. : hit1->getZ() - (hit1->getR() * dz / dr);
477 if(!doMask(z0, cuts.z0)) return false;
478
479 // delta_phi cuts
480 float dphi = P4Helpers::deltaPhi(hit2->getPhi(),hit1->getPhi()); // Look into this issue
481 if(!doMask(dphi, cuts.dphi)) return false;
482
483 // phislope cuts
484 float phiSlope = dr==0. ? 0. : dphi / dr;
485 if(!doMask(phiSlope, cuts.phiSlope)) return false;
486
487 return true;
488}
489
490bool FPGATrackSimGNNGraphConstructionTool::applyTripletCuts(const std::shared_ptr<FPGATrackSimGNNHit> & hit1, const std::shared_ptr<FPGATrackSimGNNHit> & hit2, const std::shared_ptr<FPGATrackSimGNNHit> & hit3, const ModuleMapConfig::TripletCuts& cuts)
491{
492 auto safeDiv = [&](float dA, float dB) { return (dB == 0.) ? 0.0f : (dA / dB); };
493
494 // Diff dydx
495 float dy_12 = hit2->getY() - hit1->getY();
496 float dy_23 = hit3->getY() - hit2->getY();
497 float dx_12 = hit2->getX() - hit1->getX();
498 float dx_23 = hit3->getX() - hit2->getX();
499
500 float diff_dydx = safeDiv(dy_12, dx_12) - safeDiv(dy_23, dx_23);
501 if(!doMask(diff_dydx, cuts.diff_dydx)) return false; // Fails the dydx cut
502
503 // Diff dzdr
504 float dz_12 = hit2->getZ() - hit1->getZ();
505 float dz_23 = hit3->getZ() - hit2->getZ();
506 float dr_12 = hit2->getR() - hit1->getR();
507 float dr_23 = hit3->getR() - hit2->getR();
508
509 float diff_dzdr = safeDiv(dz_12, dr_12) - safeDiv(dz_23, dr_23);
510 if(!doMask(diff_dzdr, cuts.diff_dzdr)) return false; // Fails the drdz cut
511
512 return true; // Passes both triplet cuts
513}
514
516{
517 if(m_moduleMapFunc == "minmax") {
518 return doMinMaxMask(val, cuts);
519 }
520 else if(m_moduleMapFunc == "meanrms") {
521 return doMeanRMSMask(val, cuts);
522 }
523 else {
524 ATH_MSG_ERROR("Chosen module map function is not minmax/meanrms which are the only types supported currently.");
525 return false;
526 }
527}
528
530{
531 return (val <= cuts.max * (1.0 + featureSign(cuts.max) * m_moduleMapTol)) && (val >= cuts.min * (1.0 - featureSign(cuts.min) * m_moduleMapTol));
532}
533
535{
536 float min_rms = cuts.mean - cuts.rms * m_moduleMapRMSThresholdFactor;
537 float max_rms = cuts.mean + cuts.rms * m_moduleMapRMSThresholdFactor;
538 float tol_min = cuts.min * (1.0 - featureSign(cuts.min) * m_moduleMapTol);
539 float tol_max = cuts.max * (1.0 + featureSign(cuts.max) * m_moduleMapTol);
540
541 float capped_min = std::max(tol_min, min_rms);
542 float capped_max = std::min(tol_max, max_rms);
543
544 return (val <= capped_max) && (val >= capped_min);
545}
546
548{
549 if(feature < 0.0) { return -1.0; }
550 else if(feature > 0.0) { return 1.0; }
551 else { return 0.0; }
552}
553
554void FPGATrackSimGNNGraphConstructionTool::doMetricLearning(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
555{
556 // Use Metric Learning for edge construction
557 // Clustering properties can be set in the input scripta as Gaudi::Property variables
558 std::vector<float> gNodeFeatures = getNodeFeatures(hits);
559 std::vector<float> gEmbedded = embed(hits);
560 doClustering(hits, edges, gEmbedded);
561}
562
563std::vector<float> FPGATrackSimGNNGraphConstructionTool::getNodeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits)
564{
565 std::vector<float> gNodeFeatures;
566
567 for(auto hit : hits) {
568 std::map<std::string, float> features;
569 features["r"] = hit->getR();
570 features["phi"] = hit->getPhi();
571 features["z"] = hit->getZ();
572
573 for(size_t i = 0; i < m_MLFeatureNamesVec.size(); i++){
574 gNodeFeatures.push_back(
575 features[m_MLFeatureNamesVec[i]] / m_MLFeatureScalesVec[i]);
576 }
577 }
578 return gNodeFeatures;
579}
580
581std::vector<float> FPGATrackSimGNNGraphConstructionTool::embed(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits)
582{
583 // Use the ML network to embed the hits in a 12-dim latent space
584 std::vector<float> gNodeFeatures = getNodeFeatures(hits);
585 std::vector<float> gEmbedded;
586
587 std::vector<Ort::Value> gInputTensor;
588 StatusCode s = m_MLInferenceTool->addInput(gInputTensor, gNodeFeatures, 0, hits.size());
589 std::vector<Ort::Value> gOutputTensor;
590 s = m_MLInferenceTool->addOutput(gOutputTensor, gEmbedded, 0, hits.size());
591 s = m_MLInferenceTool->inference(gInputTensor, gOutputTensor);
592
593 return gEmbedded;
594}
595
596void FPGATrackSimGNNGraphConstructionTool::doClustering(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges,
597 std::vector<float> & gEmbedded)
598{
599 // Create graph edges based on the hits distance in the latent space
600 // Creates a directed graph
601 int n_dim = 12;
602 int size = hits.size();
603 float r_squared = m_metricLearningR*m_metricLearningR;
604 int index1 = 0;
605 int index2 = 0;
606 int count = 0;
607 float distance = 0.;
608 std::vector<float> start(n_dim);
609
610 // Loop over all hits
611 for(int k = 0; k < size; ++k){
612 count = 0;
613 // Setup current hit
614 for(int j = 0; j < n_dim; ++j){
615 start[j] = gEmbedded[k*n_dim + j];
616 }
617 // Loop over the hits not yet checked
618 for (int i = k + 1; i < size; ++i){
619 distance = 0.;
620 for(int d = 0; d < n_dim; ++d){
621 distance += (start[d] - gEmbedded[i*n_dim + d]) * (start[d] - gEmbedded[i*n_dim + d]);
622 }
623 // Store edge if the distance between the hits meets is below the limit
624 if(distance < r_squared){
625 std::shared_ptr<FPGATrackSimGNNEdge> edge = std::make_shared<FPGATrackSimGNNEdge>();
626 // Set order of edge indices to make a directed graph
627 float d_i_sq = (hits[i]->getR() * hits[i]->getR()) + (hits[i]->getZ() * hits[i]->getZ());
628 float d_k_sq = (hits[k]->getR() * hits[k]->getR()) + (hits[k]->getZ() * hits[k]->getZ());
629 if (d_i_sq < d_k_sq){
630 index1 = i;
631 index2 = k;
632 } else {
633 index1 = k;
634 index2 = i;
635 }
636
637 edge->setEdgeIndex1(index1);
638 edge->setEdgeIndex2(index2);
639 edges.emplace_back(edge);
640 ++count;
641 }
642 // Upper limit for connections of the same hit
644 break;
645 }
646 }
647
648 }
649}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
Implements graph construction tool to build edges (connections) between hits.
static Double_t a
size_t size() const
Number of registered mappings.
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
bool doMask(float val, const ModuleMapConfig::FeatureCuts &cuts)
bool doMinMaxMask(float val, const ModuleMapConfig::FeatureCuts &cuts)
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_MLInferenceTool
bool applyTripletCuts(const std::shared_ptr< FPGATrackSimGNNHit > &hit1, const std::shared_ptr< FPGATrackSimGNNHit > &hit2, const std::shared_ptr< FPGATrackSimGNNHit > &hit3, const ModuleMapConfig::TripletCuts &cuts)
virtual StatusCode getEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
void getDoubletEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges, int cutIndex)
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits)
void getTripletEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
std::vector< float > embed(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits)
void doClustering(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges, std::vector< float > &gEmbedded)
void doMetricLearning(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
FPGATrackSimGNNGraphConstructionTool(const std::string &, const std::string &, const IInterface *)
void doModuleMap(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
bool doMeanRMSMask(float val, const ModuleMapConfig::FeatureCuts &cuts)
bool applyDoubletCuts(const std::shared_ptr< FPGATrackSimGNNHit > &hit1, const std::shared_ptr< FPGATrackSimGNNHit > &hit2, const ModuleMapConfig::DoubletCuts &cuts)
ServiceHandle< IFPGATrackSimMappingSvc > m_FPGATrackSimMapping
std::unordered_map< TripletKey, const ModuleMapConfig *, TripletKeyHash > m_tripletMap
int count(std::string s, const std::string &regx)
count how many occurances of a regx are in a string
Definition hcg.cxx:148
double deltaPhi(double phiA, double phiB)
delta Phi in range [-pi,pi[
Definition P4Helpers.h:34
TChain * tree
TFile * file