ATLAS Offline Software
Loading...
Searching...
No Matches
SPIdDumperAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#include "SPIdDumperAlg.h"
6
8#include "Identifier/Identifier.h"
14
16
17#include <fstream>
18
19namespace MuonR4 {
20
22 ATH_CHECK(m_readKey.initialize());
23 ATH_CHECK(m_inSegmentKey.initialize());
24 ATH_CHECK(m_tree.init(this));
25 ATH_CHECK(m_graphFilterTool.retrieve());
26
27 ATH_MSG_DEBUG("Successfully initialized SPIdDumperAlg with ONNX model filtering.");
28 return StatusCode::SUCCESS;
29 }
30
32 const EventContext& ctx{Gaudi::Hive::currentContext()};
33
35 if (!readHandle.isValid()) {
36 ATH_MSG_ERROR("Failed to retrieve SpacePointContainer from StoreGate.");
37 return StatusCode::FAILURE;
38 }
39
40 MuonML::GraphRawData graphData;
41 ATH_CHECK(m_graphFilterTool->runGraphInference(ctx, graphData));
42
43 if (graphData.graph->dataTensor.size() < 3) {
44 ATH_MSG_DEBUG("ONNX inference output tensor is missing.");
45 return StatusCode::SUCCESS;
46 }
47
48 const float* predictions = graphData.graph->dataTensor[2].GetTensorMutableData<float>();
49 size_t predictionIndex = 0;
50
51 size_t totalNodes = graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount();
52 size_t totalFeatureElements = graphData.graph->dataTensor[0].GetTensorTypeAndShapeInfo().GetElementCount();
53
54 if (totalNodes == 0) {
55 ATH_MSG_ERROR("Total number of nodes is zero! Cannot divide.");
56 return StatusCode::FAILURE;
57 }
58
59 size_t numFeaturesPerNode = totalFeatureElements / totalNodes;
60 ATH_MSG_DEBUG("Inferred " << numFeaturesPerNode << " features per node from tensor of size "
61 << totalFeatureElements << " and " << totalNodes << " nodes.");
62
63 std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap;
64 SG::ReadHandle readSegment(m_inSegmentKey, ctx);
65 ATH_CHECK(readSegment.isPresent());
66 for (const MuonR4::Segment* segment : *readSegment) {
67 segmentMap[segment->parent()->parentBucket()].push_back(segment);
68 }
69
70 for (const MuonR4::SpacePointBucket* bucket : *readHandle) {
71
72 std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment;
73 auto match_itr = segmentMap.find(bucket);
74 if (match_itr != segmentMap.end()) {
75 unsigned int segIdx{0};
76 for (const MuonR4::Segment* segment : match_itr->second) {
77 for (const auto& meas : segment->measurements()) {
78 spacePointToSegment[meas->spacePoint()].push_back(segIdx);
79 }
80 ++segIdx;
81 }
82 }
83
84 SpacePointPerLayerSplitter splitter{*bucket};
85 unsigned int layer{0};
86
87 for (const auto& hitsInLay : splitter.mdtHits()) {
88 for (const auto sp : hitsInLay){
89
90 const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement());
91 if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){
92 continue;
93 }
94 m_spoint_x.push_back(sp->localPosition().x());
95 m_spoint_y.push_back(sp->localPosition().y());
96 m_spoint_z.push_back(sp->localPosition().z());
97 m_spoint_driftR.push_back(sp->driftRadius());
98 m_spoint_station.push_back(m_idHelperSvc->stationName(sp->identify()));
99 m_spoint_layer.push_back(layer);
100
101 if (spacePointToSegment.count(sp) > 0) {
102 m_spoint_label.push_back(1);
103 } else {
104 m_spoint_label.push_back(0);
105 }
106
107 if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
108 m_spoint_predictions.push_back(predictions[predictionIndex]);
109 predictionIndex++;
110 } else {
111 ATH_MSG_WARNING("Prediction index exceeded ONNX output size.");
112 m_spoint_predictions.push_back(-999); // Default invalid value
113 }
114
115
116 }
117 ++layer;
118
119 }
120
121 for (const auto& hitsInLay : splitter.stripHits()) {
122
123 for (const auto sp : hitsInLay){
124
125 m_spoint_x.push_back(sp->localPosition().x());
126 m_spoint_y.push_back(sp->localPosition().y());
127 m_spoint_z.push_back(sp->localPosition().z());
128 m_spoint_driftR.push_back(sp->driftRadius());
129 m_spoint_station.push_back(m_idHelperSvc->stationName(sp->identify()));
130 m_spoint_layer.push_back(layer);
131
132 if (spacePointToSegment.count(sp) > 0) {
133 m_spoint_label.push_back(1);
134 } else {
135 m_spoint_label.push_back(0);
136 }
137
138 if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
139 m_spoint_predictions.push_back(predictions[predictionIndex]);
140 predictionIndex++;
141 } else {
142 ATH_MSG_WARNING("Prediction index exceeded ONNX output size.");
143 m_spoint_predictions.push_back(-999); // Default invalid value
144 }
145
146 }
147 ++layer;
148
149 }
150
151 }
152
153 if (!m_tree.fill(ctx)) return StatusCode::FAILURE;
154 return StatusCode::SUCCESS;
155 }
156
158 ATH_CHECK(m_tree.write());
159 return StatusCode::SUCCESS;
160 }
161
162}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
static Double_t sp
Handle class for reading from StoreGate.
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
MuonVal::VectorBranch< float > & m_spoint_driftR
MuonVal::VectorBranch< float > & m_spoint_x
ServiceHandle< Muon::IMuonIdHelperSvc > m_idHelperSvc
MuonVal::MuonTesterTree m_tree
MuonVal::VectorBranch< uint16_t > & m_spoint_layer
MuonVal::VectorBranch< uint8_t > & m_spoint_station
MuonVal::VectorBranch< float > & m_spoint_z
virtual StatusCode finalize() override
virtual StatusCode initialize() override
SG::ReadHandleKey< MuonR4::SegmentContainer > m_inSegmentKey
ToolHandle< MuonML::IGraphInferenceTool > m_graphFilterTool
MuonVal::VectorBranch< uint8_t > & m_spoint_label
virtual StatusCode execute() override
MuonVal::VectorBranch< float > & m_spoint_y
MuonVal::VectorBranch< float > & m_spoint_predictions
Placeholder for what will later be the muon segment EDM representation.
const MeasVec & measurements() const
Returns the associated measurements.
: The muon space point bucket represents a collection of points that will bre processed together in t...
The SpacePointPerLayerSplitter takes a set of spacepoints already sorted by layer Identifier (see Muo...
const HitLayVec & mdtHits() const
Returns the sorted Mdt hits.
const HitLayVec & stripHits() const
Returns the sorted strip hits.
virtual bool isValid() override final
Can the handle be successfully dereferenced?
bool isPresent() const
Is the referenced object present in SG?
This header ties the generic definitions in this package.
MdtDriftCircle_v1 MdtDriftCircle
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46