ATLAS Offline Software
Loading...
Searching...
No Matches
BucketInferenceToolBase.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN
3 for the benefit of the ATLAS collaboration
4*/
5#ifndef MUONINFERENCETOOLS_BUCKETINFERENCETOOLBASE_H
6#define MUONINFERENCETOOLS_BUCKETINFERENCETOOLBASE_H
7
12
16
17#include <onnxruntime_cxx_api.h>
18#include <vector>
19
20class ActsGeometryContext;
21
22namespace MuonML {
23
36class BucketInferenceToolBase : public extends<AthAlgTool, IGraphInferenceTool> {
37public:
38 using base_class::base_class;
39 ~BucketInferenceToolBase() override = default;
40
42 StatusCode buildGraph(const EventContext& ctx, GraphRawData& graphData) const;
43
45 StatusCode runInference(GraphRawData& graphData) const;
46
47protected:
48 StatusCode setupModel();
49 Ort::Session& model() const;
50
52 StatusCode buildFeaturesOnly(const EventContext& ctx, GraphRawData& graphData) const;
53
56 StatusCode buildTransformerInputs(const EventContext& ctx, GraphRawData& graphData) const;
57
59 StatusCode runNamedInference(GraphRawData& graphData,
60 const std::vector<const char*>& inputNames,
61 const std::vector<const char*>& outputNames) const;
62
63 SG::ReadHandleKey<MuonR4::SpacePointContainer> m_readKey{this, "ReadSpacePoints", "MuonSpacePoints"};
64 SG::ReadHandleKey<ActsTrk::GeometryContext> m_geoCtxKey{this, "AlignmentKey", "ActsAlignment", "cond handle key"};
65
66 // Sparse-graph parameters (GNN)
67 Gaudi::Property<int> m_minLayers{this, "MinLayersValid", 3};
68 Gaudi::Property<int> m_maxChamberDelta{this, "MaxChamberDelta", 13};
69 Gaudi::Property<int> m_maxSectorDelta{this, "MaxSectorDelta", 1};
70 Gaudi::Property<double> m_maxDistXY{this, "MaxDistXY", 6800.0};
71 Gaudi::Property<double> m_maxAbsDz{this, "MaxAbsDz", 15000.0};
72
73 // Debug/validation knobs
74 Gaudi::Property<unsigned int> m_debugDumpFirstNNodes{this, "DebugDumpFirstNNodes", 5};
75 Gaudi::Property<unsigned int> m_debugDumpFirstNEdges{this, "DebugDumpFirstNEdges", 12};
76 Gaudi::Property<bool> m_validateEdges{this, "ValidateEdges", true};
77
78private:
79 ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> m_onnxSessionTool{
80 this, "ModelSession", ""};
81};
82
83} // namespace MuonML
84
85#endif
86
Property holding a SG store/key/clid from which a ReadHandle is made.
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
GNN-style graph builder (features + edges). Kept for tools that want it.
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
StatusCode buildTransformerInputs(const EventContext &ctx, GraphRawData &graphData) const
Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
StatusCode runInference(GraphRawData &graphData) const
Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"output"}...
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
~BucketInferenceToolBase() override=default
Gaudi::Property< double > m_maxDistXY
SG::ReadHandleKey< ActsTrk::GeometryContext > m_geoCtxKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25