ATLAS Offline Software
Loading...
Searching...
No Matches
MuonML::SegmentEdgeClassifierTool Class Referencefinal

Runs a segment-level GNN on reconstructed muon segments to classify segment-pair edges as "good" or "background". More...

#include <SegmentEdgeClassifierTool.h>

Inheritance diagram for MuonML::SegmentEdgeClassifierTool:
Collaboration diagram for MuonML::SegmentEdgeClassifierTool:

Public Member Functions

StatusCode initialize () override
 Retrieve the ONNX model and resolve node feature ordering from metadata.
StatusCode runGraphInference (const EventContext &ctx, GraphRawData &graphData) const override
 Not supported by this tool; returns FAILURE.
StatusCode buildGraph (const EventContext &ctx, const xAOD::MuonSegmentContainer &segments, SegmentEdgeGraph &graph) const override
 Build a GNN graph from segments, computing node and edge features and storing the graph structure in graph.
StatusCode classifyEdges (const EventContext &ctx, const SegmentEdgeGraph &graph, std::vector< SegmentEdgeScore > &scores) const override
 Run ONNX inference on graph and populate scores with logit and probability for each edge; called after buildGraph().
StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 GNN-style graph builder (features + edges). Kept for tools that want it.
StatusCode runInference (GraphRawData &graphData) const
 Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.
 DeclareInterfaceID (ISegmentEdgeClassifierTool, 1, 0)

Protected Member Functions

StatusCode setupModel ()
Ort::Session & model () const
StatusCode buildFeaturesOnly (const EventContext &ctx, GraphRawData &graphData) const
 Build only features (N,6); attaches one tensor in graph.dataTensor[0].
StatusCode buildTransformerInputs (const EventContext &ctx, GraphRawData &graphData) const
 Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
StatusCode runNamedInference (GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
 Generic named inference, for tools with different I/O conventions.

Static Protected Member Functions

static std::string trimFeatureToken (std::string s)
static std::vector< std::string > parseFeatureNames (const std::string &raw)

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
SG::ReadHandleKey< ActsTrk::GeometryContextm_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"}
Gaudi::Property< int > m_minLayers {this, "MinLayersValid", 3}
Gaudi::Property< int > m_maxChamberDelta {this, "MaxChamberDelta", 13}
Gaudi::Property< int > m_maxSectorDelta {this, "MaxSectorDelta", 1}
Gaudi::Property< double > m_maxDistXY {this, "MaxDistXY", 6800.0}
Gaudi::Property< double > m_maxAbsDz {this, "MaxAbsDz", 15000.0}
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5}
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12}
Gaudi::Property< bool > m_validateEdges {this, "ValidateEdges", true}
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
bool m_isCuda {false}
int m_cudaDeviceId {0}

Static Protected Attributes

static constexpr std::size_t kBucketFeatureCount = 6
static constexpr std::size_t kNodeFeatureCount = 10
static constexpr std::size_t kEdgeFeatureCount = 7
static constexpr std::array< std::string_view, kNodeFeatureCountkDefaultNodeFeatureNames

Private Attributes

Gaudi::Property< float > m_maxDeltaThetaDeg {this, "MaxDeltaThetaDeg", 35.f}
Gaudi::Property< int > m_maxDeltaSector {this, "MaxDeltaSector", 1}
Gaudi::Property< int > m_sectorModulo {this, "SectorModulo", 16}
Gaudi::Property< std::string > m_inputNodeName {this, "InputNodeName", "x"}
Gaudi::Property< std::string > m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"}
Gaudi::Property< std::string > m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"}
Gaudi::Property< std::string > m_outputName {this, "OutputName", "logits"}
float m_cosMin {0.f}
std::vector< std::string > m_nodeFeatureNames {}
 Node feature order expected by the model metadata (resolved at initialize).
std::vector< SegmentNodeFeatureIdm_nodeFeatureIds {}
ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool

Detailed Description

Runs a segment-level GNN on reconstructed muon segments to classify segment-pair edges as "good" or "background".

The tool reads a xAOD::MuonSegmentContainer and builds a graph where:

  • Nodes are muon segments, each with 10 features:
    • Position and direction (6 floats)
    • Chamber index, layer count, sector, and segment multiplicity (4 floats)
  • Edges connect all segment pairs within an angular threshold (cos(angle) >= cos(MaxDeltaThetaDeg)) and sector distance, with 7 features:
    • Spatial displacement (3 floats: dx, dy, dz)
    • Distance magnitude (1 float)
    • Angle (dot product, 1 float)
    • Chamber and sector match flags (2 flags)

The tool then runs an ONNX model (typically a GIN or GCN variant) to produce a logit or probability for each edge, enabling downstream algorithms to filter low-quality segment associations and improve reconstruction efficiency.

Key difference from GraphBucketFilterTool: operates at segment (edge) level rather than bucket (node) level, and the interface uses discrete graph structures (SegmentEdgeGraph) rather than tensors for input/output.

Note: runGraphInference() is not supported by this tool; use SegmentEdgeInferenceAlg and the ISegmentEdgeClassifierTool methods instead.

Definition at line 59 of file SegmentEdgeClassifierTool.h.

Member Function Documentation

◆ buildFeaturesOnly()

StatusCode BucketInferenceToolBase::buildFeaturesOnly ( const EventContext & ctx,
GraphRawData & graphData ) const
protectedinherited

Build only features (N,6); attaches one tensor in graph.dataTensor[0].

Definition at line 89 of file BucketInferenceToolBase.cxx.

90 {
91 graphData.graph = std::make_unique<InferenceGraph>();
92 graphData.srcEdges.clear();
93 graphData.desEdges.clear();
94 graphData.featureLeaves.clear();
95 graphData.spacePointsInBucket.clear();
96
97 const MuonR4::SpacePointContainer* buckets{nullptr};
98 ATH_CHECK(SG::get(buckets, m_readKey, ctx));
99
100 const ActsTrk::GeometryContext* gctx = nullptr;
101 ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
102
103 std::vector<BucketGraphUtils::NodeAux> nodes;
104 BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes,
105 graphData.featureLeaves,
106 graphData.spacePointsInBucket); // now int64_t-compatible
107
108 const int64_t numNodes = static_cast<int64_t>(nodes.size());
109 ATH_MSG_DEBUG("Total buckets: " << buckets->size()
110 << " -> nodes (size>0): " << numNodes
111 << " | features.size()=" << graphData.featureLeaves.size());
112
113 if (numNodes == 0) {
114 ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping inference.");
115 return StatusCode::SUCCESS;
116 }
117
118 const int64_t nFeatPerNode = static_cast<int64_t>(kBucketFeatureCount);
119 if (numNodes * nFeatPerNode != static_cast<int64_t>(graphData.featureLeaves.size())) {
120 ATH_MSG_ERROR( "Feature size mismatch: expected " << (numNodes * nFeatPerNode)
121 << " got " << graphData.featureLeaves.size());
122 return StatusCode::FAILURE;
123 }
124
125 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
126 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
127 graphData.graph->dataTensor.emplace_back(
128 Ort::Value::CreateTensor<float>(memInfo,
129 graphData.featureLeaves.data(),
130 graphData.featureLeaves.size(),
131 featShape.data(),
132 featShape.size()));
133 return StatusCode::SUCCESS;
134}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
size_type size() const noexcept
Returns the number of elements in the collection.
static constexpr std::size_t kBucketFeatureCount
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
SG::ReadHandleKey< ActsTrk::GeometryContext > m_geoCtxKey
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36

◆ buildGraph() [1/2]

StatusCode BucketInferenceToolBase::buildGraph ( const EventContext & ctx,
GraphRawData & graphData ) const
inherited

GNN-style graph builder (features + edges). Kept for tools that want it.

Definition at line 196 of file BucketInferenceToolBase.cxx.

197 {
198 ATH_CHECK(buildFeaturesOnly(ctx, graphData));
199
200 const MuonR4::SpacePointContainer* buckets{nullptr};
201 ATH_CHECK(SG::get(buckets, m_readKey, ctx));
202
203 const ActsTrk::GeometryContext* gctx = nullptr;
204 ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
205
206 std::vector<BucketGraphUtils::NodeAux> nodes;
207 std::vector<float> throwawayFeatures;
208 std::vector<int64_t> throwawaySp; // int64_t
209 BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes, throwawayFeatures, throwawaySp);
210
211 const int64_t numNodes = static_cast<int64_t>(nodes.size());
212 if (numNodes == 0) {
213 ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping graph building.");
214 return StatusCode::SUCCESS;
215 }
216
217 std::vector<int64_t> srcEdges, dstEdges;
224 srcEdges, dstEdges);
225 if (m_validateEdges) {
226 size_t bad = 0;
227 std::vector<int64_t> newSrc;
228 std::vector<int64_t> newDst;
229 newSrc.reserve(srcEdges.size());
230 newDst.reserve(dstEdges.size());
231 for (size_t k = 0; k < srcEdges.size(); ++k) {
232 const int64_t u = srcEdges[k];
233 const int64_t v = dstEdges[k];
234 const bool okU = (u >= 0 && u < numNodes);
235 const bool okV = (v >= 0 && v < numNodes);
236 if (okU && okV) {
237 newSrc.push_back(u);
238 newDst.push_back(v);
239 } else {
240 ++bad;
241 ATH_MSG_DEBUG( "Drop invalid edge " << k << ": (" << u << "->" << v
242 << "), valid node range [0," << (numNodes-1) << "]");
243 }
244 }
245 if (bad) {
246 ATH_MSG_WARNING( "Removed " << bad << " invalid edges out of "
247 << srcEdges.size());
248 srcEdges.swap(newSrc);
249 dstEdges.swap(newDst);
250 }
251 }
252
253 const size_t E = srcEdges.size();
254
255 if (msgLvl(MSG::DEBUG)) {
256 // DEBUG: Count connections per node
257 ATH_MSG_DEBUG("Edges built: " << E);
258 const unsigned int dumpE = std::min<unsigned int>(m_debugDumpFirstNEdges, E);
259 for (unsigned int k = 0; k < dumpE; ++k) {
260 ATH_MSG_DEBUG("EDGE[" << k << "]: " << srcEdges[k] << " -> " << dstEdges[k]);
261 }
262 std::vector<int> nodeConnections(numNodes, 0);
263 for (size_t k = 0; k < srcEdges.size(); ++k) {
264 const int64_t u = srcEdges[k];
265 const int64_t v = dstEdges[k];
266 if (u >= 0 && u < numNodes) nodeConnections[u]++;
267 if (v >= 0 && v < numNodes) nodeConnections[v]++;
268 }
269
270 ATH_MSG_INFO("=== DEBUGGING: Node Connections (first 10 nodes) ===");
271 const int64_t debugNodeCount = std::min(numNodes, static_cast<int64_t>(10));
272 for (int64_t i = 0; i < debugNodeCount; ++i) {
273 ATH_MSG_DEBUG("Node[" << i << "] connections: " << nodeConnections[i]);
274 }
275 ATH_MSG_DEBUG("=== END DEBUG NODE CONNECTIONS ===");
276
277 // DEBUG: Show detailed edge connections for first 10 nodes
278 ATH_MSG_DEBUG("=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
279 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
280 std::stringstream connections;
281 connections << "Node[" << nodeIdx << "] connected to: ";
282 bool foundAny = false;
283
284 for (size_t k = 0; k < srcEdges.size(); ++k) {
285 const int64_t u = srcEdges[k];
286 const int64_t v = dstEdges[k];
287
288 if (u == nodeIdx) {
289 if (foundAny) connections << ", ";
290 connections << v;
291 foundAny = true;
292 } else if (v == nodeIdx) {
293 if (foundAny) connections << ", ";
294 connections << u;
295 foundAny = true;
296 }
297 }
298
299 if (!foundAny) connections << "none";
300 ATH_MSG_DEBUG(connections.str());
301 }
302 ATH_MSG_DEBUG("=== END DEBUG DETAILED CONNECTIONS ===");
303 }
304
305 graphData.edgeIndexPacked.clear();
306 const size_t Efinal = BucketGraphUtils::packEdgeIndex(srcEdges, dstEdges, graphData.edgeIndexPacked);
307
308 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
309 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(Efinal)};
310 graphData.graph->dataTensor.emplace_back(
311 Ort::Value::CreateTensor<int64_t>(memInfo,
312 graphData.edgeIndexPacked.data(),
313 graphData.edgeIndexPacked.size(),
314 edgeShape.data(),
315 edgeShape.size()));
316
317 ATH_MSG_DEBUG("Built sparse bucket graph: N=" << numNodes << ", E=" << Efinal);
318 return StatusCode::SUCCESS;
319}
#define ATH_MSG_INFO(x)
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< double > m_maxDistXY
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
@ u
Enums for curvilinear frames.
Definition ParamDefs.h:77
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition GraphData.h:42

◆ buildGraph() [2/2]

StatusCode MuonML::SegmentEdgeClassifierTool::buildGraph ( const EventContext & ctx,
const xAOD::MuonSegmentContainer & segments,
SegmentEdgeGraph & graph ) const
overridevirtual

Build a GNN graph from segments, computing node and edge features and storing the graph structure in graph.

Implements MuonML::ISegmentEdgeClassifierTool.

Definition at line 169 of file SegmentEdgeClassifierTool.cxx.

169 {
170 graph = SegmentEdgeGraph{};
171 graph.nNodes = segments.size();
172 graph.segments.reserve(graph.nNodes);
173 graph.nodeFeatures.reserve(graph.nNodes * kNodeFeatureCount);
174
175 std::vector<Amg::Vector3D> pos, dir;
176 std::vector<BucketSegmentFeatures> bucket;
177 pos.reserve(graph.nNodes); dir.reserve(graph.nNodes); bucket.reserve(graph.nNodes);
178
179 std::map<SegmentGroupKey, int> segmentMultiplicity{};
180 for (const xAOD::MuonSegment* seg : segments) {
181 if (!seg) continue;
182 ++segmentMultiplicity[segmentGroupKey(*seg)];
183 }
184
185 for (const xAOD::MuonSegment* seg : segments) {
186 if (!seg) continue;
187 const Amg::Vector3D p = seg->position();
188 Amg::Vector3D d = seg->direction();
189
190 const int chamberIdx = static_cast<int>(seg->chamberIndex());
191 const int layers = segmentLayerCount(*seg);
192 const int sec = seg->sector();
193 const auto multIt = segmentMultiplicity.find(segmentGroupKey(*seg));
194 const int nSeg = (multIt != segmentMultiplicity.end()) ? multIt->second : 1;
195
196 graph.segments.push_back(seg);
197 pos.emplace_back(p.x() / Gaudi::Units::m,
198 p.y() / Gaudi::Units::m,
199 p.z() / Gaudi::Units::m);
200 dir.emplace_back(d.x(), d.y(), d.z());
201 bucket.emplace_back(BucketSegmentFeatures{chamberIdx, layers, sec, nSeg});
202 for (const SegmentNodeFeatureId featureId : m_nodeFeatureIds) {
203 graph.nodeFeatures.push_back(nodeFeatureValue(featureId, pos.back(), dir.back(), bucket.back()));
204 }
205 }
206 graph.nNodes = graph.segments.size();
207
208 // Consistency check: all vectors must have same size
209 if (pos.size() != graph.nNodes || dir.size() != graph.nNodes || bucket.size() != graph.nNodes) {
210 ATH_MSG_ERROR("Inconsistent vector sizes during graph building: nodes=" << graph.nNodes
211 << ", pos=" << pos.size() << ", dir=" << dir.size() << ", bucket=" << bucket.size());
212 return StatusCode::FAILURE;
213 }
214
215 if (graph.nNodes < 2) {
216 graph.nEdges = 0;
217 return StatusCode::SUCCESS;
218 }
219
220 std::unordered_map<int, std::vector<std::size_t>> nodesBySector;
221 nodesBySector.reserve(graph.nNodes);
222 for (std::size_t i = 0; i < graph.nNodes; ++i) {
223 nodesBySector[bucket[i].sector].push_back(i);
224 }
225
226 auto normalizeSector = [&](int s) {
227 // m_sectorModulo > 0: wrap sector to [0, modulo); <=0: disable wrapping
228 if (m_sectorModulo.value() > 0) {
229 s %= m_sectorModulo.value();
230 if (s < 0) s += m_sectorModulo.value();
231 }
232 return s;
233 };
234
235 const std::size_t maxEdges = graph.nNodes * (graph.nNodes - 1);
236 graph.edgeIndex.reserve(2 * maxEdges);
237 graph.edgeFeatures.reserve(kEdgeFeatureCount * maxEdges);
238
239 for (std::size_t i = 0; i < graph.nNodes; ++i) {
240 std::unordered_set<int> targetSectors;
241 targetSectors.reserve(2 * m_maxDeltaSector.value() + 1);
242 for (int delta = -m_maxDeltaSector.value(); delta <= m_maxDeltaSector.value(); ++delta) {
243 targetSectors.insert(normalizeSector(bucket[i].sector + delta));
244 }
245
246 for (const int sec : targetSectors) {
247 auto it = nodesBySector.find(sec);
248 if (it == nodesBySector.end()) continue;
249 for (const std::size_t j : it->second) {
250 if (i == j) continue;
251 if (sectorDistance(bucket[i].sector, bucket[j].sector, m_sectorModulo.value()) > m_maxDeltaSector.value()) continue;
252 const float cosang = static_cast<float>(dir[i].dot(dir[j]));
253 if (cosang < m_cosMin) continue;
254
255 graph.edgeIndex.push_back(static_cast<int64_t>(i));
256 graph.edgeIndex.push_back(static_cast<int64_t>(j));
257
258 const Amg::Vector3D delta = pos[j] - pos[i];
259 const float dx = static_cast<float>(delta.x());
260 const float dy = static_cast<float>(delta.y());
261 const float dz = static_cast<float>(delta.z());
262 const float dist = static_cast<float>(delta.mag());
263 graph.edgeFeatures.insert(graph.edgeFeatures.end(), {dx,dy,dz,dist,cosang, float(bucket[i].chamberIndex==bucket[j].chamberIndex), float(bucket[i].sector==bucket[j].sector)});
264 }
265 }
266 }
267 graph.nEdges = graph.edgeIndex.size() / 2;
268 ATH_MSG_DEBUG("buildGraph: input segments=" << segments.size()
269 << ", kept nodes=" << graph.nNodes
270 << ", built edges=" << graph.nEdges);
271 return StatusCode::SUCCESS;
272}
static constexpr std::size_t kEdgeFeatureCount
static constexpr std::size_t kNodeFeatureCount
std::vector< SegmentNodeFeatureId > m_nodeFeatureIds
Eigen::Matrix< double, 3, 1 > Vector3D
layers(flags, cells_name, *args, **kw)
Here we define wrapper functions to set up all of the standard corrections.
SegmentNodeFeatureId
Identifier for each node feature in segment-based GNNs.
Definition MuonMLEvent.h:28
float j(const xAOD::IParticle &, const xAOD::TrackMeasurementValidation &hit, const Eigen::Matrix3d &jab_inv)
MuonSegment_v1 MuonSegment
Reference the current persistent version:

