ATLAS Offline Software
Loading...
Searching...
No Matches
MuonML::SegmentEdgeClassifierTool Class Referencefinal

Runs a segment-level GNN on reconstructed muon segments to classify segment-pair edges as "good" or "background". More...

#include <SegmentEdgeClassifierTool.h>

Inheritance diagram for MuonML::SegmentEdgeClassifierTool:
Collaboration diagram for MuonML::SegmentEdgeClassifierTool:

Public Member Functions

StatusCode initialize () override
 Retrieve the ONNX model and resolve node feature ordering from metadata.
StatusCode runGraphInference (const EventContext &ctx, GraphRawData &graphData) const override
 Not supported by this tool; returns FAILURE.
StatusCode buildGraph (const EventContext &ctx, const xAOD::MuonSegmentContainer &segments, SegmentEdgeGraph &graph) const override
 Build a GNN graph from segments, computing node and edge features and storing the graph structure in graph.
StatusCode classifyEdges (const EventContext &ctx, const SegmentEdgeGraph &graph, std::vector< SegmentEdgeScore > &scores) const override
 Run ONNX inference on graph and populate scores with logit and probability for each edge; called after buildGraph().
StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 GNN-style graph builder (features + edges). Kept for tools that want it.
StatusCode runInference (GraphRawData &graphData) const
 Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.
 DeclareInterfaceID (ISegmentEdgeClassifierTool, 1, 0)

Protected Member Functions

StatusCode setupModel ()
Ort::Session & model () const
StatusCode buildFeaturesOnly (const EventContext &ctx, GraphRawData &graphData) const
 Build only features (N,6); attaches one tensor in graph.dataTensor[0].
StatusCode buildTransformerInputs (const EventContext &ctx, GraphRawData &graphData) const
 Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
StatusCode runNamedInference (GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
 Generic named inference, for tools with different I/O conventions.

Static Protected Member Functions

static std::string trimFeatureToken (std::string s)
static std::vector< std::string > parseFeatureNames (const std::string &raw)

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
ActsTrk::GeoContextReadKey_t m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"}
Gaudi::Property< int > m_minLayers {this, "MinLayersValid", 3}
Gaudi::Property< int > m_maxChamberDelta {this, "MaxChamberDelta", 13}
Gaudi::Property< int > m_maxSectorDelta {this, "MaxSectorDelta", 1}
Gaudi::Property< double > m_maxDistXY {this, "MaxDistXY", 6800.0}
Gaudi::Property< double > m_maxAbsDz {this, "MaxAbsDz", 15000.0}
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5}
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12}
Gaudi::Property< bool > m_validateEdges {this, "ValidateEdges", true}
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
bool m_isCuda {false}
int m_cudaDeviceId {0}

Static Protected Attributes

static constexpr std::size_t kBucketFeatureCount = 6
static constexpr std::size_t kNodeFeatureCount = 10
static constexpr std::size_t kEdgeFeatureCount = 7
static constexpr std::array< std::string_view, kNodeFeatureCountkDefaultNodeFeatureNames

Private Member Functions

StatusCode dumpDebugEvent (const EventContext &ctx, const SegmentEdgeGraph &graph, const std::vector< SegmentEdgeScore > &scores) const

Private Attributes

Gaudi::Property< float > m_maxDeltaThetaDeg {this, "MaxDeltaThetaDeg", 35.f}
Gaudi::Property< int > m_maxDeltaSector {this, "MaxDeltaSector", 1}
Gaudi::Property< int > m_sectorModulo {this, "SectorModulo", 16}
Gaudi::Property< std::string > m_inputNodeName {this, "InputNodeName", "x"}
Gaudi::Property< std::string > m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"}
Gaudi::Property< std::string > m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"}
Gaudi::Property< std::string > m_outputName {this, "OutputName", "logits"}
Gaudi::Property< std::string > m_debugDumpFile {this, "DebugDumpFile", ""}
Gaudi::Property< unsigned int > m_debugDumpMaxEvents {this, "DebugDumpMaxEvents", 0}
float m_cosMin {0.f}
std::vector< std::string > m_nodeFeatureNames {}
 Node feature order expected by the model metadata (resolved at initialize).
std::vector< SegmentNodeFeatureIdm_nodeFeatureIds {}
std::mutex m_debugDumpMutex
std::atomic< unsigned int > m_debugDumpEvents {0}
ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool

Detailed Description

Runs a segment-level GNN on reconstructed muon segments to classify segment-pair edges as "good" or "background".

The tool reads a xAOD::MuonSegmentContainer and builds a graph where:

  • Nodes are muon segments, each with 10 features:
    • Position and direction (6 floats)
    • Chamber index, layer count, sector, and segment multiplicity (4 floats)
  • Edges connect all segment pairs within an angular threshold (cos(angle) >= cos(MaxDeltaThetaDeg)) and sector distance, with 7 features:
    • Spatial displacement (3 floats: dx, dy, dz)
    • Distance magnitude (1 float)
    • Angle (dot product, 1 float)
    • Chamber and sector match flags (2 flags)

The tool then runs an ONNX model (typically a GIN or GCN variant) to produce a logit or probability for each edge, enabling downstream algorithms to filter low-quality segment associations and improve reconstruction efficiency.

Key difference from GraphBucketFilterTool: operates at segment (edge) level rather than bucket (node) level, and the interface uses discrete graph structures (SegmentEdgeGraph) rather than tensors for input/output.

Note: runGraphInference() is not supported by this tool; use SegmentEdgeInferenceAlg and the ISegmentEdgeClassifierTool methods instead.

Definition at line 61 of file SegmentEdgeClassifierTool.h.

Member Function Documentation

◆ buildFeaturesOnly()

StatusCode BucketInferenceToolBase::buildFeaturesOnly ( const EventContext & ctx,
GraphRawData & graphData ) const
protectedinherited

Build only features (N,6); attaches one tensor in graph.dataTensor[0].

Definition at line 87 of file BucketInferenceToolBase.cxx.

88 {
89
90 graphData.graph.reset();
91 graphData.srcEdges.clear();
92 graphData.desEdges.clear();
93 graphData.edgeIndexPacked.clear();
94 graphData.featureLeaves.clear();
95 graphData.spacePointsInBucket.clear();
96 graphData.graph = std::make_unique<InferenceGraph>();
97 graphData.graph->dataTensor.reserve(1); // features input; outputs are reserved in runNamedInference()
98
99 const MuonR4::SpacePointContainer* buckets{nullptr};
100 ATH_CHECK(SG::get(buckets, m_readKey, ctx));
101
102 const ActsTrk::GeometryContext* gctx = nullptr;
103 ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
104
105 std::vector<BucketGraphUtils::NodeAux> nodes;
106 BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes,
107 graphData.featureLeaves,
108 graphData.spacePointsInBucket); // now int64_t-compatible
109
110 const int64_t numNodes = static_cast<int64_t>(nodes.size());
111 ATH_MSG_DEBUG("Total buckets: " << buckets->size()
112 << " -> nodes (size>0): " << numNodes
113 << " | features.size()=" << graphData.featureLeaves.size());
114
115 if (numNodes == 0) {
116 ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping inference.");
117 return StatusCode::SUCCESS;
118 }
119
120 const int64_t nFeatPerNode = static_cast<int64_t>(kBucketFeatureCount);
121 if (numNodes * nFeatPerNode != static_cast<int64_t>(graphData.featureLeaves.size())) {
122 ATH_MSG_ERROR( "Feature size mismatch: expected " << (numNodes * nFeatPerNode)
123 << " got " << graphData.featureLeaves.size());
124 return StatusCode::FAILURE;
125 }
126
127 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
128 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
129 graphData.graph->dataTensor.emplace_back(
130 Ort::Value::CreateTensor<float>(memInfo,
131 graphData.featureLeaves.data(),
132 graphData.featureLeaves.size(),
133 featShape.data(),
134 featShape.size()));
135 return StatusCode::SUCCESS;
136}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
size_type size() const noexcept
Returns the number of elements in the collection.
ActsTrk::GeoContextReadKey_t m_geoCtxKey
static constexpr std::size_t kBucketFeatureCount
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition GraphData.h:42
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36

