ATLAS Offline Software
BucketInferenceToolBase.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN
3  for the benefit of the ATLAS collaboration
4 */
5 #ifndef MUONINFERENCETOOLS_BUCKETINFERENCETOOLBASE_H
6 #define MUONINFERENCETOOLS_BUCKETINFERENCETOOLBASE_H
7 
12 
16 
17 #include <onnxruntime_cxx_api.h>
18 #include <vector>
19 
20 class ActsGeometryContext;
21 
22 namespace MuonML {
23 
36 class BucketInferenceToolBase : public extends<AthAlgTool, IGraphInferenceTool> {
37 public:
38  using base_class::base_class;
39  ~BucketInferenceToolBase() override = default;
40 
42  StatusCode buildGraph(const EventContext& ctx, GraphRawData& graphData) const;
43 
45  StatusCode runInference(GraphRawData& graphData) const;
46 
47 protected:
49  Ort::Session& model() const;
50 
52  StatusCode buildFeaturesOnly(const EventContext& ctx, GraphRawData& graphData) const;
53 
56  StatusCode buildTransformerInputs(const EventContext& ctx, GraphRawData& graphData) const;
57 
60  const std::vector<const char*>& inputNames,
61  const std::vector<const char*>& outputNames) const;
62 
63  SG::ReadHandleKey<MuonR4::SpacePointContainer> m_readKey{this, "ReadSpacePoints", "MuonSpacePoints"};
64  SG::ReadHandleKey<ActsTrk::GeometryContext> m_geoCtxKey{this, "AlignmentKey", "ActsAlignment", "cond handle key"};
65 
66  // Sparse-graph parameters (GNN)
67  Gaudi::Property<int> m_minLayers{this, "MinLayersValid", 3};
68  Gaudi::Property<int> m_maxChamberDelta{this, "MaxChamberDelta", 13};
69  Gaudi::Property<int> m_maxSectorDelta{this, "MaxSectorDelta", 1};
70  Gaudi::Property<double> m_maxDistXY{this, "MaxDistXY", 6800.0};
71  Gaudi::Property<double> m_maxAbsDz{this, "MaxAbsDz", 15000.0};
72 
73  // Debug/validation knobs
74  Gaudi::Property<unsigned int> m_debugDumpFirstNNodes{this, "DebugDumpFirstNNodes", 5};
75  Gaudi::Property<unsigned int> m_debugDumpFirstNEdges{this, "DebugDumpFirstNEdges", 12};
76  Gaudi::Property<bool> m_validateEdges{this, "ValidateEdges", true};
77 
78 private:
79  ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> m_onnxSessionTool{
80  this, "ModelSession", ""};
81 };
82 
83 } // namespace MuonML
84 
85 #endif
86 
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
IGraphInferenceTool.h
MuonML::BucketInferenceToolBase::model
Ort::Session & model() const
Definition: BucketInferenceToolBase.cxx:21
MuonML::BucketInferenceToolBase::buildGraph
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
GNN-style graph builder (features + edges). Kept for tools that want it.
Definition: BucketInferenceToolBase.cxx:139
MuonML::BucketInferenceToolBase::m_debugDumpFirstNEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Definition: BucketInferenceToolBase.h:75
MuonML::BucketInferenceToolBase::buildTransformerInputs
StatusCode buildTransformerInputs(const EventContext &ctx, GraphRawData &graphData) const
Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
Definition: BucketInferenceToolBase.cxx:79
SG::ReadHandleKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Definition: StoreGate/StoreGate/ReadHandleKey.h:39
python.oracle.Session
Session
Definition: oracle.py:76
MuonML::BucketInferenceToolBase::runNamedInference
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
Definition: BucketInferenceToolBase.cxx:264
MuonML
Definition: BucketGraphUtils.h:19
MuonML::BucketInferenceToolBase::m_debugDumpFirstNNodes
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes
Definition: BucketInferenceToolBase.h:74
MuonML::BucketInferenceToolBase::~BucketInferenceToolBase
~BucketInferenceToolBase() override=default
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
MuonML::GraphRawData
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition: GraphData.h:25
MuonML::BucketInferenceToolBase::m_validateEdges
Gaudi::Property< bool > m_validateEdges
Definition: BucketInferenceToolBase.h:76
IOnnxRuntimeSessionTool.h
MuonML::BucketInferenceToolBase::m_maxSectorDelta
Gaudi::Property< int > m_maxSectorDelta
Definition: BucketInferenceToolBase.h:69
ReadCondHandleKey.h
MuonML::BucketInferenceToolBase::setupModel
StatusCode setupModel()
Definition: BucketInferenceToolBase.cxx:25
MuonML::BucketInferenceToolBase::m_geoCtxKey
SG::ReadHandleKey< ActsTrk::GeometryContext > m_geoCtxKey
Definition: BucketInferenceToolBase.h:64
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:17
MuonML::BucketInferenceToolBase::m_maxDistXY
Gaudi::Property< double > m_maxDistXY
Definition: BucketInferenceToolBase.h:70
GraphData.h
MuonML::BucketInferenceToolBase::buildFeaturesOnly
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
Definition: BucketInferenceToolBase.cxx:32
MuonML::BucketInferenceToolBase::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Definition: BucketInferenceToolBase.h:63
MuonML::BucketInferenceToolBase::m_onnxSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: BucketInferenceToolBase.h:79
MuonML::BucketInferenceToolBase::m_minLayers
Gaudi::Property< int > m_minLayers
Definition: BucketInferenceToolBase.h:67
SpacePointContainer.h
MuonML::BucketInferenceToolBase::runInference
StatusCode runInference(GraphRawData &graphData) const
Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"output"}.
Definition: BucketInferenceToolBase.cxx:344
MuonML::BucketInferenceToolBase
Definition: BucketInferenceToolBase.h:36
MuonML::BucketInferenceToolBase::m_maxChamberDelta
Gaudi::Property< int > m_maxChamberDelta
Definition: BucketInferenceToolBase.h:68
MuonML::BucketInferenceToolBase::m_maxAbsDz
Gaudi::Property< double > m_maxAbsDz
Definition: BucketInferenceToolBase.h:71