◆ buildTransformerInputs()

StatusCode BucketInferenceToolBase::buildTransformerInputs ( const EventContext & ctx,
GraphRawData & graphData ) const
protectedinherited

Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.

Definition at line 136 of file BucketInferenceToolBase.cxx.

137 {
138 // Start from (N,6)
139 ATH_CHECK(buildFeaturesOnly(ctx, graphData));
140
141 // Copy features flat buffer for lifetime management
142 std::vector<float> featuresFlat = graphData.featureLeaves;
143 const int64_t S = static_cast<int64_t>(featuresFlat.size() / kBucketFeatureCount);
144
145 if (S == 0) {
146 ATH_MSG_WARNING("No valid features for transformer input. Skipping inference.");
147 return StatusCode::SUCCESS;
148 }
149
150 if (msgLvl(MSG::DEBUG)) {
151 // DEBUG: Print transformer input features for first 10 nodes
152 ATH_MSG_DEBUG("=== DEBUGGING: Transformer input features for first 10 nodes ===");
153 const int64_t debugNodes = std::min(S, static_cast<int64_t>(10));
154 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
155 const int64_t baseIdx = nodeIdx * static_cast<int64_t>(kBucketFeatureCount);
156 ATH_MSG_DEBUG("TransformerNode[" << nodeIdx << "]: "
157 << "x=" << featuresFlat[baseIdx + 0] << ", "
158 << "y=" << featuresFlat[baseIdx + 1] << ", "
159 << "z=" << featuresFlat[baseIdx + 2] << ", "
160 << "layers=" << featuresFlat[baseIdx + 3] << ", "
161 << "nSp=" << featuresFlat[baseIdx + 4] << ", "
162 << "bucketSize=" << featuresFlat[baseIdx + 5]);
163 }
164 ATH_MSG_DEBUG("=== END DEBUG TRANSFORMER FEATURES ===");
165 }
166
167 // Rebuild graph with exactly 2 inputs: features [1,S,6], pad_mask [1,S]
168 graphData.graph = std::make_unique<InferenceGraph>();
169
170 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
171
172 // features: [1,S,6] (backed by graphData.featureLeaves to keep alive)
173 std::vector<int64_t> fShape{1, S, static_cast<int64_t>(kBucketFeatureCount)};
174 graphData.featureLeaves.swap(featuresFlat);
175 graphData.graph->dataTensor.emplace_back(
176 Ort::Value::CreateTensor<float>(memInfo,
177 graphData.featureLeaves.data(),
178 graphData.featureLeaves.size(),
179 fShape.data(),
180 fShape.size()));
181
182 // pad_mask: [1,S] (bool). Create ORT-owned tensor and fill with False (=valid).
183 Ort::AllocatorWithDefaultOptions allocator;
184 std::vector<int64_t> mShape{1, S};
185 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
186 mShape.data(),
187 mShape.size(),
188 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
189 bool* maskPtr = padVal.GetTensorMutableData<bool>();
190 for (int64_t i = 0; i < S; ++i) maskPtr[i] = false;
191 graphData.graph->dataTensor.emplace_back(std::move(padVal));
192
193 return StatusCode::SUCCESS;
194}

◆ classifyEdges()

StatusCode MuonML::SegmentEdgeClassifierTool::classifyEdges ( const EventContext & ctx,
const SegmentEdgeGraph & graph,
std::vector< SegmentEdgeScore > & scores ) const
overridevirtual

Run ONNX inference on graph and populate scores with logit and probability for each edge; called after buildGraph().

Implements MuonML::ISegmentEdgeClassifierTool.

Definition at line 274 of file SegmentEdgeClassifierTool.cxx.

274 {
275 scores.clear();
276 if (!graph.nEdges) return StatusCode::SUCCESS;
277
278 if (graph.nodeFeatures.size() != graph.nNodes * kNodeFeatureCount) {
279 ATH_MSG_ERROR("Unexpected node feature size " << graph.nodeFeatures.size()
280 << "; expected " << (graph.nNodes * kNodeFeatureCount));
281 return StatusCode::FAILURE;
282 }
283 if (graph.edgeIndex.size() != 2 * graph.nEdges) {
284 ATH_MSG_ERROR("Unexpected edge index size " << graph.edgeIndex.size()
285 << "; expected " << (2 * graph.nEdges));
286 return StatusCode::FAILURE;
287 }
288 if (graph.edgeFeatures.size() != graph.nEdges * kEdgeFeatureCount) {
289 ATH_MSG_ERROR("Unexpected edge feature size " << graph.edgeFeatures.size()
290 << "; expected " << (graph.nEdges * kEdgeFeatureCount));
291 return StatusCode::FAILURE;
292 }
293
294 GraphRawData raw{};
295 raw.graph = std::make_unique<InferenceGraph>();
296 raw.featureLeaves = graph.nodeFeatures;
297 raw.edgeIndexPacked.reserve(2 * graph.nEdges);
298 raw.srcEdges.reserve(graph.nEdges);
299 raw.desEdges.reserve(graph.nEdges);
300 for (std::size_t e = 0; e < graph.nEdges; ++e) {
301 raw.srcEdges.push_back(graph.edgeIndex[2 * e]);
302 raw.desEdges.push_back(graph.edgeIndex[2 * e + 1]);
303 }
304 raw.edgeIndexPacked.insert(raw.edgeIndexPacked.end(), raw.srcEdges.begin(), raw.srcEdges.end());
305 raw.edgeIndexPacked.insert(raw.edgeIndexPacked.end(), raw.desEdges.begin(), raw.desEdges.end());
306
307 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
308
309 const std::vector<int64_t> nodeShape{static_cast<int64_t>(graph.nNodes), static_cast<int64_t>(kNodeFeatureCount)};
310 raw.graph->dataTensor.emplace_back(
311 Ort::Value::CreateTensor<float>(memInfo,
312 raw.featureLeaves.data(),
313 raw.featureLeaves.size(),
314 nodeShape.data(),
315 nodeShape.size()));
316
317 const std::vector<int64_t> edgeIndexShape{2, static_cast<int64_t>(graph.nEdges)};
318 raw.graph->dataTensor.emplace_back(
319 Ort::Value::CreateTensor<int64_t>(memInfo,
320 raw.edgeIndexPacked.data(),
321 raw.edgeIndexPacked.size(),
322 edgeIndexShape.data(),
323 edgeIndexShape.size()));
324
325 // ONNX Runtime's CreateTensor API takes a non-const pointer, but it does not
326 // mutate input buffers during inference. Avoid copying edge_attr every event.
327 ATLAS_THREAD_SAFE float* edgeFeaturesData = const_cast<float*>(graph.edgeFeatures.data());
328 const std::vector<int64_t> edgeAttrShape{static_cast<int64_t>(graph.nEdges), static_cast<int64_t>(kEdgeFeatureCount)};
329 raw.graph->dataTensor.emplace_back(
330 Ort::Value::CreateTensor<float>(memInfo,
331 edgeFeaturesData,
332 graph.edgeFeatures.size(),
333 edgeAttrShape.data(),
334 edgeAttrShape.size()));
335
336 const std::vector<const char*> inputNames{
337 m_inputNodeName.value().c_str(),
338 m_inputEdgeIndexName.value().c_str(),
339 m_inputEdgeAttrName.value().c_str()};
340 const std::vector<const char*> outputNames{m_outputName.value().c_str()};
341 ATH_MSG_DEBUG("classifyEdges: ONNX inputs shapes x=[" << nodeShape[0] << "," << nodeShape[1]
342 << "], edge_index=[" << edgeIndexShape[0] << "," << edgeIndexShape[1]
343 << "], edge_attr=[" << edgeAttrShape[0] << "," << edgeAttrShape[1] << "]");
344 ATH_CHECK(runNamedInference(raw, inputNames, outputNames));
345
346 if (raw.graph->dataTensor.size() <= inputNames.size()) {
347 ATH_MSG_ERROR("Missing ONNX output tensor for segment edge inference");
348 return StatusCode::FAILURE;
349 }
350
351 const Ort::Value& outTensor = raw.graph->dataTensor[inputNames.size()];
352 const auto outInfo = outTensor.GetTensorTypeAndShapeInfo();
353 const std::vector<int64_t> outShape = outInfo.GetShape();
354 const size_t outSize = outInfo.GetElementCount();
355 if (!outShape.empty()) {
356 ATH_MSG_DEBUG("classifyEdges: ONNX output rank=" << outShape.size()
357 << ", first dim=" << outShape.front()
358 << ", elements=" << outSize);
359 } else {
360 ATH_MSG_DEBUG("classifyEdges: ONNX scalar output, elements=" << outSize);
361 }
362 if (outSize < graph.nEdges) {
363 ATH_MSG_ERROR("ONNX logits tensor has " << outSize << " entries for " << graph.nEdges << " edges");
364 return StatusCode::FAILURE;
365 }
366
367 const float* logits = outTensor.GetTensorData<float>();
368 scores.reserve(graph.nEdges);
369 for (std::size_t e=0; e<graph.nEdges; ++e) {
370 const float l = logits[e];
371 scores.push_back({std::size_t(graph.edgeIndex[2 * e]), std::size_t(graph.edgeIndex[2 * e + 1]), l, sigmoid(l)});
372 }
373 return StatusCode::SUCCESS;
374}
#define ATLAS_THREAD_SAFE
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
Gaudi::Property< std::string > m_outputName
Gaudi::Property< std::string > m_inputEdgeAttrName
Gaudi::Property< std::string > m_inputEdgeIndexName
Gaudi::Property< std::string > m_inputNodeName
Double_t sigmoid(Double_t x)
l
Printing final latex table to .tex output file.

◆ DeclareInterfaceID()

MuonML::ISegmentEdgeClassifierTool::DeclareInterfaceID ( ISegmentEdgeClassifierTool ,
1 ,
0  )
inherited

◆ initialize()

StatusCode MuonML::SegmentEdgeClassifierTool::initialize ( )
override

Retrieve the ONNX model and resolve node feature ordering from metadata.

Definition at line 80 of file SegmentEdgeClassifierTool.cxx.

80 {
82
83 // Resolve node feature names from model metadata, matching the ONNX exporter.
84 {
85 Ort::AllocatorWithDefaultOptions allocator;
86 Ort::ModelMetadata meta = model().GetModelMetadata();
87 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
88 std::vector<std::string> keyList;
89 keyList.reserve(keys.size());
90 for (const auto& k : keys) keyList.emplace_back(k.get());
91
92 constexpr std::array<std::string_view, 4> candidates{
93 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
94 std::string usedKey;
95 std::vector<std::string> names;
96 for (std::string_view key : candidates) {
97 const std::string keyStr{key};
98 if (std::find(keyList.begin(), keyList.end(), keyStr) == keyList.end()) continue;
99 names = parseFeatureNames(meta.LookupCustomMetadataMapAllocated(keyStr.c_str(), allocator).get());
100 if (!names.empty()) {
101 usedKey = keyStr;
102 break;
103 }
104 }
105
106 if (names.empty()) {
108 ATH_MSG_WARNING("Model metadata has no usable node feature name key"
109 " (tried x_feature_names/node_feature_names/feature_names/input_feature_names)."
110 " Falling back to default training order.");
111 } else {
112 if (names.size() != kNodeFeatureCount) {
113 ATH_MSG_ERROR("Model metadata key '" << usedKey << "' has " << names.size()
114 << " features, expected " << kNodeFeatureCount);
115 return StatusCode::FAILURE;
116 }
117 for (const std::string& n : names) {
118 if (!nodeFeatureIdFromName(n).has_value()) {
119 ATH_MSG_ERROR("Unsupported node feature name in model metadata ('" << usedKey
120 << "'): '" << n << "'."
121 " Add mapping in SegmentEdgeClassifierTool::nodeFeatureValue().");
122 return StatusCode::FAILURE;
123 }
124 }
125 m_nodeFeatureNames = std::move(names);
126 ATH_MSG_DEBUG("Using node feature names from model metadata key '" << usedKey << "'.");
127 }
128
129 m_nodeFeatureIds.reserve(m_nodeFeatureNames.size());
130 for (const std::string& n : m_nodeFeatureNames) {
131 const auto id = nodeFeatureIdFromName(n);
132 if (!id.has_value()) {
133 ATH_MSG_ERROR("Internal feature-id resolution failed for node feature name '" << n << "'.");
134 return StatusCode::FAILURE;
135 }
136 m_nodeFeatureIds.push_back(*id);
137 }
138
139 std::ostringstream order;
140 order << "Node feature order:";
141 for (std::size_t i = 0; i < m_nodeFeatureNames.size(); ++i) {
142 order << " f" << i << "=" << m_nodeFeatureNames[i];
143 if (i + 1 < m_nodeFeatureNames.size()) order << ",";
144 }
145 ATH_MSG_DEBUG(order.str());
146 }
147
149 ATH_MSG_ERROR("Internal node feature setup has " << m_nodeFeatureNames.size()
150 << " entries, expected " << kNodeFeatureCount);
151 return StatusCode::FAILURE;
152 }
153 if (m_nodeFeatureIds.size() != kNodeFeatureCount) {
154 ATH_MSG_ERROR("Internal node feature id setup has " << m_nodeFeatureIds.size()
155 << " entries, expected " << kNodeFeatureCount);
156 return StatusCode::FAILURE;
157 }
158
159 m_cosMin = std::cos(m_maxDeltaThetaDeg.value() * Gaudi::Units::deg);
160
161 return StatusCode::SUCCESS;
162}
static constexpr std::array< std::string_view, kNodeFeatureCount > kDefaultNodeFeatureNames
static std::vector< std::string > parseFeatureNames(const std::string &raw)
std::vector< std::string > m_nodeFeatureNames
Node feature order expected by the model metadata (resolved at initialize).
order
Configure Herwig7.