◆ buildGraph() [1/2]

StatusCode BucketInferenceToolBase::buildGraph ( const EventContext & ctx,
GraphRawData & graphData ) const
inherited

GNN-style graph builder (features + edges). Kept for tools that want it.

Definition at line 200 of file BucketInferenceToolBase.cxx.

201 {
202
203 graphData.graph.reset();
204 graphData.srcEdges.clear();
205 graphData.desEdges.clear();
206 graphData.featureLeaves.clear();
207 graphData.spacePointsInBucket.clear();
208 graphData.edgeIndexPacked.clear();
209 graphData.graph = std::make_unique<InferenceGraph>();
210 graphData.graph->dataTensor.reserve(2); // features and edge_index inputs; outputs are reserved in runNamedInference()
211
212 const MuonR4::SpacePointContainer* buckets{nullptr};
213 ATH_CHECK(SG::get(buckets, m_readKey, ctx));
214
215 const ActsTrk::GeometryContext* gctx = nullptr;
216 ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
217
218 std::vector<BucketGraphUtils::NodeAux> nodes;
219
220 BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes,
221 graphData.featureLeaves,
222 graphData.spacePointsInBucket);
223
224 const int64_t numNodes = static_cast<int64_t>(nodes.size());
225 ATH_MSG_DEBUG("Total buckets: " << buckets->size()
226 << " -> nodes (size>0): " << numNodes
227 << " | features.size()=" << graphData.featureLeaves.size());
228
229 if (numNodes == 0) {
230 ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping graph building.");
231 return StatusCode::SUCCESS;
232 }
233
234 const int64_t nFeatPerNode = static_cast<int64_t>(kBucketFeatureCount);
235 if (numNodes * nFeatPerNode != static_cast<int64_t>(graphData.featureLeaves.size())) {
236 ATH_MSG_ERROR("Feature size mismatch: expected " << (numNodes * nFeatPerNode)
237 << " got " << graphData.featureLeaves.size());
238 return StatusCode::FAILURE;
239 }
240
241 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
242 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
243 graphData.graph->dataTensor.emplace_back(
244 Ort::Value::CreateTensor<float>(memInfo,
245 graphData.featureLeaves.data(),
246 graphData.featureLeaves.size(),
247 featShape.data(),
248 featShape.size()));
249
256 graphData.srcEdges, graphData.desEdges);
257 if (m_validateEdges) {
258 size_t bad = 0;
259 size_t write = 0;
260 for (size_t k = 0; k < graphData.srcEdges.size(); ++k) {
261 const int64_t u = graphData.srcEdges[k];
262 const int64_t v = graphData.desEdges[k];
263 const bool okU = (u >= 0 && u < numNodes);
264 const bool okV = (v >= 0 && v < numNodes);
265 if (okU && okV) {
266 graphData.srcEdges[write] = u;
267 graphData.desEdges[write] = v;
268 ++write;
269 } else {
270 ++bad;
271 ATH_MSG_DEBUG( "Drop invalid edge " << k << ": (" << u << "->" << v
272 << "), valid node range [0," << (numNodes-1) << "]");
273 }
274 }
275 if (bad) {
276 ATH_MSG_WARNING( "Removed " << bad << " invalid edges out of "
277 << graphData.srcEdges.size());
278 graphData.srcEdges.resize(write);
279 graphData.desEdges.resize(write);
280 }
281 }
282
283 const size_t E = graphData.srcEdges.size();
284
285 if (msgLvl(MSG::DEBUG)) {
286 // DEBUG: Count connections per node
287 ATH_MSG_DEBUG("Edges built: " << E);
288 const size_t dumpE = std::min<std::size_t>(m_debugDumpFirstNEdges.value(), E);
289 for (size_t k = 0; k < dumpE; ++k) {
290 ATH_MSG_DEBUG("EDGE[" << k << "]: "
291 << graphData.srcEdges[k] << " -> "
292 << graphData.desEdges[k]);
293 }
294
295 std::vector<int> nodeConnections(numNodes, 0);
296 for (size_t k = 0; k < graphData.srcEdges.size(); ++k) {
297 const int64_t u = graphData.srcEdges[k];
298 const int64_t v = graphData.desEdges[k];
299 if (u >= 0 && u < numNodes) nodeConnections[u]++;
300 if (v >= 0 && v < numNodes) nodeConnections[v]++;
301 }
302
303 ATH_MSG_DEBUG("=== DEBUGGING: Node Connections (first 10 nodes) ===");
304 const int64_t debugNodeCount = std::min(numNodes, static_cast<int64_t>(10));
305 for (int64_t i = 0; i < debugNodeCount; ++i) {
306 ATH_MSG_DEBUG("Node[" << i << "] connections: " << nodeConnections[i]);
307 }
308 ATH_MSG_DEBUG("=== END DEBUG NODE CONNECTIONS ===");
309
310 ATH_MSG_DEBUG("=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
311 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
312 std::stringstream connections;
313 connections << "Node[" << nodeIdx << "] connected to: ";
314 bool foundAny = false;
315
316 for (size_t k = 0; k < graphData.srcEdges.size(); ++k) {
317 const int64_t u = graphData.srcEdges[k];
318 const int64_t v = graphData.desEdges[k];
319
320 if (u == nodeIdx) {
321 if (foundAny) connections << ", ";
322 connections << v;
323 foundAny = true;
324 } else if (v == nodeIdx) {
325 if (foundAny) connections << ", ";
326 connections << u;
327 foundAny = true;
328 }
329 }
330
331 if (!foundAny) connections << "none";
332 ATH_MSG_DEBUG(connections.str());
333 }
334 ATH_MSG_DEBUG("=== END DEBUG DETAILED CONNECTIONS ===");
335 }
336
337 nodes = {};
338
339 graphData.edgeIndexPacked.clear();
340 const size_t Efinal = BucketGraphUtils::packEdgeIndex(graphData.srcEdges,
341 graphData.desEdges,
342 graphData.edgeIndexPacked);
343
344 graphData.srcEdges.clear();
345 graphData.desEdges.clear();
346
347 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(Efinal)};
348 graphData.graph->dataTensor.emplace_back(
349 Ort::Value::CreateTensor<int64_t>(memInfo,
350 graphData.edgeIndexPacked.data(),
351 graphData.edgeIndexPacked.size(),
352 edgeShape.data(),
353 edgeShape.size()));
354
355 ATH_MSG_DEBUG("Built sparse bucket graph: N=" << numNodes << ", E=" << Efinal);
356 return StatusCode::SUCCESS;
357}
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< double > m_maxDistXY
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
@ u
Enums for curvilinear frames.
Definition ParamDefs.h:77

◆ buildGraph() [2/2]

StatusCode MuonML::SegmentEdgeClassifierTool::buildGraph ( const EventContext & ctx,
const xAOD::MuonSegmentContainer & segments,
SegmentEdgeGraph & graph ) const
overridevirtual

Build a GNN graph from segments, computing node and edge features and storing the graph structure in graph.

Implements MuonML::ISegmentEdgeClassifierTool.

Definition at line 202 of file SegmentEdgeClassifierTool.cxx.

