![]() |
ATLAS Offline Software
|
Runs a segment-level GNN on reconstructed muon segments to classify segment-pair edges as "good" or "background". More...
#include <SegmentEdgeClassifierTool.h>
Public Member Functions | |
| StatusCode | initialize () override |
| Retrieve the ONNX model and resolve node feature ordering from metadata. | |
| StatusCode | runGraphInference (const EventContext &ctx, GraphRawData &graphData) const override |
| Not supported by this tool; returns FAILURE. | |
| StatusCode | buildGraph (const EventContext &ctx, const xAOD::MuonSegmentContainer &segments, SegmentEdgeGraph &graph) const override |
Build a GNN graph from segments, computing node and edge features and storing the graph structure in graph. | |
| StatusCode | classifyEdges (const EventContext &ctx, const SegmentEdgeGraph &graph, std::vector< SegmentEdgeScore > &scores) const override |
Run ONNX inference on graph and populate scores with logit and probability for each edge; called after buildGraph(). | |
| StatusCode | buildGraph (const EventContext &ctx, GraphRawData &graphData) const |
| GNN-style graph builder (features + edges). Kept for tools that want it. | |
| StatusCode | runInference (GraphRawData &graphData) const |
| Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}. | |
| DeclareInterfaceID (ISegmentEdgeClassifierTool, 1, 0) | |
Protected Member Functions | |
| StatusCode | setupModel () |
| Ort::Session & | model () const |
| StatusCode | buildFeaturesOnly (const EventContext &ctx, GraphRawData &graphData) const |
| Build only features (N,6); attaches one tensor in graph.dataTensor[0]. | |
| StatusCode | buildTransformerInputs (const EventContext &ctx, GraphRawData &graphData) const |
| Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1. | |
| StatusCode | runNamedInference (GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const |
| Generic named inference, for tools with different I/O conventions. | |
Static Protected Member Functions | |
| static std::string | trimFeatureToken (std::string s) |
| static std::vector< std::string > | parseFeatureNames (const std::string &raw) |
Protected Attributes | |
| SG::ReadHandleKey< MuonR4::SpacePointContainer > | m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"} |
| SG::ReadHandleKey< ActsTrk::GeometryContext > | m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"} |
| Gaudi::Property< int > | m_minLayers {this, "MinLayersValid", 3} |
| Gaudi::Property< int > | m_maxChamberDelta {this, "MaxChamberDelta", 13} |
| Gaudi::Property< int > | m_maxSectorDelta {this, "MaxSectorDelta", 1} |
| Gaudi::Property< double > | m_maxDistXY {this, "MaxDistXY", 6800.0} |
| Gaudi::Property< double > | m_maxAbsDz {this, "MaxAbsDz", 15000.0} |
| Gaudi::Property< unsigned int > | m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5} |
| Gaudi::Property< unsigned int > | m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12} |
| Gaudi::Property< bool > | m_validateEdges {this, "ValidateEdges", true} |
| Gaudi::Property< bool > | m_sanitizeNonFinitePredictions |
| bool | m_isCuda {false} |
| int | m_cudaDeviceId {0} |
Static Protected Attributes | |
| static constexpr std::size_t | kBucketFeatureCount = 6 |
| static constexpr std::size_t | kNodeFeatureCount = 10 |
| static constexpr std::size_t | kEdgeFeatureCount = 7 |
| static constexpr std::array< std::string_view, kNodeFeatureCount > | kDefaultNodeFeatureNames |
Private Attributes | |
| Gaudi::Property< float > | m_maxDeltaThetaDeg {this, "MaxDeltaThetaDeg", 35.f} |
| Gaudi::Property< int > | m_maxDeltaSector {this, "MaxDeltaSector", 1} |
| Gaudi::Property< int > | m_sectorModulo {this, "SectorModulo", 16} |
| Gaudi::Property< std::string > | m_inputNodeName {this, "InputNodeName", "x"} |
| Gaudi::Property< std::string > | m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"} |
| Gaudi::Property< std::string > | m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"} |
| Gaudi::Property< std::string > | m_outputName {this, "OutputName", "logits"} |
| float | m_cosMin {0.f} |
| std::vector< std::string > | m_nodeFeatureNames {} |
| Node feature order expected by the model metadata (resolved at initialize). | |
| std::vector< SegmentNodeFeatureId > | m_nodeFeatureIds {} |
| ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > | m_onnxSessionTool |
Runs a segment-level GNN on reconstructed muon segments to classify segment-pair edges as "good" or "background".
The tool reads a xAOD::MuonSegmentContainer and builds a graph where:
The tool then runs an ONNX model (typically a GIN or GCN variant) to produce a logit or probability for each edge, enabling downstream algorithms to filter low-quality segment associations and improve reconstruction efficiency.
Key difference from GraphBucketFilterTool: operates at segment (edge) level rather than bucket (node) level, and the interface uses discrete graph structures (SegmentEdgeGraph) rather than tensors for input/output.
Note: runGraphInference() is not supported by this tool; use SegmentEdgeInferenceAlg and the ISegmentEdgeClassifierTool methods instead.
Definition at line 59 of file SegmentEdgeClassifierTool.h.
|
protectedinherited |
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
Definition at line 89 of file BucketInferenceToolBase.cxx.
|
inherited |
GNN-style graph builder (features + edges). Kept for tools that want it.
Definition at line 196 of file BucketInferenceToolBase.cxx.
|
overridevirtual |
Build a GNN graph from segments, computing node and edge features and storing the graph structure in graph.
Implements MuonML::ISegmentEdgeClassifierTool.
Definition at line 169 of file SegmentEdgeClassifierTool.cxx.
|
protectedinherited |
Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
Definition at line 136 of file BucketInferenceToolBase.cxx.
|
overridevirtual |
Run ONNX inference on graph and populate scores with logit and probability for each edge; called after buildGraph().
Implements MuonML::ISegmentEdgeClassifierTool.
Definition at line 274 of file SegmentEdgeClassifierTool.cxx.
|
inherited |
|
override |
Retrieve the ONNX model and resolve node feature ordering from metadata.
Definition at line 80 of file SegmentEdgeClassifierTool.cxx.
|
protectedinherited |
Definition at line 65 of file BucketInferenceToolBase.cxx.
|
staticprotectedinherited |
Definition at line 32 of file BucketInferenceToolBase.cxx.
|
override |
Not supported by this tool; returns FAILURE.
Use SegmentEdgeInferenceAlg + buildGraph() + classifyEdges() instead.
Definition at line 164 of file SegmentEdgeClassifierTool.cxx.
|
inherited |
Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.
Definition at line 485 of file BucketInferenceToolBase.cxx.
|
protectedinherited |
Generic named inference, for tools with different I/O conventions.
Definition at line 321 of file BucketInferenceToolBase.cxx.
|
protectedinherited |
Definition at line 69 of file BucketInferenceToolBase.cxx.
|
staticprotectedinherited |
Definition at line 25 of file BucketInferenceToolBase.cxx.
|
staticconstexprprotectedinherited |
Definition at line 53 of file BucketInferenceToolBase.h.
|
staticconstexprprotectedinherited |
Definition at line 56 of file BucketInferenceToolBase.h.
|
staticconstexprprotectedinherited |
Definition at line 55 of file BucketInferenceToolBase.h.
|
staticconstexprprotectedinherited |
Definition at line 54 of file BucketInferenceToolBase.h.
|
private |
Definition at line 92 of file SegmentEdgeClassifierTool.h.
|
protectedinherited |
Definition at line 102 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 94 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 93 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 80 of file BucketInferenceToolBase.h.
|
private |
Definition at line 90 of file SegmentEdgeClassifierTool.h.
|
private |
Definition at line 89 of file SegmentEdgeClassifierTool.h.
|
private |
Definition at line 88 of file SegmentEdgeClassifierTool.h.
|
protectedinherited |
Definition at line 101 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 90 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 87 of file BucketInferenceToolBase.h.
|
private |
Definition at line 86 of file SegmentEdgeClassifierTool.h.
|
private |
Definition at line 85 of file SegmentEdgeClassifierTool.h.
|
protectedinherited |
Definition at line 89 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 88 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 86 of file BucketInferenceToolBase.h.
|
private |
Definition at line 96 of file SegmentEdgeClassifierTool.h.
|
private |
Node feature order expected by the model metadata (resolved at initialize).
Definition at line 95 of file SegmentEdgeClassifierTool.h.
|
privateinherited |
Definition at line 105 of file BucketInferenceToolBase.h.
|
private |
Definition at line 91 of file SegmentEdgeClassifierTool.h.
|
protectedinherited |
Definition at line 79 of file BucketInferenceToolBase.h.
|
protectedinherited |
Definition at line 96 of file BucketInferenceToolBase.h.
|
private |
Definition at line 87 of file SegmentEdgeClassifierTool.h.
|
protectedinherited |
Definition at line 95 of file BucketInferenceToolBase.h.