ATLAS Offline Software
Loading...
Searching...
No Matches
SPInferenceToolBase.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
5
10#include "Acts/Utilities/MathHelpers.hpp"
12#include <span>
13
14namespace {
15 template <typename T>
16 std::ostream& operator<<(std::ostream& ostr, const std::vector<T>& vec) {
17 ostr << "[";
18 for (const T& val : vec) {
19 ostr << val << ", ";
20 }
21 ostr << "]";
22 return ostr;
23 }
24 template <typename T1 , typename T2>
25 std::ostream& operator<<(std::ostream& ostr, const std::pair<T1, T2>& pair) {
26 ostr << "(" << pair.first << ", " << pair.second << ")";
27 return ostr;
28 }
29
30 std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(const std::vector<int64_t>& edges) {
31 std::vector<std::pair<int64_t, int64_t>> indexPairs;
32 const size_t nEdges = edges.size() / 2;
33 indexPairs.reserve(nEdges);
34 for (size_t i = 0; i < nEdges; ++i) {
35 indexPairs.emplace_back(std::make_pair(edges[i], edges[i+nEdges]));
36 }
37 return indexPairs;
38 }
39
40 std::vector<int64_t> makeSortedEdges(const std::vector<int64_t>& edges) {
41 const size_t nEdges = edges.size() / 2;
42 std::vector<std::pair<int64_t, int64_t>> indexPairs;
43 indexPairs.reserve(nEdges);
44
45 // Create (src, dst) pairs
46 for (size_t i = 0; i < nEdges; ++i) {
47 indexPairs.emplace_back(edges[i], edges[i + nEdges]);
48 }
49
50 // Sort by src, then dst
51 std::sort(indexPairs.begin(), indexPairs.end(), [](const auto& a, const auto& b) {
52 return (a.first < b.first) || (a.first == b.first && a.second < b.second);
53 });
54
55 // Reconstruct sorted flat edges: [sorted_src..., sorted_dst...]
56 std::vector<int64_t> sortedEdges;
57 sortedEdges.reserve(2 * nEdges);
58 for (const auto& pair : indexPairs) {
59 sortedEdges.push_back(pair.first);
60 }
61 for (const auto& pair : indexPairs) {
62 sortedEdges.push_back(pair.second);
63 }
64
65 return sortedEdges;
66 }
67
68 std::string formatNodeFeatures(const std::vector<float>& featureLeaves, size_t numFeaturesPerNode) {
69 std::ostringstream oss;
70
71 const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
72 oss << "Number of nodes: " << numNodes << "\n";
73 oss << "Features per node: " << numFeaturesPerNode << "\n";
74
75 for (size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
76 oss << "Node[" << nodeIdx << "]: [";
77 for (size_t f = 0; f < numFeaturesPerNode; ++f) {
78 if (f > 0) oss << ", ";
79 oss << featureLeaves[nodeIdx * numFeaturesPerNode + f];
80 }
81 oss << "]\n";
82 }
83
84 return oss.str();
85 }
86}
87namespace MuonML{
88 Ort::Session& SPInferenceToolBase::model() const {
89 return m_onnxSessionTool->session();
90 }
92 ATH_CHECK(m_onnxSessionTool.retrieve());
93 ATH_CHECK(m_readKey.initialize());
94
95 // Detect CUDA provider by dynamic-casting the concrete session tool.
96 if (const auto* cudaTool = dynamic_cast<const AthOnnx::OnnxRuntimeSessionToolCUDA*>(
97 m_onnxSessionTool.get())) {
98 m_isCuda = true;
99 m_cudaDeviceId = cudaTool->deviceId();
100 ATH_MSG_DEBUG("ONNX session is running on CUDA device " << m_cudaDeviceId
101 << ". I/O binding will be used.");
102 } else {
103 m_isCuda = false;
104 ATH_MSG_DEBUG("ONNX session is running on CPU.");
105 }
106
107 Ort::ModelMetadata metadata = model().GetModelMetadata();
108 Ort::AllocatorWithDefaultOptions allocator;
109 Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
110
111 if (feature_json_ptr) {
112 std::string feature_json = feature_json_ptr.get();
113 nlohmann::json json_obj = nlohmann::json::parse(feature_json);
114 for (const auto& feature : json_obj) {
115 m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
116 }
117 }
118 m_graphFeatures.setConnector("fullyConnected", msgStream());
119
120 if (!m_graphFeatures.isValid()) {
121 ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
122 return StatusCode::FAILURE;
123 }
124 return StatusCode::SUCCESS;
125 }
126
127 StatusCode SPInferenceToolBase::buildGraph( const EventContext& ctx,
128 GraphRawData& graphData) const {
129
131 if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
132 graphData.graph.reset();
133 }
135 if (graphData.graph) {
136 return StatusCode::SUCCESS;
137 }
138 if (!m_graphFeatures.isValid()) {
139 ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
140 return StatusCode::FAILURE;
141 }
142
143 graphData.graph = std::make_unique<InferenceGraph>();
144
145 SG::ReadHandle spacePoints{m_readKey, ctx};
146 ATH_CHECK(spacePoints.isPresent());
147
148 int64_t nNodes{0}, possConn{0};
149 graphData.spacePointsInBucket.clear();
150 graphData.spacePointsInBucket.reserve(spacePoints->size());
151
152 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
153 nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
154 possConn += Acts::sumUpToN(graphData.spacePointsInBucket.back());
155 }
156
157 graphData.nodeIndex = 0;
158 graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
159 graphData.currLeave = graphData.featureLeaves.begin();
160
161 graphData.srcEdges.reserve(possConn);
162 graphData.desEdges.reserve(possConn);
163
165 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
166 const LayerSpBucket mlBucket{*bucket};
167 m_graphFeatures.fillInData(mlBucket, graphData);
168 }
169
170 graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
171 std::make_move_iterator(graphData.desEdges.end()));
172
173 ATH_MSG_DEBUG("Features:"<<m_graphFeatures.featureNames());
174 ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
175 ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
176
177 std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
178 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
179
180
181 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
182 graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
183 graphData.featureLeaves.data(), graphData.featureLeaves.size(),
184 featShape.data(), featShape.size()));
185
186
187 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
188 graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
189
190 graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
191
192 graphData.previousList = &m_graphFeatures;
193
194 graphData.srcEdges.clear();
195 graphData.desEdges.clear();
196 graphData.featureLeaves.clear();
197 ATH_MSG_DEBUG("Graph data built successfully.");
198 return StatusCode::SUCCESS;
199 }
200
201 StatusCode SPInferenceToolBase::runInference(GraphRawData& graphData) const {
202 if (!m_graphFeatures.isValid()) {
203 ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
204 return StatusCode::FAILURE;
205 }
206 if (!graphData.graph) {
207 ATH_MSG_ERROR("Graph data is not built.");
208 return StatusCode::FAILURE;
209 }
210
211 if (graphData.graph->dataTensor.size() < 2) {
212 ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
213 return StatusCode::FAILURE;
214 }
215
216 std::vector<const char*> inputNames = {"features", "edge_index"};
217 std::vector<const char*> outputNames = {"output"};
218
219 Ort::RunOptions run_options;
220 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
221
222 if (m_isCuda) {
223 // ---- CUDA path: use IoBinding ----
224 Ort::IoBinding binding(model());
225 for (std::size_t i = 0; i < inputNames.size(); ++i) {
226 binding.BindInput(inputNames[i], graphData.graph->dataTensor[i]);
227 }
228 // Bind output to CPU so logits are directly readable after sync.
229 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
230 for (const char* outName : outputNames) {
231 binding.BindOutput(outName, cpuOut);
232 }
233
234 model().Run(run_options, binding);
235 binding.SynchronizeOutputs();
236
237 std::vector<Ort::Value> outputTensors = binding.GetOutputValues();
238 if (outputTensors.empty()) {
239 ATH_MSG_ERROR("IoBinding inference returned empty output.");
240 return StatusCode::FAILURE;
241 }
242
243 float* output_data = outputTensors[0].GetTensorMutableData<float>();
244 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
245
246 std::span<float> predictions(output_data, output_data + output_size);
247 for (size_t i = 0; i < output_size; i++) {
248 if (!std::isfinite(predictions[i])) {
249 ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
250 predictions[i] = -100.0f;
251 }
252 }
253 graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
254 return StatusCode::SUCCESS;
255 }
256
257 // ---- CPU path (unchanged) ----
258 std::vector<Ort::Value> outputTensors = model().Run(run_options,
259 inputNames.data(), // input tensor names
260 graphData.graph->dataTensor.data(), // pointer to the tensor vector
261 graphData.graph->dataTensor.size(), // size of the tensor vector
262 outputNames.data(), // output tensor names
263 outputNames.size()); // number of output tensors
264
265 if (outputTensors.empty()) {
266 ATH_MSG_ERROR("Inference returned empty output.");
267 return StatusCode::FAILURE;
268 }
269
270 float* output_data = outputTensors[0].GetTensorMutableData<float>();
271 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
272
273 std::span<float> predictions(output_data, output_data + output_size);
274
275 for (size_t i = 0; i < output_size; i++) {
276 if (!std::isfinite(predictions[i])) {
277 ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
278 predictions[i] = -100.0f;
279 }
280 }
281
282 graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
283
284 return StatusCode::SUCCESS;
285 }
286
287}
288
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_FATAL(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::vector< size_t > vec
static Double_t a
std::ostream & operator<<(std::ostream &lhs, const TestGaudiProperty &rhs)
The LayerSpBucket is a space pointbucket where the points are internally sorted by their layer number...
Definition LayerBucket.h:14
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
StatusCode runInference(GraphRawData &graphData) const
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Fill up the GraphRawData and construct the graph for the ML inference with ONNX.
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter.
: The muon space point bucket represents a collection of points that will bre processed together in t...
STL class.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
unsigned int nodeIndex
Number of the already filled nodes.
Definition GraphData.h:52
const NodeFeatureList * previousList
Pointer to the latest parsed NodeFeatureList.
Definition GraphData.h:44
std::vector< float >::iterator currLeave
The following variables are needed to fill the consistently the raw data for the Graph Building.
Definition GraphData.h:50
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36