ATLAS Offline Software
Loading...
Searching...
No Matches
DVInferenceToolBase.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef MUONINFERENCETOOLS_DVINFERENCETOOLBASE_H
5#define MUONINFERENCETOOLS_DVINFERENCETOOLBASE_H
6
10
14
15#include "CaloEvent/CaloTowerContainer.h"
18
19#include <onnxruntime_cxx_api.h>
20
21#include <array>
22#include <cstddef>
23#include <cstdint>
24#include <string>
25#include <string_view>
26#include <vector>
27
28namespace MuonML {
29
31 bool valid{false};
32 std::size_t nNodes{0};
33 std::size_t nMuonNodes{0};
34 std::size_t nCaloNodes{0};
35 std::size_t nEdges{0};
36 float rawOutput{0.f};
37 float probability{0.f};
38};
39
56class DVInferenceToolBase : public extends<AthAlgTool, IGraphInferenceTool> {
57public:
58 using base_class::base_class;
59 ~DVInferenceToolBase() override = default;
60
61 StatusCode initialize() override;
62
64 StatusCode runGraphInference(const EventContext& ctx, GraphRawData& graphData) const override;
65
67 StatusCode buildGraph(const EventContext& ctx, GraphRawData& graphData) const;
68
70 StatusCode runInference(GraphRawData& graphData) const;
71
73 StatusCode inferEvent(const EventContext& ctx, DVInferenceResult& result) const;
74
75protected:
76 static constexpr std::size_t kNodeFeatureCount = 7;
77 static constexpr std::size_t kEdgeFeatureCount = 5;
78 static constexpr std::size_t kInputTensorCount = 4;
79
80 static constexpr std::array<std::string_view, kNodeFeatureCount> kDefaultNodeFeatureNames = {
81 "r_pos", "theta_pos", "phi_pos", "theta_dir", "phi_dir", "energy_like", "nCells_or_DoF"};
82 static constexpr std::array<std::string_view, kEdgeFeatureCount> kDefaultEdgeFeatureNames = {
83 "d_energy_like", "d_phi", "d_eta", "cos_angle", "same_sector"};
84
85 StatusCode setupModel();
86 Ort::Session& model() const;
87
89 std::string name{};
90 std::size_t tensorIndex{0};
91 };
92
93 std::vector<std::string> modelInputNames() const;
94 std::vector<std::string> modelOutputNames() const;
95
96 StatusCode runNamedInference(GraphRawData& graphData,
97 const std::vector<InputTensorSpec>& inputSpecs,
98 const std::vector<std::string>& outputNames) const;
99
100 float probabilityFromOutput(const Ort::Value& output, float& rawOutput) const;
101
103 this, "SegmentKey", "MuonSegmentsFromR4", "Input R4 muon segment container"};
105 this, "SpacePointKeys", {"MuonSpacePoints"},
106 "Default is MuonSpacePoints only, matching the training MuonBucketDump SegmentKey alignment."};
108 this, "TowerContainerKey", "CombinedTower", "Input calorimeter tower container"};
109
110 Gaudi::Property<float> m_minTowerEnergyMeV{
111 this, "MinTowerEnergyMeV", 1000.f, "Minimum calo tower energy used as a DV graph node"};
112 Gaudi::Property<float> m_maxTowerSegmentDR{
113 this, "MaxTowerSegmentDR", 0.4f, "Maximum segment-calo deltaR used in the converter"};
114 Gaudi::Property<float> m_caloRMaxMm{
115 this, "CaloRMaxMm", 4250.f, "Barrel radius used for the calo-envelope intersection in mm"};
116 Gaudi::Property<float> m_caloZMaxMm{
117 this, "CaloZMaxMm", 6500.f, "Endcap |z| used for the calo-envelope intersection in mm"};
118 Gaudi::Property<int> m_sectorModulo{
119 this, "SectorModulo", 16, "Number of sectors used by the calo phi->sector converter"};
120 Gaudi::Property<bool> m_requireEdges{
121 this, "RequireEdges", false, "Skip inference when the event graph has no segment-tower edges"};
122 Gaudi::Property<bool> m_useBucketSegmentSelection{
123 this, "UseBucketSegmentSelection", true, "Build muon nodes from segment-parent SpacePoint buckets"};
124 Gaudi::Property<bool> m_fallbackToAllSegments{
125 this, "FallbackToAllSegments", false, "If bucket-segment matching fails, fall back to all SegmentKey segments."};
126 Gaudi::Property<int> m_maxEdges{this, "MaxEdges", -1, "Maximum number of directed segment-tower edges to create; negative means no cap"};
127
128 Gaudi::Property<std::string> m_inputNodeName{this, "InputNodeName", "x"};
129 Gaudi::Property<std::string> m_inputEdgeIndexName{this, "InputEdgeIndexName", "edge_index"};
130 Gaudi::Property<std::string> m_inputEdgeAttrName{this, "InputEdgeAttrName", "edge_attr"};
131 Gaudi::Property<std::string> m_inputNMuonNodesName{this, "InputNMuonNodesName", "n_muon_nodes"};
132 Gaudi::Property<std::string> m_outputName{this, "OutputName", "logits"};
133
134 Gaudi::Property<std::string> m_singleOutputMode{
135 this, "SingleOutputMode", "logit", "How to interpret a one-value output: auto, logit, or prob"};
136
137 Gaudi::Property<unsigned int> m_debugDumpFirstNNodes{this, "DebugDumpFirstNNodes", 0};
138 Gaudi::Property<unsigned int> m_debugDumpFirstNEdges{this, "DebugDumpFirstNEdges", 0};
139 Gaudi::Property<bool> m_sanitizeNonFiniteInputs{
140 this, "SanitizeNonFiniteInputs", true, "Replace non-finite input features with zero before creating ONNX tensors"};
141 Gaudi::Property<bool> m_sanitizeNonFinitePredictions{
142 this, "SanitizeNonFinitePredictions", false, "Replace non-finite ONNX outputs with -100 and log a warning"};
143
144 bool m_isCuda{false};
146
147private:
148 ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> m_onnxSessionTool{
149 this, "ModelSession", "", "ONNX Runtime session tool for the DV classifier"};
150};
151
152} // namespace MuonML
153
154#endif
Property holding a SG store/key/clid from which a ReadHandle is made.
Athena tool for DisplacedVertex graph-level ONNX inference.
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Gaudi::Property< float > m_caloRMaxMm
Gaudi::Property< std::string > m_inputNodeName
Gaudi::Property< float > m_maxTowerSegmentDR
Gaudi::Property< int > m_sectorModulo
SG::ReadHandleKey< xAOD::MuonSegmentContainer > m_segmentKey
Gaudi::Property< std::string > m_outputName
std::vector< std::string > modelOutputNames() const
static constexpr std::size_t kNodeFeatureCount
SG::ReadHandleKey< CaloTowerContainer > m_towerKey
Gaudi::Property< float > m_minTowerEnergyMeV
std::vector< std::string > modelInputNames() const
Gaudi::Property< bool > m_requireEdges
Gaudi::Property< int > m_maxEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes
StatusCode runInference(GraphRawData &graphData) const
Run the configured ONNX session on a graph already built by buildGraph.
Gaudi::Property< bool > m_useBucketSegmentSelection
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< float > m_caloZMaxMm
Gaudi::Property< std::string > m_inputEdgeIndexName
static constexpr std::array< std::string_view, kEdgeFeatureCount > kDefaultEdgeFeatureNames
static constexpr std::size_t kEdgeFeatureCount
~DVInferenceToolBase() override=default
Gaudi::Property< bool > m_fallbackToAllSegments
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< InputTensorSpec > &inputSpecs, const std::vector< std::string > &outputNames) const
StatusCode runGraphInference(const EventContext &ctx, GraphRawData &graphData) const override
IGraphInferenceTool entry point: build the DV event graph and run ONNX.
Gaudi::Property< std::string > m_inputEdgeAttrName
SG::ReadHandleKeyArray< MuonR4::SpacePointContainer > m_spacePointKeys
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Build the DV ONNX input tensors: x, edge_index, edge_attr, n_muon_nodes.
float probabilityFromOutput(const Ort::Value &output, float &rawOutput) const
static constexpr std::size_t kInputTensorCount
StatusCode inferEvent(const EventContext &ctx, DVInferenceResult &result) const
Convenience event-classifier API used by DVInferenceAlg.
Gaudi::Property< std::string > m_singleOutputMode
Gaudi::Property< bool > m_sanitizeNonFiniteInputs
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
Gaudi::Property< std::string > m_inputNMuonNodesName
static constexpr std::array< std::string_view, kNodeFeatureCount > kDefaultNodeFeatureNames
Property holding a SG store/key/clid from which a ReadHandle is made.
HandleKeyArray< ReadHandle< T >, ReadHandleKey< T >, Gaudi::DataHandle::Reader > ReadHandleKeyArray
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25