202 {
203 graph = SegmentEdgeGraph{};
204 graph.nNodes = segments.size();
205 graph.segments.reserve(graph.nNodes);
206 graph.nodeFeatures.reserve(graph.nNodes * kNodeFeatureCount);
207
208 std::vector<Amg::Vector3D> pos, dir;
209 std::vector<BucketSegmentFeatures> bucket;
210 pos.reserve(graph.nNodes); dir.reserve(graph.nNodes); bucket.reserve(graph.nNodes);
211
212 std::map<SegmentGroupKey, int> segmentMultiplicity{};
213 for (const xAOD::MuonSegment* seg : segments) {
214 if (!seg) continue;
215 ++segmentMultiplicity[segmentGroupKey(*seg)];
216 }
217
218 for (const xAOD::MuonSegment* seg : segments) {
219 if (!seg) continue;
220 const Amg::Vector3D p = seg->position();
221 Amg::Vector3D d = seg->direction();
222
223 const int chamberIdx = static_cast<int>(seg->chamberIndex());
224 const int layers = segmentLayerCount(*seg);
225 const int sec = seg->sector();
226 const auto multIt = segmentMultiplicity.find(segmentGroupKey(*seg));
227 const int nSeg = (multIt != segmentMultiplicity.end()) ? multIt->second : 1;
228
229 graph.segments.push_back(seg);
230 pos.emplace_back(p.x() / Gaudi::Units::m,
231 p.y() / Gaudi::Units::m,
232 p.z() / Gaudi::Units::m);
233 dir.emplace_back(d.x(), d.y(), d.z());
234 bucket.emplace_back(BucketSegmentFeatures{chamberIdx, layers, sec, nSeg});
235 for (const SegmentNodeFeatureId featureId : m_nodeFeatureIds) {
236 graph.nodeFeatures.push_back(nodeFeatureValue(featureId, pos.back(), dir.back(), bucket.back()));
237 }
238 }
239 graph.nNodes = graph.segments.size();
240
241 // Consistency check: all vectors must have same size
242 if (pos.size() != graph.nNodes || dir.size() != graph.nNodes || bucket.size() != graph.nNodes) {
243 ATH_MSG_ERROR("Inconsistent vector sizes during graph building: nodes=" << graph.nNodes
244 << ", pos=" << pos.size() << ", dir=" << dir.size() << ", bucket=" << bucket.size());
245 return StatusCode::FAILURE;
246 }
247
248 if (graph.nNodes < 2) {
249 graph.nEdges = 0;
250 return StatusCode::SUCCESS;
251 }
252
253 std::unordered_map<int, std::vector<std::size_t>> nodesBySector;
254 nodesBySector.reserve(graph.nNodes);
255 for (std::size_t i = 0; i < graph.nNodes; ++i) {
256 nodesBySector[bucket[i].sector].push_back(i);
257 }
258
259 auto normalizeSector = [&](int s) {
260 // m_sectorModulo > 0: wrap sector to [0, modulo); <=0: disable wrapping
261 if (m_sectorModulo.value() > 0) {
262 s %= m_sectorModulo.value();
263 if (s < 0) s += m_sectorModulo.value();
264 }
265 return s;
266 };
267
268 const std::size_t maxEdges = graph.nNodes * (graph.nNodes - 1);
269 graph.edgeIndex.reserve(2 * maxEdges);
270 graph.edgeFeatures.reserve(kEdgeFeatureCount * maxEdges);
271
272 for (std::size_t i = 0; i < graph.nNodes; ++i) {
273 std::unordered_set<int> targetSectors;
274 targetSectors.reserve(2 * m_maxDeltaSector.value() + 1);
275 for (int delta = -m_maxDeltaSector.value(); delta <= m_maxDeltaSector.value(); ++delta) {
276 targetSectors.insert(normalizeSector(bucket[i].sector + delta));
277 }
278
279 for (const int sec : targetSectors) {
280 auto it = nodesBySector.find(sec);
281 if (it == nodesBySector.end()) continue;
282 for (const std::size_t j : it->second) {
283 if (i == j) continue;
284 if (sectorDistance(bucket[i].sector, bucket[j].sector, m_sectorModulo.value()) > m_maxDeltaSector.value()) continue;
285 const float cosang = static_cast<float>(dir[i].dot(dir[j]));
286 if (cosang < m_cosMin) continue;
287
288 graph.edgeIndex.push_back(static_cast<int64_t>(i));
289 graph.edgeIndex.push_back(static_cast<int64_t>(j));
290
291 const Amg::Vector3D delta = pos[j] - pos[i];
292 const float dx = static_cast<float>(delta.x());
293 const float dy = static_cast<float>(delta.y());
294 const float dz = static_cast<float>(delta.z());
295 const float dist = static_cast<float>(delta.mag());
296 graph.edgeFeatures.insert(graph.edgeFeatures.end(), {dx,dy,dz,dist,cosang, float(bucket[i].chamberIndex==bucket[j].chamberIndex), float(bucket[i].sector==bucket[j].sector)});
297 }
298 }
299 }
300 graph.nEdges = graph.edgeIndex.size() / 2;
301 ATH_MSG_DEBUG("buildGraph: input segments=" << segments.size()
302 << ", kept nodes=" << graph.nNodes
303 << ", built edges=" << graph.nEdges);
304 return StatusCode::SUCCESS;
305}
static constexpr std::size_t kEdgeFeatureCount
static constexpr std::size_t kNodeFeatureCount
std::vector< SegmentNodeFeatureId > m_nodeFeatureIds
Eigen::Matrix< double, 3, 1 > Vector3D
layers(flags, cells_name, *args, **kw)
Here we define wrapper functions to set up all of the standard corrections.
SegmentNodeFeatureId
Identifier for each node feature in segment-based GNNs.
Definition MuonMLEvent.h:28
float j(const xAOD::IParticle &, const xAOD::TrackMeasurementValidation &hit, const Eigen::Matrix3d &jab_inv)
MuonSegment_v1 MuonSegment
Reference the current persistent version:

◆ buildTransformerInputs()

StatusCode BucketInferenceToolBase::buildTransformerInputs ( const EventContext & ctx,
GraphRawData & graphData ) const
protectedinherited

Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.

Definition at line 138 of file BucketInferenceToolBase.cxx.

139 {
140 // Start from (N,6)
141 ATH_CHECK(buildFeaturesOnly(ctx, graphData));
142
143 // Copy features flat buffer for lifetime management
144 std::vector<float> featuresFlat = graphData.featureLeaves;
145 const int64_t S = static_cast<int64_t>(featuresFlat.size() / kBucketFeatureCount);
146
147 if (S == 0) {
148 ATH_MSG_WARNING("No valid features for transformer input. Skipping inference.");
149 return StatusCode::SUCCESS;
150 }
151
152 if (msgLvl(MSG::DEBUG)) {
153 // DEBUG: Print transformer input features for first 10 nodes
154 ATH_MSG_DEBUG("=== DEBUGGING: Transformer input features for first 10 nodes ===");
155 const int64_t debugNodes = std::min(S, static_cast<int64_t>(10));
156 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
157 const int64_t baseIdx = nodeIdx * static_cast<int64_t>(kBucketFeatureCount);
158 ATH_MSG_DEBUG("TransformerNode[" << nodeIdx << "]: "
159 << "x=" << featuresFlat[baseIdx + 0] << ", "
160 << "y=" << featuresFlat[baseIdx + 1] << ", "
161 << "z=" << featuresFlat[baseIdx + 2] << ", "
162 << "layers=" << featuresFlat[baseIdx + 3] << ", "
163 << "nSp=" << featuresFlat[baseIdx + 4] << ", "
164 << "bucketSize=" << featuresFlat[baseIdx + 5]);
165 }
166 ATH_MSG_DEBUG("=== END DEBUG TRANSFORMER FEATURES ===");
167 }
168
169 // Rebuild graph with exactly 2 inputs: features [1,S,6], pad_mask [1,S]
170 graphData.graph.reset();
171 graphData.graph = std::make_unique<InferenceGraph>();
172 graphData.graph->dataTensor.reserve(2); // features and pad_mask inputs; outputs are reserved in runNamedInference()
173
174 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
175
176 // features: [1,S,6] (backed by graphData.featureLeaves to keep alive)
177 std::vector<int64_t> fShape{1, S, static_cast<int64_t>(kBucketFeatureCount)};
178 graphData.featureLeaves.swap(featuresFlat);
179 graphData.graph->dataTensor.emplace_back(
180 Ort::Value::CreateTensor<float>(memInfo,
181 graphData.featureLeaves.data(),
182 graphData.featureLeaves.size(),
183 fShape.data(),
184 fShape.size()));
185
186 // pad_mask: [1,S] (bool). Create ORT-owned tensor and fill with False (=valid).
187 Ort::AllocatorWithDefaultOptions allocator;
188 std::vector<int64_t> mShape{1, S};
189 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
190 mShape.data(),
191 mShape.size(),
192 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
193 bool* maskPtr = padVal.GetTensorMutableData<bool>();
194 for (int64_t i = 0; i < S; ++i) maskPtr[i] = false;
195 graphData.graph->dataTensor.emplace_back(std::move(padVal));
196
197 return StatusCode::SUCCESS;
198}
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].

◆ classifyEdges()

StatusCode MuonML::SegmentEdgeClassifierTool::classifyEdges ( const EventContext & ctx,
const SegmentEdgeGraph & graph,
std::vector< SegmentEdgeScore > & scores ) const
overridevirtual

Run ONNX inference on graph and populate scores with logit and probability for each edge; called after buildGraph().

Implements MuonML::ISegmentEdgeClassifierTool.

Definition at line 307 of file SegmentEdgeClassifierTool.cxx.

309 {
310 scores.clear();
311 if (!graph.nNodes) return StatusCode::SUCCESS;
312 if (!graph.nEdges) {
313 ATH_CHECK(dumpDebugEvent(ctx, graph, scores));
314 return StatusCode::SUCCESS;
315 }
316
317 if (graph.nodeFeatures.size() != graph.nNodes * kNodeFeatureCount) {
318 ATH_MSG_ERROR("Unexpected node feature size " << graph.nodeFeatures.size()
319 << "; expected " << (graph.nNodes * kNodeFeatureCount));
320 return StatusCode::FAILURE;
321 }
322 if (graph.edgeIndex.size() != 2 * graph.nEdges) {
323 ATH_MSG_ERROR("Unexpected edge index size " << graph.edgeIndex.size()
324 << "; expected " << (2 * graph.nEdges));
325 return StatusCode::FAILURE;
326 }
327 if (graph.edgeFeatures.size() != graph.nEdges * kEdgeFeatureCount) {
328 ATH_MSG_ERROR("Unexpected edge feature size " << graph.edgeFeatures.size()
329 << "; expected " << (graph.nEdges * kEdgeFeatureCount));
330 return StatusCode::FAILURE;
331 }
332
333 GraphRawData raw{};
334 raw.graph = std::make_unique<InferenceGraph>();
335 raw.featureLeaves = graph.nodeFeatures;
336 raw.edgeIndexPacked.reserve(2 * graph.nEdges);
337 raw.srcEdges.reserve(graph.nEdges);
338 raw.desEdges.reserve(graph.nEdges);
339 for (std::size_t e = 0; e < graph.nEdges; ++e) {
340 raw.srcEdges.push_back(graph.edgeIndex[2 * e]);
341 raw.desEdges.push_back(graph.edgeIndex[2 * e + 1]);
342 }
343 raw.edgeIndexPacked.insert(raw.edgeIndexPacked.end(), raw.srcEdges.begin(), raw.srcEdges.end());
344 raw.edgeIndexPacked.insert(raw.edgeIndexPacked.end(), raw.desEdges.begin(), raw.desEdges.end());
345
346 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
347
348 const std::vector<int64_t> nodeShape{static_cast<int64_t>(graph.nNodes), static_cast<int64_t>(kNodeFeatureCount)};
349 raw.graph->dataTensor.emplace_back(
350 Ort::Value::CreateTensor<float>(memInfo,
351 raw.featureLeaves.data(),
352 raw.featureLeaves.size(),
353 nodeShape.data(),
354 nodeShape.size()));
355
356 const std::vector<int64_t> edgeIndexShape{2, static_cast<int64_t>(graph.nEdges)};
357 raw.graph->dataTensor.emplace_back(
358 Ort::Value::CreateTensor<int64_t>(memInfo,
359 raw.edgeIndexPacked.data(),
360 raw.edgeIndexPacked.size(),
361 edgeIndexShape.data(),
362 edgeIndexShape.size()));
363
364 // ONNX Runtime's CreateTensor API takes a non-const pointer, but it does not
365 // mutate input buffers during inference. Avoid copying edge_attr every event.
366 ATLAS_THREAD_SAFE float* edgeFeaturesData = const_cast<float*>(graph.edgeFeatures.data());
367 const std::vector<int64_t> edgeAttrShape{static_cast<int64_t>(graph.nEdges), static_cast<int64_t>(kEdgeFeatureCount)};
368 raw.graph->dataTensor.emplace_back(
369 Ort::Value::CreateTensor<float>(memInfo,
370 edgeFeaturesData,
371 graph.edgeFeatures.size(),
372 edgeAttrShape.data(),
373 edgeAttrShape.size()));
374
375 const std::vector<const char*> inputNames{
376 m_inputNodeName.value().c_str(),
377 m_inputEdgeIndexName.value().c_str(),
378 m_inputEdgeAttrName.value().c_str()};
379 const std::vector<const char*> outputNames{m_outputName.value().c_str()};
380 ATH_MSG_DEBUG("classifyEdges: ONNX inputs shapes x=[" << nodeShape[0] << "," << nodeShape[1]
381 << "], edge_index=[" << edgeIndexShape[0] << "," << edgeIndexShape[1]
382 << "], edge_attr=[" << edgeAttrShape[0] << "," << edgeAttrShape[1] << "]");
383 ATH_CHECK(runNamedInference(raw, inputNames, outputNames));
384
385 if (raw.graph->dataTensor.size() <= inputNames.size()) {
386 ATH_MSG_ERROR("Missing ONNX output tensor for segment edge inference");
387 return StatusCode::FAILURE;
388 }
389
390 const Ort::Value& outTensor = raw.graph->dataTensor[inputNames.size()];
391 const auto outInfo = outTensor.GetTensorTypeAndShapeInfo();
392 const std::vector<int64_t> outShape = outInfo.GetShape();
393 const size_t outSize = outInfo.GetElementCount();
394 if (!outShape.empty()) {
395 ATH_MSG_DEBUG("classifyEdges: ONNX output rank=" << outShape.size()
396 << ", first dim=" << outShape.front()
397 << ", elements=" << outSize);
398 } else {
399 ATH_MSG_DEBUG("classifyEdges: ONNX scalar output, elements=" << outSize);
400 }
401 if (outSize < graph.nEdges) {
402 ATH_MSG_ERROR("ONNX logits tensor has " << outSize << " entries for " << graph.nEdges << " edges");
403 return StatusCode::FAILURE;
404 }
405
406 const float* logits = outTensor.GetTensorData<float>();
407 scores.reserve(graph.nEdges);
408 for (std::size_t e=0; e<graph.nEdges; ++e) {
409 const float l = logits[e];
410 scores.push_back({std::size_t(graph.edgeIndex[2 * e]),
411 std::size_t(graph.edgeIndex[2 * e + 1]),
412 l,
414 }
415
416 ATH_CHECK(dumpDebugEvent(ctx, graph, scores));
417 return StatusCode::SUCCESS;
418}
#define ATLAS_THREAD_SAFE
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
Gaudi::Property< std::string > m_outputName
Gaudi::Property< std::string > m_inputEdgeAttrName
Gaudi::Property< std::string > m_inputEdgeIndexName
Gaudi::Property< std::string > m_inputNodeName
StatusCode dumpDebugEvent(const EventContext &ctx, const SegmentEdgeGraph &graph, const std::vector< SegmentEdgeScore > &scores) const
l
Printing final latex table to .tex output file.

◆ DeclareInterfaceID()

MuonML::ISegmentEdgeClassifierTool::DeclareInterfaceID ( ISegmentEdgeClassifierTool ,
1 ,
0  )
inherited

◆ dumpDebugEvent()

StatusCode MuonML::SegmentEdgeClassifierTool::dumpDebugEvent ( const EventContext & ctx,
const SegmentEdgeGraph & graph,
const std::vector< SegmentEdgeScore > & scores ) const
private

Definition at line 420 of file SegmentEdgeClassifierTool.cxx.

423 {
424 if (m_debugDumpFile.value().empty()) return StatusCode::SUCCESS;
425
426 std::lock_guard<std::mutex> lock{m_debugDumpMutex};
427 if (m_debugDumpMaxEvents.value() != 0 &&
428 m_debugDumpEvents.load(std::memory_order_relaxed) >=
429 m_debugDumpMaxEvents.value()) {
430 return StatusCode::SUCCESS;
431 }
432
433 if (graph.nodeFeatures.size() != graph.nNodes * kNodeFeatureCount ||
434 graph.edgeIndex.size() != graph.nEdges * 2 ||
435 graph.edgeFeatures.size() != graph.nEdges * kEdgeFeatureCount ||
436 scores.size() != graph.nEdges) {
437 ATH_MSG_ERROR("Cannot write segment-edge debug dump: inconsistent graph/output sizes"
438 << " nodes=" << graph.nNodes
439 << " nodeFeatures=" << graph.nodeFeatures.size()
440 << " edges=" << graph.nEdges
441 << " edgeIndex=" << graph.edgeIndex.size()
442 << " edgeFeatures=" << graph.edgeFeatures.size()
443 << " scores=" << scores.size());
444 return StatusCode::FAILURE;
445 }
446
447 nlohmann::json x = nlohmann::json::array();
448 x.get_ref<nlohmann::json::array_t&>().reserve(graph.nodeFeatures.size());
449 for (const float value : graph.nodeFeatures) {
450 x.push_back(std::isfinite(value) ? nlohmann::json(value)
451 : nlohmann::json(nullptr));
452 }
453
454 nlohmann::json edgeIndex = nlohmann::json::array();
455 edgeIndex.get_ref<nlohmann::json::array_t&>().reserve(graph.nEdges * 2);
456 // This is the actual ONNX [2,E] row-major buffer: all sources then all destinations.
457 for (std::size_t edge = 0; edge < graph.nEdges; ++edge) {
458 edgeIndex.push_back(graph.edgeIndex[2 * edge]);
459 }
460 for (std::size_t edge = 0; edge < graph.nEdges; ++edge) {
461 edgeIndex.push_back(graph.edgeIndex[2 * edge + 1]);
462 }
463
464 nlohmann::json edgeAttr = nlohmann::json::array();
465 edgeAttr.get_ref<nlohmann::json::array_t&>().reserve(graph.edgeFeatures.size());
466 for (const float value : graph.edgeFeatures) {
467 edgeAttr.push_back(std::isfinite(value) ? nlohmann::json(value)
468 : nlohmann::json(nullptr));
469 }
470
471 nlohmann::json logits = nlohmann::json::array();
472 nlohmann::json probabilities = nlohmann::json::array();
473 nlohmann::json edgeSrc = nlohmann::json::array();
474 nlohmann::json edgeDst = nlohmann::json::array();
475 logits.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
476 probabilities.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
477 edgeSrc.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
478 edgeDst.get_ref<nlohmann::json::array_t&>().reserve(scores.size());
479 for (const SegmentEdgeScore& score : scores) {
480 edgeSrc.push_back(score.src);
481 edgeDst.push_back(score.dst);
482 logits.push_back(std::isfinite(score.logit) ? nlohmann::json(score.logit)
483 : nlohmann::json(nullptr));
484 probabilities.push_back(std::isfinite(score.probability)
485 ? nlohmann::json(score.probability)
486 : nlohmann::json(nullptr));
487 }
488
489 std::ofstream out{m_debugDumpFile.value(), std::ios::out | std::ios::app};
490 if (!out) {
491 ATH_MSG_ERROR("Could not append to segment-edge debug dump file: "
492 << m_debugDumpFile.value());
493 return StatusCode::FAILURE;
494 }
495
496 const unsigned int dumpIndex =
497 m_debugDumpEvents.fetch_add(1, std::memory_order_relaxed);
498 nlohmann::ordered_json event;
499 event["record_type"] = "event";
500 event["format_version"] = 1;
501 event["dump_index"] = dumpIndex;
502 event["run_number"] = ctx.eventID().run_number();
503 event["lumi_block"] = ctx.eventID().lumi_block();
504 event["event_number"] = ctx.eventID().event_number();
505 event["slot"] = ctx.slot();
506 event["n_nodes"] = graph.nNodes;
507 event["n_edges"] = graph.nEdges;
508 event["x_shape"] = {graph.nNodes, kNodeFeatureCount};
509 event["edge_index_shape"] = {2, graph.nEdges};
510 event["edge_attr_shape"] = {graph.nEdges, kEdgeFeatureCount};
511 event["logits_shape"] = {graph.nEdges};
512 event["x"] = std::move(x);
513 event["edge_index"] = std::move(edgeIndex);
514 event["edge_attr"] = std::move(edgeAttr);
515 event["edge_src"] = std::move(edgeSrc);
516 event["edge_dst"] = std::move(edgeDst);
517 event["logits"] = std::move(logits);
518 event["probabilities"] = std::move(probabilities);
519 out << event.dump() << '\n';
520
521 ATH_MSG_DEBUG("Wrote segment-edge debug event " << dumpIndex
522 << " to " << m_debugDumpFile.value());
523
524 return StatusCode::SUCCESS;
525}
virtual void lock()=0
Interface to allow an object to lock itself when made const in SG.
#define x
Gaudi::Property< unsigned int > m_debugDumpMaxEvents
std::atomic< unsigned int > m_debugDumpEvents
Gaudi::Property< std::string > m_debugDumpFile
virtual void reserve(size_t sz) override
Change the capacity of all aux data vectors.

◆ initialize()

StatusCode MuonML::SegmentEdgeClassifierTool::initialize ( )
override

Retrieve the ONNX model and resolve node feature ordering from metadata.

Definition at line 79 of file SegmentEdgeClassifierTool.cxx.

79 {
81
82 // Resolve node feature names from model metadata, matching the ONNX exporter.
83 {
84 Ort::AllocatorWithDefaultOptions allocator;
85 Ort::ModelMetadata meta = model().GetModelMetadata();
86 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
87 std::vector<std::string> keyList;
88 keyList.reserve(keys.size());
89 for (const auto& k : keys) keyList.emplace_back(k.get());
90
91 constexpr std::array<std::string_view, 4> candidates{
92 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
93 std::string usedKey;
94 std::vector<std::string> names;
95 for (std::string_view key : candidates) {
96 const std::string keyStr{key};
97 if (std::find(keyList.begin(), keyList.end(), keyStr) == keyList.end()) continue;
98 names = parseFeatureNames(meta.LookupCustomMetadataMapAllocated(keyStr.c_str(), allocator).get());
99 if (!names.empty()) {
100 usedKey = keyStr;
101 break;
102 }
103 }
104
105 if (names.empty()) {
107 ATH_MSG_WARNING("Model metadata has no usable node feature name key"
108 " (tried x_feature_names/node_feature_names/feature_names/input_feature_names)."
109 " Falling back to default training order.");
110 } else {
111 if (names.size() != kNodeFeatureCount) {
112 ATH_MSG_ERROR("Model metadata key '" << usedKey << "' has " << names.size()
113 << " features, expected " << kNodeFeatureCount);
114 return StatusCode::FAILURE;
115 }
116 for (const std::string& n : names) {
117 if (!nodeFeatureIdFromName(n).has_value()) {
118 ATH_MSG_ERROR("Unsupported node feature name in model metadata ('" << usedKey
119 << "'): '" << n << "'."
120 " Add mapping in SegmentEdgeClassifierTool::nodeFeatureValue().");
121 return StatusCode::FAILURE;
122 }
123 }
124 m_nodeFeatureNames = std::move(names);
125 ATH_MSG_DEBUG("Using node feature names from model metadata key '" << usedKey << "'.");
126 }
127
128 m_nodeFeatureIds.reserve(m_nodeFeatureNames.size());
129 for (const std::string& n : m_nodeFeatureNames) {
130 const auto id = nodeFeatureIdFromName(n);
131 if (!id.has_value()) {
132 ATH_MSG_ERROR("Internal feature-id resolution failed for node feature name '" << n << "'.");
133 return StatusCode::FAILURE;
134 }
135 m_nodeFeatureIds.push_back(*id);
136 }
137
138 std::ostringstream order;
139 order << "Node feature order:";
140 for (std::size_t i = 0; i < m_nodeFeatureNames.size(); ++i) {
141 order << " f" << i << "=" << m_nodeFeatureNames[i];
142 if (i + 1 < m_nodeFeatureNames.size()) order << ",";
143 }
144 ATH_MSG_DEBUG(order.str());
145 }
146
148 ATH_MSG_ERROR("Internal node feature setup has " << m_nodeFeatureNames.size()
149 << " entries, expected " << kNodeFeatureCount);
150 return StatusCode::FAILURE;
151 }
152 if (m_nodeFeatureIds.size() != kNodeFeatureCount) {
153 ATH_MSG_ERROR("Internal node feature id setup has " << m_nodeFeatureIds.size()
154 << " entries, expected " << kNodeFeatureCount);
155 return StatusCode::FAILURE;
156 }
157
158 m_cosMin = std::cos(m_maxDeltaThetaDeg.value() * Gaudi::Units::deg);
159
160 if (!m_debugDumpFile.value().empty()) {
161 std::ofstream out{m_debugDumpFile.value(), std::ios::out | std::ios::trunc};
162 if (!out) {
163 ATH_MSG_ERROR("Could not create segment-edge debug dump file: "
164 << m_debugDumpFile.value());
165 return StatusCode::FAILURE;
166 }
167
168 nlohmann::ordered_json metadata;
169 metadata["record_type"] = "metadata";
170 metadata["format_version"] = 1;
171 metadata["tool"] = "SegmentEdgeClassifierTool";
172 metadata["input_names"] = {m_inputNodeName.value(),
173 m_inputEdgeIndexName.value(),
174 m_inputEdgeAttrName.value()};
175 metadata["output_name"] = m_outputName.value();
176 metadata["x_feature_names"] = m_nodeFeatureNames;
177 metadata["edge_attr_feature_names"] = {
178 "deltaPositionX_m", "deltaPositionY_m", "deltaPositionZ_m",
179 "distance_m", "cos_opening_angle", "same_chamber", "same_sector"};
180 metadata["edge_index_layout"] = "row_major_2_by_E";
181 metadata["edge_order"] = "directed src_to_dst; row 0 then row 1";
182 metadata["max_delta_theta_deg"] = m_maxDeltaThetaDeg.value();
183 metadata["max_delta_sector"] = m_maxDeltaSector.value();
184 metadata["sector_modulo"] = m_sectorModulo.value();
185 metadata["debug_dump_max_events"] = m_debugDumpMaxEvents.value();
186 out << metadata.dump() << '\n';
187
188 ATH_MSG_INFO("Writing segment-edge ONNX debug dump to "
189 << m_debugDumpFile.value()
190 << " (DebugDumpMaxEvents="
191 << m_debugDumpMaxEvents.value() << ")");
192 }
193
194 return StatusCode::SUCCESS;
195}
#define ATH_MSG_INFO(x)
static constexpr std::array< std::string_view, kNodeFeatureCount > kDefaultNodeFeatureNames
static std::vector< std::string > parseFeatureNames(const std::string &raw)
std::vector< std::string > m_nodeFeatureNames
Node feature order expected by the model metadata (resolved at initialize).
order
Configure Herwig7.

◆ model()

Ort::Session & BucketInferenceToolBase::model ( ) const
protectedinherited

Definition at line 65 of file BucketInferenceToolBase.cxx.

65 {
66 return m_onnxSessionTool->session();
67}
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool

◆ parseFeatureNames()

std::vector< std::string > BucketInferenceToolBase::parseFeatureNames ( const std::string & raw)
staticprotectedinherited

Definition at line 32 of file BucketInferenceToolBase.cxx.

32 {
33 std::vector<std::string> out;
34 const std::string s = trimFeatureToken(raw);
35 if (s.empty()) return out;
36
37 // Preferred exporter format: JSON list of strings.
38 if (!s.empty() && s.front() == '[') {
39 bool inQuote = false;
40 std::string token;
41 for (char c : s) {
42 if (c == '"') {
43 if (inQuote) {
44 if (!token.empty()) out.push_back(token);
45 token.clear();
46 }
47 inQuote = !inQuote;
48 continue;
49 }
50 if (inQuote) token.push_back(c);
51 }
52 if (!out.empty()) return out;
53 }
54
55 // Backward-compatible format: comma-separated.
56 std::istringstream ss(s);
57 std::string tok;
58 while (std::getline(ss, tok, ',')) {
59 tok = trimFeatureToken(tok);
60 if (!tok.empty()) out.push_back(tok);
61 }
62 return out;
63}
static Double_t ss
static std::string trimFeatureToken(std::string s)

◆ runGraphInference()

StatusCode MuonML::SegmentEdgeClassifierTool::runGraphInference ( const EventContext & ctx,
GraphRawData & graphData ) const
override

Not supported by this tool; returns FAILURE.

Use SegmentEdgeInferenceAlg + buildGraph() + classifyEdges() instead.

Definition at line 197 of file SegmentEdgeClassifierTool.cxx.

197 {
198 ATH_MSG_ERROR("runGraphInference is not supported by SegmentEdgeClassifierTool. Use SegmentEdgeInferenceAlg + ISegmentEdgeClassifierTool methods.");
199 return StatusCode::FAILURE;
200}

◆ runInference()

StatusCode BucketInferenceToolBase::runInference ( GraphRawData & graphData) const
inherited

Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.

Definition at line 532 of file BucketInferenceToolBase.cxx.

532 {
533 std::vector<const char*> inputNames = {"features", "edge_index"};
534 std::vector<const char*> outputNames = {m_outputName.value().c_str()};
535 return runNamedInference(graphData, inputNames, outputNames);
536}
Gaudi::Property< std::string > m_outputName

◆ runNamedInference()

StatusCode BucketInferenceToolBase::runNamedInference ( GraphRawData & graphData,
const std::vector< const char * > & inputNames,
const std::vector< const char * > & outputNames ) const
protectedinherited

Generic named inference, for tools with different I/O conventions.

Definition at line 359 of file BucketInferenceToolBase.cxx.

363{
364 if (!graphData.graph) {
365 ATH_MSG_ERROR("Graph data is not built.");
366 return StatusCode::FAILURE;
367 }
368 if (graphData.graph->dataTensor.empty()) {
369 ATH_MSG_ERROR("No input tensors prepared for inference.");
370 return StatusCode::FAILURE;
371 }
372
373 // Reserve the final size here from the actual I/O lists instead
374 // of hard-coding assumptions in the graph builders.
375 graphData.graph->dataTensor.reserve(inputNames.size() + outputNames.size());
376 if (graphData.graph->dataTensor.size() < inputNames.size()) {
377 ATH_MSG_ERROR("Prepared " << graphData.graph->dataTensor.size()
378 << " tensors but inference expects " << inputNames.size() << " inputs.");
379 return StatusCode::FAILURE;
380 }
381
382 if (msgLvl(MSG::DEBUG)) {
383 // DEBUG: Print actual input tensor data for features tensor
384
385 ATH_MSG_DEBUG("=== DEBUGGING: ONNX Input tensor data ===");
386 if (!graphData.graph->dataTensor.empty()) {
387 const auto& featureTensor = graphData.graph->dataTensor[0];
388 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
389 ATH_MSG_DEBUG("Features tensor shape: [" << featShape[0]
390 << (featShape.size()>1 ? ("," + std::to_string(featShape[1])) : "")
391 << (featShape.size()>2 ? ("," + std::to_string(featShape[2])) : "") << "]");
392
393 float* featData = const_cast<Ort::Value&>(featureTensor).GetTensorMutableData<float>();
394 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
395 ATH_MSG_DEBUG("Features tensor total elements: " << totalElements);
396
397 // Print up to 10 nodes; stride = nFeat from tensor shape
398 const size_t nFeat = (featShape.size() > 1 && featShape[1] > 0) ? static_cast<size_t>(featShape[1]) : 1;
399 const size_t nNodes = totalElements / nFeat;
400 const size_t debugNodes = std::min(nNodes, static_cast<size_t>(10));
401
402 // Try to read feature names from model custom metadata.
403 // Prefer x_feature_names (current exporter), then fall back to legacy keys.
404 std::vector<std::string> featNames;
405 {
406 Ort::AllocatorWithDefaultOptions allocator;
407 Ort::ModelMetadata meta = model().GetModelMetadata();
408 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
409 std::vector<std::string> keyNames;
410 keyNames.reserve(keys.size());
411 for (const auto& k : keys) keyNames.emplace_back(k.get());
412 const std::array<std::string, 4> candidates{
413 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
414 for (const std::string& key : candidates) {
415 if (std::find(keyNames.begin(), keyNames.end(), key) != keyNames.end()) {
416 std::string val = meta.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get();
417 featNames = parseFeatureNames(val);
418 break;
419 }
420 }
421 if (featNames.empty()) {
422 ATH_MSG_DEBUG("No usable feature-name metadata key found in model; using generic fN labels.");
423 }
424 }
425 auto featLabel = [&](size_t f) -> std::string {
426 if (f < featNames.size()) return featNames[f];
427 return "f" + std::to_string(f);
428 };
429
430 // Print legend
431 {
432 std::ostringstream legend;
433 legend << "Node feature legend (" << nFeat << " features):";
434 for (size_t f = 0; f < nFeat; ++f) {
435 legend << " f" << f << "=" << featLabel(f);
436 if (f + 1 < nFeat) legend << ",";
437 }
438 ATH_MSG_DEBUG(legend.str());
439 }
440
441 for (size_t n = 0; n < debugNodes; ++n) {
442 std::ostringstream row;
443 row << "ONNXNode[" << n << "]:";
444 for (size_t f = 0; f < nFeat; ++f) {
445 row << " f" << f << "=" << featData[n * nFeat + f];
446 if (f + 1 < nFeat) row << ",";
447 }
448 ATH_MSG_DEBUG(row.str());
449 }
450 }
451 ATH_MSG_DEBUG("=== END DEBUG ONNX INPUT ===");
452 }
453
454 Ort::RunOptions run_options;
455 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
456
457 if (m_isCuda) {
458 // ---- CUDA path: use IoBinding so tensors stay on device ----
459 Ort::IoBinding binding(model());
460 for (std::size_t i = 0; i < inputNames.size(); ++i) {
461 binding.BindInput(inputNames[i], graphData.graph->dataTensor[i]);
462 }
463 // Bind outputs to CPU so predictions are directly readable after sync.
464 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
465 for (const char* outName : outputNames) {
466 binding.BindOutput(outName, cpuOut);
467 }
468
469 model().Run(run_options, binding);
470 binding.SynchronizeOutputs();
471
472 std::vector<Ort::Value> outputs = binding.GetOutputValues();
473 if (outputs.empty()) {
474 ATH_MSG_ERROR("IoBinding inference returned empty output.");
475 return StatusCode::FAILURE;
476 }
477
478 float* outData = outputs[0].GetTensorMutableData<float>();
479 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
480 ATH_MSG_DEBUG("ONNX (IoBinding) raw output elementCount = " << outSize);
481
482 if (m_sanitizeNonFinitePredictions.value()) {
483 std::span<float> preds(outData, outData + outSize);
484 for (size_t i = 0; i < outSize; ++i) {
485 if (!std::isfinite(preds[i])) {
486 ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
487 preds[i] = -100.0f;
488 }
489 }
490 }
491
492 for (auto& v : outputs) {
493 graphData.graph->dataTensor.emplace_back(std::move(v));
494 }
495 return StatusCode::SUCCESS;
496 }
497
498 // ---- CPU path ----
499 std::vector<Ort::Value> outputs =
500 model().Run(run_options,
501 inputNames.data(),
502 graphData.graph->dataTensor.data(),
503 inputNames.size(),
504 outputNames.data(),
505 outputNames.size());
506
507 if (outputs.empty()) {
508 ATH_MSG_ERROR("Inference returned empty output.");
509 return StatusCode::FAILURE;
510 }
511
512 float* outData = outputs[0].GetTensorMutableData<float>();
513 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
514 ATH_MSG_DEBUG("ONNX raw output elementCount = " << outSize);
515
516 if (m_sanitizeNonFinitePredictions.value()) {
517 std::span<float> preds(outData, outData + outSize);
518 for (size_t i = 0; i < outSize; ++i) {
519 if (!std::isfinite(preds[i])) {
520 ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
521 preds[i] = -100.0f;
522 }
523 }
524 }
525
526 for (auto& v : outputs) {
527 graphData.graph->dataTensor.emplace_back(std::move(v));
528 }
529 return StatusCode::SUCCESS;
530}
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
row
Appending html table to final .html summary file.

◆ setupModel()

StatusCode BucketInferenceToolBase::setupModel ( )
protectedinherited

Definition at line 69 of file BucketInferenceToolBase.cxx.

69 {
70 ATH_CHECK(m_onnxSessionTool.retrieve());
71 ATH_CHECK(m_readKey.initialize());
72 ATH_CHECK(m_geoCtxKey.initialize());
73
74 const InferenceUtils::SessionBackend backend = InferenceUtils::sessionBackend(m_onnxSessionTool);
75 m_isCuda = backend.isCuda;
76 m_cudaDeviceId = backend.cudaDeviceId;
77 if (m_isCuda) {
78 ATH_MSG_INFO("ONNX session is running on CUDA device " << m_cudaDeviceId
79 << ". I/O binding will be used.");
80 } else {
81 ATH_MSG_INFO("ONNX session is running on CPU.");
82 }
83
84 return StatusCode::SUCCESS;
85}
SessionBackend sessionBackend(const SessionToolHandle &sessionTool)

◆ trimFeatureToken()

std::string BucketInferenceToolBase::trimFeatureToken ( std::string s)
staticprotectedinherited

Definition at line 25 of file BucketInferenceToolBase.cxx.

25 {
26 auto notSpace = [](unsigned char c) { return !std::isspace(c); };
27 s.erase(s.begin(), std::find_if(s.begin(), s.end(), notSpace));
28 s.erase(std::find_if(s.rbegin(), s.rend(), notSpace).base(), s.end());
29 return s;
30}

Member Data Documentation

◆ kBucketFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kBucketFeatureCount = 6
staticconstexprprotectedinherited

Definition at line 53 of file BucketInferenceToolBase.h.

◆ kDefaultNodeFeatureNames

std::array<std::string_view, kNodeFeatureCount> MuonML::BucketInferenceToolBase::kDefaultNodeFeatureNames
staticconstexprprotectedinherited
Initial value:
= {
"segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
"segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
"bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"}

Definition at line 56 of file BucketInferenceToolBase.h.

56 {
57 "segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
58 "segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
59 "bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"};

◆ kEdgeFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kEdgeFeatureCount = 7
staticconstexprprotectedinherited

Definition at line 55 of file BucketInferenceToolBase.h.

◆ kNodeFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kNodeFeatureCount = 10
staticconstexprprotectedinherited

Definition at line 54 of file BucketInferenceToolBase.h.

◆ m_cosMin

float MuonML::SegmentEdgeClassifierTool::m_cosMin {0.f}
private

Definition at line 100 of file SegmentEdgeClassifierTool.h.

100{0.f};

◆ m_cudaDeviceId

int MuonML::BucketInferenceToolBase::m_cudaDeviceId {0}
protectedinherited

Definition at line 102 of file BucketInferenceToolBase.h.

102{0};

◆ m_debugDumpEvents

std::atomic<unsigned int> MuonML::SegmentEdgeClassifierTool::m_debugDumpEvents {0}
mutableprivate

Definition at line 107 of file SegmentEdgeClassifierTool.h.

107{0};

◆ m_debugDumpFile

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_debugDumpFile {this, "DebugDumpFile", ""}
private

Definition at line 98 of file SegmentEdgeClassifierTool.h.

98{this, "DebugDumpFile", ""};

◆ m_debugDumpFirstNEdges

Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12}
protectedinherited

Definition at line 94 of file BucketInferenceToolBase.h.

94{this, "DebugDumpFirstNEdges", 12};

◆ m_debugDumpFirstNNodes

Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5}
protectedinherited

Definition at line 93 of file BucketInferenceToolBase.h.

93{this, "DebugDumpFirstNNodes", 5};

◆ m_debugDumpMaxEvents

Gaudi::Property<unsigned int> MuonML::SegmentEdgeClassifierTool::m_debugDumpMaxEvents {this, "DebugDumpMaxEvents", 0}
private

Definition at line 99 of file SegmentEdgeClassifierTool.h.

99{this, "DebugDumpMaxEvents", 0};

◆ m_debugDumpMutex

std::mutex MuonML::SegmentEdgeClassifierTool::m_debugDumpMutex
mutableprivate

Definition at line 106 of file SegmentEdgeClassifierTool.h.

◆ m_geoCtxKey

ActsTrk::GeoContextReadKey_t MuonML::BucketInferenceToolBase::m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"}
protectedinherited

Definition at line 80 of file BucketInferenceToolBase.h.

80{this, "AlignmentKey", "ActsAlignment", "cond handle key"};

◆ m_inputEdgeAttrName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"}
private

Definition at line 96 of file SegmentEdgeClassifierTool.h.

96{this, "InputEdgeAttrName", "edge_attr"};

◆ m_inputEdgeIndexName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"}
private

Definition at line 95 of file SegmentEdgeClassifierTool.h.

95{this, "InputEdgeIndexName", "edge_index"};

◆ m_inputNodeName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_inputNodeName {this, "InputNodeName", "x"}
private

Definition at line 94 of file SegmentEdgeClassifierTool.h.

94{this, "InputNodeName", "x"};

◆ m_isCuda

bool MuonML::BucketInferenceToolBase::m_isCuda {false}
protectedinherited

Definition at line 101 of file BucketInferenceToolBase.h.

101{false};

◆ m_maxAbsDz

Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxAbsDz {this, "MaxAbsDz", 15000.0}
protectedinherited

Definition at line 90 of file BucketInferenceToolBase.h.

90{this, "MaxAbsDz", 15000.0};

◆ m_maxChamberDelta

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxChamberDelta {this, "MaxChamberDelta", 13}
protectedinherited

Definition at line 87 of file BucketInferenceToolBase.h.

87{this, "MaxChamberDelta", 13};

◆ m_maxDeltaSector

Gaudi::Property<int> MuonML::SegmentEdgeClassifierTool::m_maxDeltaSector {this, "MaxDeltaSector", 1}
private

Definition at line 92 of file SegmentEdgeClassifierTool.h.

92{this, "MaxDeltaSector", 1};

◆ m_maxDeltaThetaDeg

Gaudi::Property<float> MuonML::SegmentEdgeClassifierTool::m_maxDeltaThetaDeg {this, "MaxDeltaThetaDeg", 35.f}
private

Definition at line 91 of file SegmentEdgeClassifierTool.h.

91{this, "MaxDeltaThetaDeg", 35.f};

◆ m_maxDistXY

Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxDistXY {this, "MaxDistXY", 6800.0}
protectedinherited

Definition at line 89 of file BucketInferenceToolBase.h.

89{this, "MaxDistXY", 6800.0};

◆ m_maxSectorDelta

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxSectorDelta {this, "MaxSectorDelta", 1}
protectedinherited

Definition at line 88 of file BucketInferenceToolBase.h.

88{this, "MaxSectorDelta", 1};

◆ m_minLayers

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_minLayers {this, "MinLayersValid", 3}
protectedinherited

Definition at line 86 of file BucketInferenceToolBase.h.

86{this, "MinLayersValid", 3};

◆ m_nodeFeatureIds

std::vector<SegmentNodeFeatureId> MuonML::SegmentEdgeClassifierTool::m_nodeFeatureIds {}
private

Definition at line 104 of file SegmentEdgeClassifierTool.h.

104{};

◆ m_nodeFeatureNames

std::vector<std::string> MuonML::SegmentEdgeClassifierTool::m_nodeFeatureNames {}
private

Node feature order expected by the model metadata (resolved at initialize).

Definition at line 103 of file SegmentEdgeClassifierTool.h.

103{};

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::BucketInferenceToolBase::m_onnxSessionTool
privateinherited
Initial value:
{
this, "ModelSession", ""}

Definition at line 105 of file BucketInferenceToolBase.h.

105 {
106 this, "ModelSession", ""};

◆ m_outputName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_outputName {this, "OutputName", "logits"}
private

Definition at line 97 of file SegmentEdgeClassifierTool.h.

97{this, "OutputName", "logits"};

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::BucketInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protectedinherited

Definition at line 79 of file BucketInferenceToolBase.h.

79{this, "ReadSpacePoints", "MuonSpacePoints"};

◆ m_sanitizeNonFinitePredictions

Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_sanitizeNonFinitePredictions
protectedinherited
Initial value:
{
this, "SanitizeNonFinitePredictions", false,
"When true, replace non-finite ONNX outputs with -100 and log a warning."}

Definition at line 96 of file BucketInferenceToolBase.h.

96 {
97 this, "SanitizeNonFinitePredictions", false,
98 "When true, replace non-finite ONNX outputs with -100 and log a warning."};

◆ m_sectorModulo

Gaudi::Property<int> MuonML::SegmentEdgeClassifierTool::m_sectorModulo {this, "SectorModulo", 16}
private

Definition at line 93 of file SegmentEdgeClassifierTool.h.

93{this, "SectorModulo", 16};

◆ m_validateEdges

Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_validateEdges {this, "ValidateEdges", true}
protectedinherited

Definition at line 95 of file BucketInferenceToolBase.h.

95{this, "ValidateEdges", true};

The documentation for this class was generated from the following files: