ATLAS Offline Software
Loading...
Searching...
No Matches
SPInferenceToolBase.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
5
10#include "Acts/Utilities/MathHelpers.hpp"
11#include <span>
12
13namespace {
14 template <typename T>
15 std::ostream& operator<<(std::ostream& ostr, const std::vector<T>& vec) {
16 ostr << "[";
17 for (const T& val : vec) {
18 ostr << val << ", ";
19 }
20 ostr << "]";
21 return ostr;
22 }
23 template <typename T1 , typename T2>
24 std::ostream& operator<<(std::ostream& ostr, const std::pair<T1, T2>& pair) {
25 ostr << "(" << pair.first << ", " << pair.second << ")";
26 return ostr;
27 }
28
29 std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(const std::vector<int64_t>& edges) {
30 std::vector<std::pair<int64_t, int64_t>> indexPairs;
31 const size_t nEdges = edges.size() / 2;
32 indexPairs.reserve(nEdges);
33 for (size_t i = 0; i < nEdges; ++i) {
34 indexPairs.emplace_back(std::make_pair(edges[i], edges[i+nEdges]));
35 }
36 return indexPairs;
37 }
38
39 std::vector<int64_t> makeSortedEdges(const std::vector<int64_t>& edges) {
40 const size_t nEdges = edges.size() / 2;
41 std::vector<std::pair<int64_t, int64_t>> indexPairs;
42 indexPairs.reserve(nEdges);
43
44 // Create (src, dst) pairs
45 for (size_t i = 0; i < nEdges; ++i) {
46 indexPairs.emplace_back(edges[i], edges[i + nEdges]);
47 }
48
49 // Sort by src, then dst
50 std::sort(indexPairs.begin(), indexPairs.end(), [](const auto& a, const auto& b) {
51 return (a.first < b.first) || (a.first == b.first && a.second < b.second);
52 });
53
54 // Reconstruct sorted flat edges: [sorted_src..., sorted_dst...]
55 std::vector<int64_t> sortedEdges;
56 sortedEdges.reserve(2 * nEdges);
57 for (const auto& pair : indexPairs) {
58 sortedEdges.push_back(pair.first);
59 }
60 for (const auto& pair : indexPairs) {
61 sortedEdges.push_back(pair.second);
62 }
63
64 return sortedEdges;
65 }
66
67 std::string formatNodeFeatures(const std::vector<float>& featureLeaves, size_t numFeaturesPerNode) {
68 std::ostringstream oss;
69
70 const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
71 oss << "Number of nodes: " << numNodes << "\n";
72 oss << "Features per node: " << numFeaturesPerNode << "\n";
73
74 for (size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
75 oss << "Node[" << nodeIdx << "]: [";
76 for (size_t f = 0; f < numFeaturesPerNode; ++f) {
77 if (f > 0) oss << ", ";
78 oss << featureLeaves[nodeIdx * numFeaturesPerNode + f];
79 }
80 oss << "]\n";
81 }
82
83 return oss.str();
84 }
85}
86namespace MuonML{
87 Ort::Session& SPInferenceToolBase::model() const {
88 return m_onnxSessionTool->session();
89 }
91 ATH_CHECK(m_onnxSessionTool.retrieve());
92 ATH_CHECK(m_readKey.initialize());
93
94 Ort::ModelMetadata metadata = model().GetModelMetadata();
95 Ort::AllocatorWithDefaultOptions allocator;
96 Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
97
98 if (feature_json_ptr) {
99 std::string feature_json = feature_json_ptr.get();
100 nlohmann::json json_obj = nlohmann::json::parse(feature_json);
101 for (const auto& feature : json_obj) {
102 m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
103 }
104 }
105 m_graphFeatures.setConnector("fullyConnected", msgStream());
106
107 if (!m_graphFeatures.isValid()) {
108 ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
109 return StatusCode::FAILURE;
110 }
111 return StatusCode::SUCCESS;
112 }
113
114 StatusCode SPInferenceToolBase::buildGraph( const EventContext& ctx,
115 GraphRawData& graphData) const {
116
118 if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
119 graphData.graph.reset();
120 }
122 if (graphData.graph) {
123 return StatusCode::SUCCESS;
124 }
125 if (!m_graphFeatures.isValid()) {
126 ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
127 return StatusCode::FAILURE;
128 }
129
130 graphData.graph = std::make_unique<InferenceGraph>();
131
132 SG::ReadHandle spacePoints{m_readKey, ctx};
133 ATH_CHECK(spacePoints.isPresent());
134
135 int64_t nNodes{0}, possConn{0};
136 graphData.spacePointsInBucket.clear();
137 graphData.spacePointsInBucket.reserve(spacePoints->size());
138
139 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
140 nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
141 possConn += Acts::sumUpToN(graphData.spacePointsInBucket.back());
142 }
143
144 graphData.nodeIndex = 0;
145 graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
146 graphData.currLeave = graphData.featureLeaves.begin();
147
148 graphData.srcEdges.reserve(possConn);
149 graphData.desEdges.reserve(possConn);
150
152 for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
153 const LayerSpBucket mlBucket{*bucket};
154 m_graphFeatures.fillInData(mlBucket, graphData);
155 }
156
157 graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
158 std::make_move_iterator(graphData.desEdges.end()));
159
160 ATH_MSG_DEBUG("Features:"<<m_graphFeatures.featureNames());
161 ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
162 ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
163
164 std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
165 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
166
167
168 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
169 graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
170 graphData.featureLeaves.data(), graphData.featureLeaves.size(),
171 featShape.data(), featShape.size()));
172
173
174 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
175 graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
176
177 graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
178
179 graphData.previousList = &m_graphFeatures;
180
181 graphData.srcEdges.clear();
182 graphData.desEdges.clear();
183 graphData.featureLeaves.clear();
184 ATH_MSG_DEBUG("Graph data built successfully.");
185 return StatusCode::SUCCESS;
186 }
187
188 StatusCode SPInferenceToolBase::runInference(GraphRawData& graphData) const {
189 if (!m_graphFeatures.isValid()) {
190 ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
191 return StatusCode::FAILURE;
192 }
193 if (!graphData.graph) {
194 ATH_MSG_ERROR("Graph data is not built.");
195 return StatusCode::FAILURE;
196 }
197
198 if (graphData.graph->dataTensor.size() < 2) {
199 ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
200 return StatusCode::FAILURE;
201 }
202
203 std::vector<const char*> inputNames = {"features", "edge_index"};
204 std::vector<const char*> outputNames = {"output"};
205
206 Ort::RunOptions run_options;
207 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
208
209 std::vector<Ort::Value> outputTensors = model().Run(run_options,
210 inputNames.data(), // input tensor names
211 graphData.graph->dataTensor.data(), // pointer to the tensor vector
212 graphData.graph->dataTensor.size(), // size of the tensor vector
213 outputNames.data(), // output tensor names
214 outputNames.size()); // number of output tensors
215
216 if (outputTensors.empty()) {
217 ATH_MSG_ERROR("Inference returned empty output.");
218 return StatusCode::FAILURE;
219 }
220
221 float* output_data = outputTensors[0].GetTensorMutableData<float>();
222 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
223
224 std::span<float> predictions(output_data, output_data + output_size);
225
226 for (size_t i = 0; i < output_size; i++) {
227 if (!std::isfinite(predictions[i])) {
228 ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
229 predictions[i] = -100.0f;
230 }
231 }
232
233 graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
234
235 return StatusCode::SUCCESS;
236 }
237
238}
239
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_FATAL(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::vector< size_t > vec
static Double_t a
std::ostream & operator<<(std::ostream &lhs, const TestGaudiProperty &rhs)
The LayerSpBucket is a space pointbucket where the points are internally sorted by their layer number...
Definition LayerBucket.h:14
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
StatusCode runInference(GraphRawData &graphData) const
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Fill up the GraphRawData and construct the graph for the ML inference with ONNX.
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter.
: The muon space point bucket represents a collection of points that will bre processed together in t...
STL class.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
unsigned int nodeIndex
Number of the already filled nodes.
Definition GraphData.h:52
const NodeFeatureList * previousList
Pointer to the latest parsed NodeFeatureList.
Definition GraphData.h:44
std::vector< float >::iterator currLeave
The following variables are needed to fill the consistently the raw data for the Graph Building.
Definition GraphData.h:50
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36