◆ model()

Ort::Session & BucketInferenceToolBase::model ( ) const
protectedinherited

Definition at line 65 of file BucketInferenceToolBase.cxx.

65 {
66 return m_onnxSessionTool->session();
67}
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool

◆ parseFeatureNames()

std::vector< std::string > BucketInferenceToolBase::parseFeatureNames ( const std::string & raw)
staticprotectedinherited

Definition at line 32 of file BucketInferenceToolBase.cxx.

32 {
33 std::vector<std::string> out;
34 const std::string s = trimFeatureToken(raw);
35 if (s.empty()) return out;
36
37 // Preferred exporter format: JSON list of strings.
38 if (!s.empty() && s.front() == '[') {
39 bool inQuote = false;
40 std::string token;
41 for (char c : s) {
42 if (c == '"') {
43 if (inQuote) {
44 if (!token.empty()) out.push_back(token);
45 token.clear();
46 }
47 inQuote = !inQuote;
48 continue;
49 }
50 if (inQuote) token.push_back(c);
51 }
52 if (!out.empty()) return out;
53 }
54
55 // Backward-compatible format: comma-separated.
56 std::istringstream ss(s);
57 std::string tok;
58 while (std::getline(ss, tok, ',')) {
59 tok = trimFeatureToken(tok);
60 if (!tok.empty()) out.push_back(tok);
61 }
62 return out;
63}
static Double_t ss
static std::string trimFeatureToken(std::string s)

◆ runGraphInference()

StatusCode MuonML::SegmentEdgeClassifierTool::runGraphInference ( const EventContext & ctx,
GraphRawData & graphData ) const
override

Not supported by this tool; returns FAILURE.

Use SegmentEdgeInferenceAlg + buildGraph() + classifyEdges() instead.

Definition at line 164 of file SegmentEdgeClassifierTool.cxx.

164 {
165 ATH_MSG_ERROR("runGraphInference is not supported by SegmentEdgeClassifierTool. Use SegmentEdgeInferenceAlg + ISegmentEdgeClassifierTool methods.");
166 return StatusCode::FAILURE;
167}

◆ runInference()

StatusCode BucketInferenceToolBase::runInference ( GraphRawData & graphData) const
inherited

Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.

Definition at line 485 of file BucketInferenceToolBase.cxx.

485 {
486 std::vector<const char*> inputNames = {"features", "edge_index"};
487 std::vector<const char*> outputNames = {m_outputName.value().c_str()};
488 return runNamedInference(graphData, inputNames, outputNames);
489}
Gaudi::Property< std::string > m_outputName

◆ runNamedInference()

StatusCode BucketInferenceToolBase::runNamedInference ( GraphRawData & graphData,
const std::vector< const char * > & inputNames,
const std::vector< const char * > & outputNames ) const
protectedinherited

Generic named inference, for tools with different I/O conventions.

Definition at line 321 of file BucketInferenceToolBase.cxx.

325{
326 if (!graphData.graph) {
327 ATH_MSG_ERROR("Graph data is not built.");
328 return StatusCode::FAILURE;
329 }
330 if (graphData.graph->dataTensor.empty()) {
331 ATH_MSG_ERROR("No input tensors prepared for inference.");
332 return StatusCode::FAILURE;
333 }
334
335 if (msgLvl(MSG::DEBUG)) {
336 // DEBUG: Print actual input tensor data for features tensor
337
338 ATH_MSG_DEBUG("=== DEBUGGING: ONNX Input tensor data ===");
339 if (!graphData.graph->dataTensor.empty()) {
340 const auto& featureTensor = graphData.graph->dataTensor[0];
341 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
342 ATH_MSG_DEBUG("Features tensor shape: [" << featShape[0]
343 << (featShape.size()>1 ? ("," + std::to_string(featShape[1])) : "")
344 << (featShape.size()>2 ? ("," + std::to_string(featShape[2])) : "") << "]");
345
346 float* featData = const_cast<Ort::Value&>(featureTensor).GetTensorMutableData<float>();
347 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
348 ATH_MSG_DEBUG("Features tensor total elements: " << totalElements);
349
350 // Print up to 10 nodes; stride = nFeat from tensor shape
351 const size_t nFeat = (featShape.size() > 1 && featShape[1] > 0) ? static_cast<size_t>(featShape[1]) : 1;
352 const size_t nNodes = totalElements / nFeat;
353 const size_t debugNodes = std::min(nNodes, static_cast<size_t>(10));
354
355 // Try to read feature names from model custom metadata.
356 // Prefer x_feature_names (current exporter), then fall back to legacy keys.
357 std::vector<std::string> featNames;
358 {
359 Ort::AllocatorWithDefaultOptions allocator;
360 Ort::ModelMetadata meta = model().GetModelMetadata();
361 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
362 std::vector<std::string> keyNames;
363 keyNames.reserve(keys.size());
364 for (const auto& k : keys) keyNames.emplace_back(k.get());
365 const std::array<std::string, 4> candidates{
366 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
367 for (const std::string& key : candidates) {
368 if (std::find(keyNames.begin(), keyNames.end(), key) != keyNames.end()) {
369 std::string val = meta.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get();
370 featNames = parseFeatureNames(val);
371 break;
372 }
373 }
374 if (featNames.empty()) {
375 ATH_MSG_DEBUG("No usable feature-name metadata key found in model; using generic fN labels.");
376 }
377 }
378 auto featLabel = [&](size_t f) -> std::string {
379 if (f < featNames.size()) return featNames[f];
380 return "f" + std::to_string(f);
381 };
382
383 // Print legend
384 {
385 std::ostringstream legend;
386 legend << "Node feature legend (" << nFeat << " features):";
387 for (size_t f = 0; f < nFeat; ++f) {
388 legend << " f" << f << "=" << featLabel(f);
389 if (f + 1 < nFeat) legend << ",";
390 }
391 ATH_MSG_DEBUG(legend.str());
392 }
393
394 for (size_t n = 0; n < debugNodes; ++n) {
395 std::ostringstream row;
396 row << "ONNXNode[" << n << "]:";
397 for (size_t f = 0; f < nFeat; ++f) {
398 row << " f" << f << "=" << featData[n * nFeat + f];
399 if (f + 1 < nFeat) row << ",";
400 }
401 ATH_MSG_DEBUG(row.str());
402 }
403 }
404 ATH_MSG_DEBUG("=== END DEBUG ONNX INPUT ===");
405 }
406
407 Ort::RunOptions run_options;
408 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
409
410 if (m_isCuda) {
411 // ---- CUDA path: use IoBinding so tensors stay on device ----
412 Ort::IoBinding binding(model());
413 for (std::size_t i = 0; i < inputNames.size(); ++i) {
414 binding.BindInput(inputNames[i], graphData.graph->dataTensor[i]);
415 }
416 // Bind outputs to CPU so predictions are directly readable after sync.
417 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
418 for (const char* outName : outputNames) {
419 binding.BindOutput(outName, cpuOut);
420 }
421
422 model().Run(run_options, binding);
423 binding.SynchronizeOutputs();
424
425 std::vector<Ort::Value> outputs = binding.GetOutputValues();
426 if (outputs.empty()) {
427 ATH_MSG_ERROR("IoBinding inference returned empty output.");
428 return StatusCode::FAILURE;
429 }
430
431 float* outData = outputs[0].GetTensorMutableData<float>();
432 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
433 ATH_MSG_DEBUG("ONNX (IoBinding) raw output elementCount = " << outSize);
434
435 if (m_sanitizeNonFinitePredictions.value()) {
436 std::span<float> preds(outData, outData + outSize);
437 for (size_t i = 0; i < outSize; ++i) {
438 if (!std::isfinite(preds[i])) {
439 ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
440 preds[i] = -100.0f;
441 }
442 }
443 }
444
445 for (auto& v : outputs) {
446 graphData.graph->dataTensor.emplace_back(std::move(v));
447 }
448 return StatusCode::SUCCESS;
449 }
450
451 // ---- CPU path ----
452 std::vector<Ort::Value> outputs =
453 model().Run(run_options,
454 inputNames.data(),
455 graphData.graph->dataTensor.data(),
456 graphData.graph->dataTensor.size(),
457 outputNames.data(),
458 outputNames.size());
459
460 if (outputs.empty()) {
461 ATH_MSG_ERROR("Inference returned empty output.");
462 return StatusCode::FAILURE;
463 }
464
465 float* outData = outputs[0].GetTensorMutableData<float>();
466 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
467 ATH_MSG_DEBUG("ONNX raw output elementCount = " << outSize);
468
469 if (m_sanitizeNonFinitePredictions.value()) {
470 std::span<float> preds(outData, outData + outSize);
471 for (size_t i = 0; i < outSize; ++i) {
472 if (!std::isfinite(preds[i])) {
473 ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
474 preds[i] = -100.0f;
475 }
476 }
477 }
478
479 for (auto& v : outputs) {
480 graphData.graph->dataTensor.emplace_back(std::move(v));
481 }
482 return StatusCode::SUCCESS;
483}
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
row
Appending html table to final .html summary file.

◆ setupModel()

StatusCode BucketInferenceToolBase::setupModel ( )
protectedinherited

Definition at line 69 of file BucketInferenceToolBase.cxx.

69 {
70 ATH_CHECK(m_onnxSessionTool.retrieve());
71 ATH_CHECK(m_readKey.initialize());
72 ATH_CHECK(m_geoCtxKey.initialize());
73
74 // Detect CUDA provider by dynamic-casting the concrete session tool.
75 if (const auto* cudaTool = dynamic_cast<const AthOnnx::OnnxRuntimeSessionToolCUDA*>(
76 m_onnxSessionTool.get())) {
77 m_isCuda = true;
78 m_cudaDeviceId = cudaTool->deviceId();
79 ATH_MSG_INFO("ONNX session is running on CUDA device " << m_cudaDeviceId
80 << ". I/O binding will be used.");
81 } else {
82 m_isCuda = false;
83 ATH_MSG_INFO("ONNX session is running on CPU.");
84 }
85
86 return StatusCode::SUCCESS;
87}

◆ trimFeatureToken()

std::string BucketInferenceToolBase::trimFeatureToken ( std::string s)
staticprotectedinherited

Definition at line 25 of file BucketInferenceToolBase.cxx.

25 {
26 auto notSpace = [](unsigned char c) { return !std::isspace(c); };
27 s.erase(s.begin(), std::find_if(s.begin(), s.end(), notSpace));
28 s.erase(std::find_if(s.rbegin(), s.rend(), notSpace).base(), s.end());
29 return s;
30}

Member Data Documentation

◆ kBucketFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kBucketFeatureCount = 6
staticconstexprprotectedinherited

Definition at line 53 of file BucketInferenceToolBase.h.

◆ kDefaultNodeFeatureNames

std::array<std::string_view, kNodeFeatureCount> MuonML::BucketInferenceToolBase::kDefaultNodeFeatureNames
staticconstexprprotectedinherited
Initial value:
= {
"segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
"segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
"bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"}

Definition at line 56 of file BucketInferenceToolBase.h.

56 {
57 "segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
58 "segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
59 "bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"};

◆ kEdgeFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kEdgeFeatureCount = 7
staticconstexprprotectedinherited

Definition at line 55 of file BucketInferenceToolBase.h.

◆ kNodeFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kNodeFeatureCount = 10
staticconstexprprotectedinherited

Definition at line 54 of file BucketInferenceToolBase.h.

◆ m_cosMin

float MuonML::SegmentEdgeClassifierTool::m_cosMin {0.f}
private

Definition at line 92 of file SegmentEdgeClassifierTool.h.

92{0.f};

◆ m_cudaDeviceId

int MuonML::BucketInferenceToolBase::m_cudaDeviceId {0}
protectedinherited

Definition at line 102 of file BucketInferenceToolBase.h.

102{0};

◆ m_debugDumpFirstNEdges

Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12}
protectedinherited

Definition at line 94 of file BucketInferenceToolBase.h.

94{this, "DebugDumpFirstNEdges", 12};

◆ m_debugDumpFirstNNodes

Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5}
protectedinherited

Definition at line 93 of file BucketInferenceToolBase.h.

93{this, "DebugDumpFirstNNodes", 5};

◆ m_geoCtxKey

SG::ReadHandleKey<ActsTrk::GeometryContext> MuonML::BucketInferenceToolBase::m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"}
protectedinherited

Definition at line 80 of file BucketInferenceToolBase.h.

80{this, "AlignmentKey", "ActsAlignment", "cond handle key"};

◆ m_inputEdgeAttrName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"}
private

Definition at line 90 of file SegmentEdgeClassifierTool.h.

90{this, "InputEdgeAttrName", "edge_attr"};

◆ m_inputEdgeIndexName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"}
private

Definition at line 89 of file SegmentEdgeClassifierTool.h.

89{this, "InputEdgeIndexName", "edge_index"};

◆ m_inputNodeName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_inputNodeName {this, "InputNodeName", "x"}
private

Definition at line 88 of file SegmentEdgeClassifierTool.h.

88{this, "InputNodeName", "x"};

◆ m_isCuda

bool MuonML::BucketInferenceToolBase::m_isCuda {false}
protectedinherited

Definition at line 101 of file BucketInferenceToolBase.h.

101{false};

◆ m_maxAbsDz

Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxAbsDz {this, "MaxAbsDz", 15000.0}
protectedinherited

Definition at line 90 of file BucketInferenceToolBase.h.

90{this, "MaxAbsDz", 15000.0};

◆ m_maxChamberDelta

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxChamberDelta {this, "MaxChamberDelta", 13}
protectedinherited

Definition at line 87 of file BucketInferenceToolBase.h.

87{this, "MaxChamberDelta", 13};

◆ m_maxDeltaSector

Gaudi::Property<int> MuonML::SegmentEdgeClassifierTool::m_maxDeltaSector {this, "MaxDeltaSector", 1}
private

Definition at line 86 of file SegmentEdgeClassifierTool.h.

86{this, "MaxDeltaSector", 1};

◆ m_maxDeltaThetaDeg

Gaudi::Property<float> MuonML::SegmentEdgeClassifierTool::m_maxDeltaThetaDeg {this, "MaxDeltaThetaDeg", 35.f}
private

Definition at line 85 of file SegmentEdgeClassifierTool.h.

85{this, "MaxDeltaThetaDeg", 35.f};

◆ m_maxDistXY

Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxDistXY {this, "MaxDistXY", 6800.0}
protectedinherited

Definition at line 89 of file BucketInferenceToolBase.h.

89{this, "MaxDistXY", 6800.0};

◆ m_maxSectorDelta

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxSectorDelta {this, "MaxSectorDelta", 1}
protectedinherited

Definition at line 88 of file BucketInferenceToolBase.h.

88{this, "MaxSectorDelta", 1};

◆ m_minLayers

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_minLayers {this, "MinLayersValid", 3}
protectedinherited

Definition at line 86 of file BucketInferenceToolBase.h.

86{this, "MinLayersValid", 3};

◆ m_nodeFeatureIds

std::vector<SegmentNodeFeatureId> MuonML::SegmentEdgeClassifierTool::m_nodeFeatureIds {}
private

Definition at line 96 of file SegmentEdgeClassifierTool.h.

96{};

◆ m_nodeFeatureNames

std::vector<std::string> MuonML::SegmentEdgeClassifierTool::m_nodeFeatureNames {}
private

Node feature order expected by the model metadata (resolved at initialize).

Definition at line 95 of file SegmentEdgeClassifierTool.h.

95{};

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::BucketInferenceToolBase::m_onnxSessionTool
privateinherited
Initial value:
{
this, "ModelSession", ""}

Definition at line 105 of file BucketInferenceToolBase.h.

105 {
106 this, "ModelSession", ""};

◆ m_outputName

Gaudi::Property<std::string> MuonML::SegmentEdgeClassifierTool::m_outputName {this, "OutputName", "logits"}
private

Definition at line 91 of file SegmentEdgeClassifierTool.h.

91{this, "OutputName", "logits"};

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::BucketInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protectedinherited

Definition at line 79 of file BucketInferenceToolBase.h.

79{this, "ReadSpacePoints", "MuonSpacePoints"};

◆ m_sanitizeNonFinitePredictions

Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_sanitizeNonFinitePredictions
protectedinherited
Initial value:
{
this, "SanitizeNonFinitePredictions", false,
"When true, replace non-finite ONNX outputs with -100 and log a warning."}

Definition at line 96 of file BucketInferenceToolBase.h.

96 {
97 this, "SanitizeNonFinitePredictions", false,
98 "When true, replace non-finite ONNX outputs with -100 and log a warning."};

◆ m_sectorModulo

Gaudi::Property<int> MuonML::SegmentEdgeClassifierTool::m_sectorModulo {this, "SectorModulo", 16}
private

Definition at line 87 of file SegmentEdgeClassifierTool.h.

87{this, "SectorModulo", 16};

◆ m_validateEdges

Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_validateEdges {this, "ValidateEdges", true}
protectedinherited

Definition at line 95 of file BucketInferenceToolBase.h.

95{this, "ValidateEdges", true};

The documentation for this class was generated from the following files: