ATLAS Offline Software
Loading...
Searching...
No Matches
MuonML::SPInferenceToolBase Class Reference

Baseline tool to handle the. More...

#include <SPInferenceToolBase.h>

Inheritance diagram for MuonML::SPInferenceToolBase:
Collaboration diagram for MuonML::SPInferenceToolBase:

Public Member Functions

StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 Fill up the GraphRawData and construct the graph for the ML inference with ONNX.
StatusCode runInference (GraphRawData &graphData) const

Protected Member Functions

StatusCode setupModel ()
Ort::Session & model () const

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
 Input space points to filter.

Private Attributes

NodeFeatureList m_graphFeatures {}
 List of features to be used for the inference.
ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool {this, "ModelSession", "" }
bool m_isCuda {false}
int m_cudaDeviceId {0}

Detailed Description

Baseline tool to handle the.

Definition at line 20 of file SPInferenceToolBase.h.

Member Function Documentation

◆ buildGraph()

StatusCode MuonML::SPInferenceToolBase::buildGraph ( const EventContext & ctx,
GraphRawData & graphData ) const

Fill up the GraphRawData and construct the graph for the ML inference with ONNX.

If the graph has been built by another inference tool and would be the same than this one the rebuild is skipped

Parameters
ctxEventContext to access the space ponit container from StoreGate
graphDataRerference to the data object to be filled.

Check whether the graph needs a rebuild

Don't launch the rebuild of the graph

Fill the graph edge features and all their respective connections

Definition at line 127 of file SPInferenceToolBase.cxx.

128 {
129
131 if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
132 graphData.graph.reset();
133 }
135 if (graphData.graph) {
136 return StatusCode::SUCCESS;
137 }
138 if (!m_graphFeatures.isValid()) {
139 ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
140 return StatusCode::FAILURE;
141 }
142
143 graphData.graph = std::make_unique<InferenceGraph>();
144
145 SG::ReadHandle spacePoints{m_readKey, ctx};
146 ATH_CHECK(spacePoints.isPresent());
147
148 int64_t nNodes{0}, possConn{0};
149 graphData.spacePointsInBucket.clear();
150 graphData.spacePointsInBucket.reserve(spacePoints->size());
151
152 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
153 nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
154 possConn += Acts::sumUpToN(graphData.spacePointsInBucket.back());
155 }
156
157 graphData.nodeIndex = 0;
158 graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
159 graphData.currLeave = graphData.featureLeaves.begin();
160
161 graphData.srcEdges.reserve(possConn);
162 graphData.desEdges.reserve(possConn);
163
165 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
166 const LayerSpBucket mlBucket{*bucket};
167 m_graphFeatures.fillInData(mlBucket, graphData);
168 }
169
170 graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
171 std::make_move_iterator(graphData.desEdges.end()));
172
173 ATH_MSG_DEBUG("Features:"<<m_graphFeatures.featureNames());
174 ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
175 ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
176
177 std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
178 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
179
180
181 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
182 graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
183 graphData.featureLeaves.data(), graphData.featureLeaves.size(),
184 featShape.data(), featShape.size()));
185
186
187 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
188 graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
189
190 graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
191
192 graphData.previousList = &m_graphFeatures;
193
194 graphData.srcEdges.clear();
195 graphData.desEdges.clear();
196 graphData.featureLeaves.clear();
197 ATH_MSG_DEBUG("Graph data built successfully.");
198 return StatusCode::SUCCESS;
199 }
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_DEBUG(x)
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter.

◆ model()

Ort::Session & MuonML::SPInferenceToolBase::model ( ) const
protected

Definition at line 88 of file SPInferenceToolBase.cxx.

88 {
89 return m_onnxSessionTool->session();
90 }
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool

◆ runInference()

StatusCode MuonML::SPInferenceToolBase::runInference ( GraphRawData & graphData) const

Definition at line 201 of file SPInferenceToolBase.cxx.

201 {
202 if (!m_graphFeatures.isValid()) {
203 ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
204 return StatusCode::FAILURE;
205 }
206 if (!graphData.graph) {
207 ATH_MSG_ERROR("Graph data is not built.");
208 return StatusCode::FAILURE;
209 }
210
211 if (graphData.graph->dataTensor.size() < 2) {
212 ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
213 return StatusCode::FAILURE;
214 }
215
216 std::vector<const char*> inputNames = {"features", "edge_index"};
217 std::vector<const char*> outputNames = {"output"};
218
219 Ort::RunOptions run_options;
220 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
221
222 if (m_isCuda) {
223 // ---- CUDA path: use IoBinding ----
224 Ort::IoBinding binding(model());
225 for (std::size_t i = 0; i < inputNames.size(); ++i) {
226 binding.BindInput(inputNames[i], graphData.graph->dataTensor[i]);
227 }
228 // Bind output to CPU so logits are directly readable after sync.
229 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
230 for (const char* outName : outputNames) {
231 binding.BindOutput(outName, cpuOut);
232 }
233
234 model().Run(run_options, binding);
235 binding.SynchronizeOutputs();
236
237 std::vector<Ort::Value> outputTensors = binding.GetOutputValues();
238 if (outputTensors.empty()) {
239 ATH_MSG_ERROR("IoBinding inference returned empty output.");
240 return StatusCode::FAILURE;
241 }
242
243 float* output_data = outputTensors[0].GetTensorMutableData<float>();
244 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
245
246 std::span<float> predictions(output_data, output_data + output_size);
247 for (size_t i = 0; i < output_size; i++) {
248 if (!std::isfinite(predictions[i])) {
249 ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
250 predictions[i] = -100.0f;
251 }
252 }
253 graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
254 return StatusCode::SUCCESS;
255 }
256
257 // ---- CPU path (unchanged) ----
258 std::vector<Ort::Value> outputTensors = model().Run(run_options,
259 inputNames.data(), // input tensor names
260 graphData.graph->dataTensor.data(), // pointer to the tensor vector
261 graphData.graph->dataTensor.size(), // size of the tensor vector
262 outputNames.data(), // output tensor names
263 outputNames.size()); // number of output tensors
264
265 if (outputTensors.empty()) {
266 ATH_MSG_ERROR("Inference returned empty output.");
267 return StatusCode::FAILURE;
268 }
269
270 float* output_data = outputTensors[0].GetTensorMutableData<float>();
271 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
272
273 std::span<float> predictions(output_data, output_data + output_size);
274
275 for (size_t i = 0; i < output_size; i++) {
276 if (!std::isfinite(predictions[i])) {
277 ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
278 predictions[i] = -100.0f;
279 }
280 }
281
282 graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
283
284 return StatusCode::SUCCESS;
285 }
#define ATH_MSG_WARNING(x)

◆ setupModel()

StatusCode MuonML::SPInferenceToolBase::setupModel ( )
protected

Definition at line 91 of file SPInferenceToolBase.cxx.

91 {
92 ATH_CHECK(m_onnxSessionTool.retrieve());
93 ATH_CHECK(m_readKey.initialize());
94
95 // Detect CUDA provider by dynamic-casting the concrete session tool.
96 if (const auto* cudaTool = dynamic_cast<const AthOnnx::OnnxRuntimeSessionToolCUDA*>(
97 m_onnxSessionTool.get())) {
98 m_isCuda = true;
99 m_cudaDeviceId = cudaTool->deviceId();
100 ATH_MSG_DEBUG("ONNX session is running on CUDA device " << m_cudaDeviceId
101 << ". I/O binding will be used.");
102 } else {
103 m_isCuda = false;
104 ATH_MSG_DEBUG("ONNX session is running on CPU.");
105 }
106
107 Ort::ModelMetadata metadata = model().GetModelMetadata();
108 Ort::AllocatorWithDefaultOptions allocator;
109 Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
110
111 if (feature_json_ptr) {
112 std::string feature_json = feature_json_ptr.get();
113 nlohmann::json json_obj = nlohmann::json::parse(feature_json);
114 for (const auto& feature : json_obj) {
115 m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
116 }
117 }
118 m_graphFeatures.setConnector("fullyConnected", msgStream());
119
120 if (!m_graphFeatures.isValid()) {
121 ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
122 return StatusCode::FAILURE;
123 }
124 return StatusCode::SUCCESS;
125 }
#define ATH_MSG_FATAL(x)

Member Data Documentation

◆ m_cudaDeviceId

int MuonML::SPInferenceToolBase::m_cudaDeviceId {0}
private

Definition at line 48 of file SPInferenceToolBase.h.

48{0};

◆ m_graphFeatures

NodeFeatureList MuonML::SPInferenceToolBase::m_graphFeatures {}
private

List of features to be used for the inference.

Definition at line 44 of file SPInferenceToolBase.h.

44{};

◆ m_isCuda

bool MuonML::SPInferenceToolBase::m_isCuda {false}
private

Definition at line 47 of file SPInferenceToolBase.h.

47{false};

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::SPInferenceToolBase::m_onnxSessionTool {this, "ModelSession", "" }
private

Definition at line 45 of file SPInferenceToolBase.h.

45{this, "ModelSession", "" };

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::SPInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protected

Input space points to filter.

Definition at line 41 of file SPInferenceToolBase.h.

41{this, "ReadSpacePoints", "MuonSpacePoints"};

The documentation for this class was generated from the following files: