ATLAS Offline Software
Loading...
Searching...
No Matches
SegmentEdgeClassifierTool.cxx
Go to the documentation of this file.
2#include "InferenceUtils.h"
6#include "GaudiKernel/SystemOfUnits.h"
7#include <nlohmann/json.hpp>
8#include <algorithm>
9#include <array>
10#include <cmath>
11#include <fstream>
12#include <mutex>
13#include <map>
14#include <optional>
15#include <sstream>
16#include <tuple>
17#include <unordered_map>
18#include <unordered_set>
19
20namespace {
21using SegmentGroupKey = std::tuple<int, int, int>; // sector, chamberIndex, etaIndex
22
23SegmentGroupKey segmentGroupKey(const xAOD::MuonSegment& seg) {
24 return {seg.sector(), static_cast<int>(seg.chamberIndex()), seg.etaIndex()};
25}
26
27int segmentLayerCount(const xAOD::MuonSegment& seg) {
28 return seg.nPrecisionHits() + seg.nPhiLayers() + seg.nTrigEtaLayers();
29}
30
36inline int sectorDistance(int a, int b, int mod) {
37 int d = std::abs(a - b);
38 return mod > 0 ? std::min(d, mod - d) : d;
39}
40
41std::optional<MuonML::SegmentNodeFeatureId> nodeFeatureIdFromName(const std::string& name) {
42 using FeatureId = MuonML::SegmentNodeFeatureId;
43 if (name == "segmentPositionX_m") return FeatureId::SegmentPositionX;
44 if (name == "segmentPositionY_m") return FeatureId::SegmentPositionY;
45 if (name == "segmentPositionZ_m") return FeatureId::SegmentPositionZ;
46 if (name == "segmentDirectionX") return FeatureId::SegmentDirectionX;
47 if (name == "segmentDirectionY") return FeatureId::SegmentDirectionY;
48 if (name == "segmentDirectionZ") return FeatureId::SegmentDirectionZ;
49 if (name == "bucket_chamberIndex") return FeatureId::BucketChamberIndex;
50 if (name == "bucket_layers") return FeatureId::BucketLayers;
51 if (name == "bucket_sector") return FeatureId::BucketSector;
52 if (name == "bucket_segments") return FeatureId::BucketSegments;
53 return std::nullopt;
54}
55
56float nodeFeatureValue(MuonML::SegmentNodeFeatureId feature,
57 const Amg::Vector3D& pos,
58 const Amg::Vector3D& dir,
59 const MuonML::BucketSegmentFeatures& bucket) {
60 using FeatureId = MuonML::SegmentNodeFeatureId;
61 switch (feature) {
62 case FeatureId::SegmentPositionX: return static_cast<float>(pos.x());
63 case FeatureId::SegmentPositionY: return static_cast<float>(pos.y());
64 case FeatureId::SegmentPositionZ: return static_cast<float>(pos.z());
65 case FeatureId::SegmentDirectionX: return static_cast<float>(dir.x());
66 case FeatureId::SegmentDirectionY: return static_cast<float>(dir.y());
67 case FeatureId::SegmentDirectionZ: return static_cast<float>(dir.z());
68 case FeatureId::BucketChamberIndex: return static_cast<float>(bucket.chamberIndex);
69 case FeatureId::BucketLayers: return static_cast<float>(bucket.layers);
70 case FeatureId::BucketSector: return static_cast<float>(bucket.sector);
71 case FeatureId::BucketSegments: return static_cast<float>(bucket.nSegments);
72 }
73 return 0.f;
74}
75}
76
77namespace MuonML {
78
81
82 // Resolve node feature names from model metadata, matching the ONNX exporter.
83 {
84 Ort::AllocatorWithDefaultOptions allocator;
85 Ort::ModelMetadata meta = model().GetModelMetadata();
86 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
87 std::vector<std::string> keyList;
88 keyList.reserve(keys.size());
89 for (const auto& k : keys) keyList.emplace_back(k.get());
90
91 constexpr std::array<std::string_view, 4> candidates{
92 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
93 std::string usedKey;
94 std::vector<std::string> names;
95 for (std::string_view key : candidates) {
96 const std::string keyStr{key};
97 if (std::find(keyList.begin(), keyList.end(), keyStr) == keyList.end()) continue;
98 names = parseFeatureNames(meta.LookupCustomMetadataMapAllocated(keyStr.c_str(), allocator).get());
99 if (!names.empty()) {
100 usedKey = keyStr;
101 break;
102 }
103 }
104
105 if (names.empty()) {
107 ATH_MSG_WARNING("Model metadata has no usable node feature name key"
108 " (tried x_feature_names/node_feature_names/feature_names/input_feature_names)."
109 " Falling back to default training order.");
110 } else {
111 if (names.size() != kNodeFeatureCount) {
112 ATH_MSG_ERROR("Model metadata key '" << usedKey << "' has " << names.size()
113 << " features, expected " << kNodeFeatureCount);
114 return StatusCode::FAILURE;
115 }
116 for (const std::string& n : names) {
117 if (!nodeFeatureIdFromName(n).has_value()) {
118 ATH_MSG_ERROR("Unsupported node feature name in model metadata ('" << usedKey
119 << "'): '" << n << "'."
120 " Add mapping in SegmentEdgeClassifierTool::nodeFeatureValue().");
121 return StatusCode::FAILURE;
122 }
123 }
124 m_nodeFeatureNames = std::move(names);
125 ATH_MSG_DEBUG("Using node feature names from model metadata key '" << usedKey << "'.");
126 }
127
128 m_nodeFeatureIds.reserve(m_nodeFeatureNames.size());
129 for (const std::string& n : m_nodeFeatureNames) {
130 const auto id = nodeFeatureIdFromName(n);
131 if (!id.has_value()) {
132 ATH_MSG_ERROR("Internal feature-id resolution failed for node feature name '" << n << "'.");
133 return StatusCode::FAILURE;
134 }
135 m_nodeFeatureIds.push_back(*id);
136 }
137
138 std::ostringstream order;
139 order << "Node feature order:";
140 for (std::size_t i = 0; i < m_nodeFeatureNames.size(); ++i) {
141 order << " f" << i << "=" << m_nodeFeatureNames[i];
142 if (i + 1 < m_nodeFeatureNames.size()) order << ",";
143 }
144 ATH_MSG_DEBUG(order.str());
145 }
146
148 ATH_MSG_ERROR("Internal node feature setup has " << m_nodeFeatureNames.size()
149 << " entries, expected " << kNodeFeatureCount);
150 return StatusCode::FAILURE;
151 }
152 if (m_nodeFeatureIds.size() != kNodeFeatureCount) {
153 ATH_MSG_ERROR("Internal node feature id setup has " << m_nodeFeatureIds.size()
154 << " entries, expected " << kNodeFeatureCount);
155 return StatusCode::FAILURE;
156 }
157
158 m_cosMin = std::cos(m_maxDeltaThetaDeg.value() * Gaudi::Units::deg);
159
160 if (!m_debugDumpFile.value().empty()) {
161 std::ofstream out{m_debugDumpFile.value(), std::ios::out | std::ios::trunc};
162 if (!out) {
163 ATH_MSG_ERROR("Could not create segment-edge debug dump file: "
164 << m_debugDumpFile.value());
165 return StatusCode::FAILURE;
166 }
167
168 nlohmann::ordered_json metadata;
169 metadata["record_type"] = "metadata";
170 metadata["format_version"] = 1;
171 metadata["tool"] = "SegmentEdgeClassifierTool";
172 metadata["input_names"] = {m_inputNodeName.value(),
173 m_inputEdgeIndexName.value(),
174 m_inputEdgeAttrName.value()};
175 metadata["output_name"] = m_outputName.value();
176 metadata["x_feature_names"] = m_nodeFeatureNames;
177 metadata["edge_attr_feature_names"] = {
178 "deltaPositionX_m", "deltaPositionY_m", "deltaPositionZ_m",
179 "distance_m", "cos_opening_angle", "same_chamber", "same_sector"};
180 metadata["edge_index_layout"] = "row_major_2_by_E";
181 metadata["edge_order"] = "directed src_to_dst; row 0 then row 1";
182 metadata["max_delta_theta_deg"] = m_maxDeltaThetaDeg.value();
183 metadata["max_delta_sector"] = m_maxDeltaSector.value();
184 metadata["sector_modulo"] = m_sectorModulo.value();
185 metadata["debug_dump_max_events"] = m_debugDumpMaxEvents.value();
186 out << metadata.dump() << '\n';
187
188 ATH_MSG_INFO("Writing segment-edge ONNX debug dump to "
189 << m_debugDumpFile.value()
190 << " (DebugDumpMaxEvents="
191 << m_debugDumpMaxEvents.value() << ")");
192 }
193
194 return StatusCode::SUCCESS;
195}
196
197StatusCode SegmentEdgeClassifierTool::runGraphInference(const EventContext&, GraphRawData&) const {
198 ATH_MSG_ERROR("runGraphInference is not supported by SegmentEdgeClassifierTool. Use SegmentEdgeInferenceAlg + ISegmentEdgeClassifierTool methods.");
199 return StatusCode::FAILURE;
200}
201
202StatusCode SegmentEdgeClassifierTool::buildGraph(const EventContext&, const xAOD::MuonSegmentContainer& segments, SegmentEdgeGraph& graph) const {
203 graph = SegmentEdgeGraph{};
204 graph.nNodes = segments.size();
205 graph.segments.reserve(graph.nNodes);
206 graph.nodeFeatures.reserve(graph.nNodes * kNodeFeatureCount);
207
208 std::vector<Amg::Vector3D> pos, dir;
209 std::vector<BucketSegmentFeatures> bucket;
210 pos.reserve(graph.nNodes); dir.reserve(graph.nNodes); bucket.reserve(graph.nNodes);
211
212 std::map<SegmentGroupKey, int> segmentMultiplicity{};
213 for (const xAOD::MuonSegment* seg : segments) {
214 if (!seg) continue;
215 ++segmentMultiplicity[segmentGroupKey(*seg)];
216 }
217
218 for (const xAOD::MuonSegment* seg : segments) {
219 if (!seg) continue;
220 const Amg::Vector3D p = seg->position();
221 Amg::Vector3D d = seg->direction();
222
223 const int chamberIdx = static_cast<int>(seg->chamberIndex());
224 const int layers = segmentLayerCount(*seg);
225 const int sec = seg->sector();
226 const auto multIt = segmentMultiplicity.find(segmentGroupKey(*seg));
227 const int nSeg = (multIt != segmentMultiplicity.end()) ? multIt->second : 1;
228
229 graph.segments.push_back(seg);
230 pos.emplace_back(p.x() / Gaudi::Units::m,
231 p.y() / Gaudi::Units::m,
232 p.z() / Gaudi::Units::m);
233 dir.emplace_back(d.x(), d.y(), d.z());
234 bucket.emplace_back(BucketSegmentFeatures{chamberIdx, layers, sec, nSeg});
235 for (const SegmentNodeFeatureId featureId : m_nodeFeatureIds) {
236 graph.nodeFeatures.push_back(nodeFeatureValue(featureId, pos.back(), dir.back(), bucket.back()));
237 }
238 }
239 graph.nNodes = graph.segments.size();
240
241 // Consistency check: all vectors must have same size
242 if (pos.size() != graph.nNodes || dir.size() != graph.nNodes || bucket.size() != graph.nNodes) {
243 ATH_MSG_ERROR("Inconsistent vector sizes during graph building: nodes=" << graph.nNodes
244 << ", pos=" << pos.size() << ", dir=" << dir.size() << ", bucket=" << bucket.size());
245 return StatusCode::FAILURE;
246 }
247
248 if (graph.nNodes < 2) {
249 graph.nEdges = 0;
250 return StatusCode::SUCCESS;
251 }
252
253 std::unordered_map<int, std::vector<std::size_t>> nodesBySector;
254 nodesBySector.reserve(graph.nNodes);
255 for (std::size_t i = 0; i < graph.nNodes; ++i) {
256 nodesBySector[bucket[i].sector].push_back(i);
257 }
258
259 auto normalizeSector = [&](int s) {
260 // m_sectorModulo > 0: wrap sector to [0, modulo); <=0: disable wrapping
261 if (m_sectorModulo.value() > 0) {
262 s %= m_sectorModulo.value();
263 if (s < 0) s += m_sectorModulo.value();
264 }
265 return s;
266 };
267
268 const std::size_t maxEdges = graph.nNodes * (graph.nNodes - 1);
269 graph.edgeIndex.reserve(2 * maxEdges);
270 graph.edgeFeatures.reserve(kEdgeFeatureCount * maxEdges);
271
272 for (std::size_t i = 0; i < graph.nNodes; ++i) {
273 std::unordered_set<int> targetSectors;
274 targetSectors.reserve(2 * m_maxDeltaSector.value() + 1);
275 for (int delta = -m_maxDeltaSector.value(); delta <= m_maxDeltaSector.value(); ++delta) {
276 targetSectors.insert(normalizeSector(bucket[i].sector + delta));
277 }
278
279 for (const int sec : targetSectors) {
280 auto it = nodesBySector.find(sec);
281 if (it == nodesBySector.end()) continue;
282 for (const std::size_t j : it->second) {
283 if (i == j) continue;
284 if (sectorDistance(bucket[i].sector, bucket[j].sector, m_sectorModulo.value()) > m_maxDeltaSector.value()) continue;
285 const float cosang = static_cast<float>(dir[i].dot(dir[j]));
286 if (cosang < m_cosMin) continue;
287
288 graph.edgeIndex.push_back(static_cast<int64_t>(i));
289 graph.edgeIndex.push_back(static_cast<int64_t>(j));
290
291 const Amg::Vector3D delta = pos[j] - pos[i];
292 const float dx = static_cast<float>(delta.x());
293 const float dy = static_cast<float>(delta.y());
294 const float dz = static_cast<float>(delta.z());
295 const float dist = static_cast<float>(delta.mag());
296 graph.edgeFeatures.insert(graph.edgeFeatures.end(), {dx,dy,dz,dist,cosang, float(bucket[i].chamberIndex==bucket[j].chamberIndex), float(bucket[i].sector==bucket[j].sector)});
297 }
298 }
299 }
300 graph.nEdges = graph.edgeIndex.size() / 2;
301 ATH_MSG_DEBUG("buildGraph: input segments=" << segments.size()
302 << ", kept nodes=" << graph.nNodes
303 << ", built edges=" << graph.nEdges);
304 return StatusCode::SUCCESS;
305}
306
307StatusCode SegmentEdgeClassifierTool::classifyEdges(const EventContext& ctx,
308 const SegmentEdgeGraph& graph,
309 std::vector<SegmentEdgeScore>& scores) const {
310 scores.clear();
311 if (!graph.nNodes) return StatusCode::SUCCESS;
312 if (!graph.nEdges) {
313 ATH_CHECK(dumpDebugEvent(ctx, graph, scores));
314 return StatusCode::SUCCESS;
315 }
316
317 if (graph.nodeFeatures.size() != graph.nNodes * kNodeFeatureCount) {
318 ATH_MSG_ERROR("Unexpected node feature size " << graph.nodeFeatures.size()
319 << "; expected " << (graph.nNodes * kNodeFeatureCount));
320 return StatusCode::FAILURE;
321 }
322 if (graph.edgeIndex.size() != 2 * graph.nEdges) {
323 ATH_MSG_ERROR("Unexpected edge index size " << graph.edgeIndex.size()
324 << "; expected " << (2 * graph.nEdges));
325 return StatusCode::FAILURE;
326 }
327 if (graph.edgeFeatures.size() != graph.nEdges * kEdgeFeatureCount) {
328 ATH_MSG_ERROR("Unexpected edge feature size " << graph.edgeFeatures.size()
329 << "; expected " << (graph.nEdges * kEdgeFeatureCount));
330 return StatusCode::FAILURE;
331 }
332
333 GraphRawData raw{};
334 raw.graph = std::make_unique<InferenceGraph>();
335 raw.featureLeaves = graph.nodeFeatures;
336 raw.edgeIndexPacked.reserve(2 * graph.nEdges);
337 raw.srcEdges.reserve(graph.nEdges);
338 raw.desEdges.reserve(graph.nEdges);
339 for (std::size_t e = 0; e < graph.nEdges; ++e) {
340 raw.srcEdges.push_back(graph.edgeIndex[2 * e]);
341 raw.desEdges.push_back(graph.edgeIndex[2 * e + 1]);
342 }
343 raw.edgeIndexPacked.insert(raw.edgeIndexPacked.end(), raw.srcEdges.begin(), raw.srcEdges.end());
344 raw.edgeIndexPacked.insert(raw.edgeIndexPacked.end(), raw.desEdges.begin(), raw.desEdges.end());
345
346 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
347
348 const std::vector<int64_t> nodeShape{static_cast<int64_t>(graph.nNodes), static_cast<int64_t>(kNodeFeatureCount)};
349 raw.graph->dataTensor.emplace_back(
350 Ort::Value::CreateTensor<float>(memInfo,
351 raw.featureLeaves.data(),
352 raw.featureLeaves.size(),
353 nodeShape.data(),
354 nodeShape.size()));
355
356 const std::vector<int64_t> edgeIndexShape{2, static_cast<int64_t>(graph.nEdges)};
357 raw.graph->dataTensor.emplace_back(
358 Ort::Value::CreateTensor<int64_t>(memInfo,
359 raw.edgeIndexPacked.data(),
360 raw.edgeIndexPacked.size(),
361 edgeIndexShape.data(),
362 edgeIndexShape.size()));
363
364 // ONNX Runtime's CreateTensor API takes a non-const pointer, but it does not
365 // mutate input buffers during inference. Avoid copying edge_attr every event.
366 ATLAS_THREAD_SAFE float* edgeFeaturesData = const_cast<float*>(graph.edgeFeatures.data());
367 const std::vector<int64_t> edgeAttrShape{static_cast<int64_t>(graph.nEdges), static_cast<int64_t>(kEdgeFeatureCount)};
368 raw.graph->dataTensor.emplace_back(
369 Ort::Value::CreateTensor<float>(memInfo,
370 edgeFeaturesData,
371 graph.edgeFeatures.size(),
372 edgeAttrShape.data(),
373 edgeAttrShape.size()));
374
375 const std::vector<const char*> inputNames{
376 m_inputNodeName.value().c_str(),
377 m_inputEdgeIndexName.value().c_str(),
378 m_inputEdgeAttrName.value().c_str()};
379 const std::vector<const char*> outputNames{m_outputName.value().c_str()};
380 ATH_MSG_DEBUG("classifyEdges: ONNX inputs shapes x=[" << nodeShape[0] << "," << nodeShape[1]
381 << "], edge_index=[" << edgeIndexShape[0] << "," << edgeIndexShape[1]
382 << "], edge_attr=[" << edgeAttrShape[0] << "," << edgeAttrShape[1] << "]");
383 ATH_CHECK(runNamedInference(raw, inputNames, outputNames));
384
385 if (raw.graph->dataTensor.size() <= inputNames.size()) {
386 ATH_MSG_ERROR("Missing ONNX output tensor for segment edge inference");
387 return StatusCode::FAILURE;
388 }
389
390 const Ort::Value& outTensor = raw.graph->dataTensor[inputNames.size()];
391 const auto outInfo = outTensor.GetTensorTypeAndShapeInfo();
392 const std::vector<int64_t> outShape = outInfo.GetShape();
393 const size_t outSize = outInfo.GetElementCount();
394 if (!outShape.empty()) {
395 ATH_MSG_DEBUG("classifyEdges: ONNX output rank=" << outShape.size()
396 << ", first dim=" << outShape.front()
397 << ", elements=" << outSize);
398 } else {
399 ATH_MSG_DEBUG("classifyEdges: ONNX scalar output, elements=" << outSize);
400 }
401 if (outSize < graph.nEdges) {
402 ATH_MSG_ERROR("ONNX logits tensor has " << outSize << " entries for " << graph.nEdges << " edges");
403 return StatusCode::FAILURE;
404 }
405
406 const float* logits = outTensor.GetTensorData<float>();
407 scores.reserve(graph.nEdges);
408 for (std::size_t e=0; e<graph.nEdges; ++e) {
409 const float l = logits[e];
410 scores.push_back({std::size_t(graph.edgeIndex[2 * e]),
411 std::size_t(graph.edgeIndex[2 * e + 1]),
412 l,
414 }
415
416 ATH_CHECK(dumpDebugEvent(ctx, graph, scores));
417 return StatusCode::SUCCESS;
418}
419
421 const EventContext& ctx,
422 const SegmentEdgeGraph& graph,
423 const std::vector<SegmentEdgeScore>& scores) const {
424 if (m_debugDumpFile.value().empty()) return StatusCode::SUCCESS;
425
426 std::lock_guard<std::mutex> lock{m_debugDumpMutex};
427 if (m_debugDumpMaxEvents.value() != 0 &&
428 m_debugDumpEvents.load(std::memory_order_relaxed) >=
429 m_debugDumpMaxEvents.value()) {
430 return StatusCode::SUCCESS;
431 }
432
433 if (graph.nodeFeatures.size() != graph.nNodes * kNodeFeatureCount ||
434 graph.edgeIndex.size() != graph.nEdges * 2 ||
435 graph.edgeFeatures.size() != graph.nEdges * kEdgeFeatureCount ||
436 scores.size() != graph.nEdges) {
437 ATH_MSG_ERROR("Cannot write segment-edge debug dump: inconsistent graph/output sizes"
438 << " nodes=" << graph.nNodes
439 << " nodeFeatures=" << graph.nodeFeatures.size()
440 << " edges=" << graph.nEdges
441 << " edgeIndex=" << graph.edgeIndex.size()
442 << " edgeFeatures=" << graph.edgeFeatures.size()
443 << " scores=" << scores.size());
444 return StatusCode::FAILURE;
445 }
446
447 nlohmann::json x = nlohmann::json::array();
448 x.get_ref<nlohmann::json::array_t&>().reserve(graph.nodeFeatures.size());
449 for (const float value : graph.nodeFeatures) {
450 x.push_back(std::isfinite(value) ? nlohmann::json(value)
451 : nlohmann::json(nullptr));
452 }
453
454 nlohmann::json edgeIndex = nlohmann::json::array();
455 edgeIndex.get_ref<nlohmann::json::array_t&>().reserve(graph.nEdges * 2);
456 // This is the actual ONNX [2,E] row-major buffer: all sources then all destinations.
457 for (std::size_t edge = 0; edge < graph.nEdges; ++edge) {
458 edgeIndex.push_back(graph.edgeIndex[2 * edge]);
459 }
460 for (std::size_t edge = 0; edge < graph.nEdges; ++edge) {
461 edgeIndex.push_back(graph.edgeIndex[2 * edge + 1]);
462 }
463
464 nlohmann::json edgeAttr = nlohmann::json::array();
465 edgeAttr.get_ref<nlohmann::json::array_t&>().reserve(graph.edgeFeatures.size());
466 for (const float value : graph.edgeFeatures) {
467 edgeAttr.push_back(std::isfinite(value) ? nlohmann::json(value)
468 : nlohmann::json(nullptr));
469 }
470
471 nlohmann::json logits = nlohmann::json::array();
472 nlohmann::json probabilities = nlohmann::json::array();
473 nlohmann::json edgeSrc = nlohmann::json::array();
474 nlohmann::json edgeDst = nlohmann::json::array();
475 logits.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
476 probabilities.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
477 edgeSrc.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
478 edgeDst.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
479 for (const SegmentEdgeScore& score : scores) {
480 edgeSrc.push_back(score.src);
481 edgeDst.push_back(score.dst);
482 logits.push_back(std::isfinite(score.logit) ? nlohmann::json(score.logit)
483 : nlohmann::json(nullptr));
484 probabilities.push_back(std::isfinite(score.probability)
485 ? nlohmann::json(score.probability)
486 : nlohmann::json(nullptr));
487 }
488
489 std::ofstream out{m_debugDumpFile.value(), std::ios::out | std::ios::app};
490 if (!out) {
491 ATH_MSG_ERROR("Could not append to segment-edge debug dump file: "
492 << m_debugDumpFile.value());
493 return StatusCode::FAILURE;
494 }
495
496 const unsigned int dumpIndex =
497 m_debugDumpEvents.fetch_add(1, std::memory_order_relaxed);
498 nlohmann::ordered_json event;
499 event["record_type"] = "event";
500 event["format_version"] = 1;
501 event["dump_index"] = dumpIndex;
502 event["run_number"] = ctx.eventID().run_number();
503 event["lumi_block"] = ctx.eventID().lumi_block();
504 event["event_number"] = ctx.eventID().event_number();
505 event["slot"] = ctx.slot();
506 event["n_nodes"] = graph.nNodes;
507 event["n_edges"] = graph.nEdges;
508 event["x_shape"] = {graph.nNodes, kNodeFeatureCount};
509 event["edge_index_shape"] = {2, graph.nEdges};
510 event["edge_attr_shape"] = {graph.nEdges, kEdgeFeatureCount};
511 event["logits_shape"] = {graph.nEdges};
512 event["x"] = std::move(x);
513 event["edge_index"] = std::move(edgeIndex);
514 event["edge_attr"] = std::move(edgeAttr);
515 event["edge_src"] = std::move(edgeSrc);
516 event["edge_dst"] = std::move(edgeDst);
517 event["logits"] = std::move(logits);
518 event["probabilities"] = std::move(probabilities);
519 out << event.dump() << '\n';
520
521 ATH_MSG_DEBUG("Wrote segment-edge debug event " << dumpIndex
522 << " to " << m_debugDumpFile.value());
523
524 return StatusCode::SUCCESS;
525}
526
527} // namespace MuonML
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
virtual void lock()=0
Interface to allow an object to lock itself when made const in SG.
static Double_t a
#define x
Define macros for attributes used to control the static checker.
#define ATLAS_THREAD_SAFE
size_type size() const noexcept
Returns the number of elements in the collection.
static constexpr std::array< std::string_view, kNodeFeatureCount > kDefaultNodeFeatureNames
static constexpr std::size_t kEdgeFeatureCount
static std::vector< std::string > parseFeatureNames(const std::string &raw)
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
static constexpr std::size_t kNodeFeatureCount
Gaudi::Property< unsigned int > m_debugDumpMaxEvents
Gaudi::Property< std::string > m_outputName
std::atomic< unsigned int > m_debugDumpEvents
Gaudi::Property< std::string > m_inputEdgeAttrName
StatusCode runGraphInference(const EventContext &ctx, GraphRawData &graphData) const override
Not supported by this tool; returns FAILURE.
StatusCode classifyEdges(const EventContext &ctx, const SegmentEdgeGraph &graph, std::vector< SegmentEdgeScore > &scores) const override
Run ONNX inference on graph and populate scores with logit and probability for each edge; called afte...
Gaudi::Property< std::string > m_debugDumpFile
StatusCode buildGraph(const EventContext &ctx, const xAOD::MuonSegmentContainer &segments, SegmentEdgeGraph &graph) const override
Build a GNN graph from segments, computing node and edge features and storing the graph structure in ...
std::vector< std::string > m_nodeFeatureNames
Node feature order expected by the model metadata (resolved at initialize).
Gaudi::Property< std::string > m_inputEdgeIndexName
Gaudi::Property< std::string > m_inputNodeName
std::vector< SegmentNodeFeatureId > m_nodeFeatureIds
StatusCode initialize() override
Retrieve the ONNX model and resolve node feature ordering from metadata.
StatusCode dumpDebugEvent(const EventContext &ctx, const SegmentEdgeGraph &graph, const std::vector< SegmentEdgeScore > &scores) const
int nTrigEtaLayers() const
Returns the number of trigger eta layers.
int nPrecisionHits() const
Amg::Vector3D direction() const
Returns the direction as Amg::Vector.
::Muon::MuonStationIndex::ChIndex chamberIndex() const
Returns the chamber index.
Amg::Vector3D position() const
Returns the position as Amg::Vector.
int nPhiLayers() const
Returns the number of phi layers.
int etaIndex() const
Returns the eta index, which corresponds to stationEta in the offline identifiers (and the ).
Eigen::Matrix< double, 3, 1 > Vector3D
SegmentNodeFeatureId
Identifier for each node feature in segment-based GNNs.
Definition MuonMLEvent.h:28
-diff
MuonSegmentContainer_v1 MuonSegmentContainer
Definition of the current "MuonSegment container version".
MuonSegment_v1 MuonSegment
Reference the current persistent version:
Segment features derived from or stored in bucket metadata.
int chamberIndex
Muon chamber index of the segment.
int sector
Sector number (typically 0–15).
int layers
Total number of active layers in the segment.
int nSegments
Count of segments in the same chamber/sector/eta group.
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition GraphData.h:42
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
std::vector< float > edgeFeatures
packed [E,7]: dpos(3), dist, cos, same_chamber, same_sector
std::vector< int64_t > edgeIndex
packed edge pairs [src0,dst0,src1,dst1,...]
std::vector< const xAOD::MuonSegment_v1 * > segments
std::vector< float > nodeFeatures
packed [N,10]: pos_m(3), dir_u(3), bucket(4)