ATLAS Offline Software
Loading...
Searching...
No Matches
BucketInferenceToolBase.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN
3 for the benefit of the ATLAS collaboration
4*/
5#ifndef MUONINFERENCETOOLS_BUCKETINFERENCETOOLBASE_H
6#define MUONINFERENCETOOLS_BUCKETINFERENCETOOLBASE_H
7
12
16
17#include <onnxruntime_cxx_api.h>
18#include <array>
19#include <cstddef>
20#include <cstdint>
21#include <string>
22#include <string_view>
23#include <vector>
24
25class ActsGeometryContext;
26
27namespace MuonML {
28
41class BucketInferenceToolBase : public extends<AthAlgTool, IGraphInferenceTool> {
42public:
43 using base_class::base_class;
44 ~BucketInferenceToolBase() override = default;
45
47 StatusCode buildGraph(const EventContext& ctx, GraphRawData& graphData) const;
48
50 StatusCode runInference(GraphRawData& graphData) const;
51
52protected:
53 static constexpr std::size_t kBucketFeatureCount = 6;
54 static constexpr std::size_t kNodeFeatureCount = 10;
55 static constexpr std::size_t kEdgeFeatureCount = 7;
56 static constexpr std::array<std::string_view, kNodeFeatureCount> kDefaultNodeFeatureNames = {
57 "segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
58 "segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
59 "bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"};
60
61 StatusCode setupModel();
62 Ort::Session& model() const;
63
64 static std::string trimFeatureToken(std::string s);
65 static std::vector<std::string> parseFeatureNames(const std::string& raw);
66
68 StatusCode buildFeaturesOnly(const EventContext& ctx, GraphRawData& graphData) const;
69
72 StatusCode buildTransformerInputs(const EventContext& ctx, GraphRawData& graphData) const;
73
75 StatusCode runNamedInference(GraphRawData& graphData,
76 const std::vector<const char*>& inputNames,
77 const std::vector<const char*>& outputNames) const;
78
79 SG::ReadHandleKey<MuonR4::SpacePointContainer> m_readKey{this, "ReadSpacePoints", "MuonSpacePoints"};
80 SG::ReadHandleKey<ActsTrk::GeometryContext> m_geoCtxKey{this, "AlignmentKey", "ActsAlignment", "cond handle key"};
81
82 // ONNX I/O name for the GNN output tensor (model-dependent)
83 Gaudi::Property<std::string> m_outputName{this, "OutputName", "logits"};
84
85 // Sparse-graph parameters (GNN)
86 Gaudi::Property<int> m_minLayers{this, "MinLayersValid", 3};
87 Gaudi::Property<int> m_maxChamberDelta{this, "MaxChamberDelta", 13};
88 Gaudi::Property<int> m_maxSectorDelta{this, "MaxSectorDelta", 1};
89 Gaudi::Property<double> m_maxDistXY{this, "MaxDistXY", 6800.0};
90 Gaudi::Property<double> m_maxAbsDz{this, "MaxAbsDz", 15000.0};
91
92 // Debug/validation knobs
93 Gaudi::Property<unsigned int> m_debugDumpFirstNNodes{this, "DebugDumpFirstNNodes", 5};
94 Gaudi::Property<unsigned int> m_debugDumpFirstNEdges{this, "DebugDumpFirstNEdges", 12};
95 Gaudi::Property<bool> m_validateEdges{this, "ValidateEdges", true};
96 Gaudi::Property<bool> m_sanitizeNonFinitePredictions{
97 this, "SanitizeNonFinitePredictions", false,
98 "When true, replace non-finite ONNX outputs with -100 and log a warning."};
99
100 // CUDA / IoBinding state (set in setupModel via dynamic_cast)
101 bool m_isCuda{false};
103
104private:
105 ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> m_onnxSessionTool{
106 this, "ModelSession", ""};
107};
108
109} // namespace MuonML
110
111#endif
112
Property holding a SG store/key/clid from which a ReadHandle is made.
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
GNN-style graph builder (features + edges). Kept for tools that want it.
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
static std::string trimFeatureToken(std::string s)
StatusCode buildTransformerInputs(const EventContext &ctx, GraphRawData &graphData) const
Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
static constexpr std::array< std::string_view, kNodeFeatureCount > kDefaultNodeFeatureNames
Gaudi::Property< std::string > m_outputName
static constexpr std::size_t kEdgeFeatureCount
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes
static constexpr std::size_t kBucketFeatureCount
static std::vector< std::string > parseFeatureNames(const std::string &raw)
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
StatusCode runInference(GraphRawData &graphData) const
Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}...
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
~BucketInferenceToolBase() override=default
static constexpr std::size_t kNodeFeatureCount
Gaudi::Property< double > m_maxDistXY
SG::ReadHandleKey< ActsTrk::GeometryContext > m_geoCtxKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25