ATLAS Offline Software
Loading...
Searching...
No Matches
MuonML::SPInferenceToolBase Class Reference

Baseline tool to handle the. More...

#include <SPInferenceToolBase.h>

Inheritance diagram for MuonML::SPInferenceToolBase:
Collaboration diagram for MuonML::SPInferenceToolBase:

Public Member Functions

StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 Fill up the GraphRawData and construct the graph for the ML inference with ONNX.
StatusCode runInference (GraphRawData &graphData) const

Protected Member Functions

StatusCode setupModel ()
Ort::Session & model () const

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
 Input space points to filter.

Private Attributes

NodeFeatureList m_graphFeatures {}
 List of features to be used for the inference.
ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool {this, "ModelSession", "" }

Detailed Description

Baseline tool to handle the.

Definition at line 20 of file SPInferenceToolBase.h.

Member Function Documentation

◆ buildGraph()

StatusCode MuonML::SPInferenceToolBase::buildGraph ( const EventContext & ctx,
GraphRawData & graphData ) const

Fill up the GraphRawData and construct the graph for the ML inference with ONNX.

If the graph has been built by another inference tool and would be the same than this one the rebuild is skipped

Parameters
ctxEventContext to access the space ponit container from StoreGate
graphDataRerference to the data object to be filled.

Check whether the graph needs a rebuild

Don't launch the rebuild of the graph

Fill the graph edge features and all their respective connections

Definition at line 114 of file SPInferenceToolBase.cxx.

115 {
116
118 if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
119 graphData.graph.reset();
120 }
122 if (graphData.graph) {
123 return StatusCode::SUCCESS;
124 }
125 if (!m_graphFeatures.isValid()) {
126 ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
127 return StatusCode::FAILURE;
128 }
129
130 graphData.graph = std::make_unique<InferenceGraph>();
131
132 SG::ReadHandle spacePoints{m_readKey, ctx};
133 ATH_CHECK(spacePoints.isPresent());
134
135 int64_t nNodes{0}, possConn{0};
136 graphData.spacePointsInBucket.clear();
137 graphData.spacePointsInBucket.reserve(spacePoints->size());
138
139 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
140 nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
141 possConn += Acts::sumUpToN(graphData.spacePointsInBucket.back());
142 }
143
144 graphData.nodeIndex = 0;
145 graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
146 graphData.currLeave = graphData.featureLeaves.begin();
147
148 graphData.srcEdges.reserve(possConn);
149 graphData.desEdges.reserve(possConn);
150
152 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
153 const LayerSpBucket mlBucket{*bucket};
154 m_graphFeatures.fillInData(mlBucket, graphData);
155 }
156
157 graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
158 std::make_move_iterator(graphData.desEdges.end()));
159
160 ATH_MSG_DEBUG("Features:"<<m_graphFeatures.featureNames());
161 ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
162 ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
163
164 std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
165 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
166
167
168 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
169 graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
170 graphData.featureLeaves.data(), graphData.featureLeaves.size(),
171 featShape.data(), featShape.size()));
172
173
174 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
175 graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
176
177 graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
178
179 graphData.previousList = &m_graphFeatures;
180
181 graphData.srcEdges.clear();
182 graphData.desEdges.clear();
183 graphData.featureLeaves.clear();
184 ATH_MSG_DEBUG("Graph data built successfully.");
185 return StatusCode::SUCCESS;
186 }
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter.

◆ model()

Ort::Session & MuonML::SPInferenceToolBase::model ( ) const
protected

Definition at line 87 of file SPInferenceToolBase.cxx.

87 {
88 return m_onnxSessionTool->session();
89 }
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool

◆ runInference()

StatusCode MuonML::SPInferenceToolBase::runInference ( GraphRawData & graphData) const

Definition at line 188 of file SPInferenceToolBase.cxx.

188 {
189 if (!m_graphFeatures.isValid()) {
190 ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
191 return StatusCode::FAILURE;
192 }
193 if (!graphData.graph) {
194 ATH_MSG_ERROR("Graph data is not built.");
195 return StatusCode::FAILURE;
196 }
197
198 if (graphData.graph->dataTensor.size() < 2) {
199 ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
200 return StatusCode::FAILURE;
201 }
202
203 std::vector<const char*> inputNames = {"features", "edge_index"};
204 std::vector<const char*> outputNames = {"output"};
205
206 Ort::RunOptions run_options;
207 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
208
209 std::vector<Ort::Value> outputTensors = model().Run(run_options,
210 inputNames.data(), // input tensor names
211 graphData.graph->dataTensor.data(), // pointer to the tensor vector
212 graphData.graph->dataTensor.size(), // size of the tensor vector
213 outputNames.data(), // output tensor names
214 outputNames.size()); // number of output tensors
215
216 if (outputTensors.empty()) {
217 ATH_MSG_ERROR("Inference returned empty output.");
218 return StatusCode::FAILURE;
219 }
220
221 float* output_data = outputTensors[0].GetTensorMutableData<float>();
222 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
223
224 std::span<float> predictions(output_data, output_data + output_size);
225
226 for (size_t i = 0; i < output_size; i++) {
227 if (!std::isfinite(predictions[i])) {
228 ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
229 predictions[i] = -100.0f;
230 }
231 }
232
233 graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
234
235 return StatusCode::SUCCESS;
236 }
#define ATH_MSG_WARNING(x)

◆ setupModel()

StatusCode MuonML::SPInferenceToolBase::setupModel ( )
protected

Definition at line 90 of file SPInferenceToolBase.cxx.

90 {
91 ATH_CHECK(m_onnxSessionTool.retrieve());
92 ATH_CHECK(m_readKey.initialize());
93
94 Ort::ModelMetadata metadata = model().GetModelMetadata();
95 Ort::AllocatorWithDefaultOptions allocator;
96 Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
97
98 if (feature_json_ptr) {
99 std::string feature_json = feature_json_ptr.get();
100 nlohmann::json json_obj = nlohmann::json::parse(feature_json);
101 for (const auto& feature : json_obj) {
102 m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
103 }
104 }
105 m_graphFeatures.setConnector("fullyConnected", msgStream());
106
107 if (!m_graphFeatures.isValid()) {
108 ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
109 return StatusCode::FAILURE;
110 }
111 return StatusCode::SUCCESS;
112 }
#define ATH_MSG_FATAL(x)

Member Data Documentation

◆ m_graphFeatures

NodeFeatureList MuonML::SPInferenceToolBase::m_graphFeatures {}
private

List of features to be used for the inference.

Definition at line 44 of file SPInferenceToolBase.h.

44{};

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::SPInferenceToolBase::m_onnxSessionTool {this, "ModelSession", "" }
private

Definition at line 45 of file SPInferenceToolBase.h.

45{this, "ModelSession", "" };

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::SPInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protected

Input space points to filter.

Definition at line 41 of file SPInferenceToolBase.h.

41{this, "ReadSpacePoints", "MuonSpacePoints"};

The documentation for this class was generated from the following files: