ATLAS Offline Software
Loading...
Searching...
No Matches
MuonML::BucketInferenceToolBase Class Reference

#include <BucketInferenceToolBase.h>

Inheritance diagram for MuonML::BucketInferenceToolBase:
Collaboration diagram for MuonML::BucketInferenceToolBase:

Public Member Functions

 ~BucketInferenceToolBase () override=default
StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 GNN-style graph builder (features + edges). Kept for tools that want it.
StatusCode runInference (GraphRawData &graphData) const
 Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.

Protected Member Functions

StatusCode setupModel ()
Ort::Session & model () const
StatusCode buildFeaturesOnly (const EventContext &ctx, GraphRawData &graphData) const
 Build only features (N,6); attaches one tensor in graph.dataTensor[0].
StatusCode buildTransformerInputs (const EventContext &ctx, GraphRawData &graphData) const
 Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
StatusCode runNamedInference (GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
 Generic named inference, for tools with different I/O conventions.

Static Protected Member Functions

static std::string trimFeatureToken (std::string s)
static std::vector< std::string > parseFeatureNames (const std::string &raw)

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
ActsTrk::GeoContextReadKey_t m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"}
Gaudi::Property< std::string > m_outputName {this, "OutputName", "logits"}
Gaudi::Property< int > m_minLayers {this, "MinLayersValid", 3}
Gaudi::Property< int > m_maxChamberDelta {this, "MaxChamberDelta", 13}
Gaudi::Property< int > m_maxSectorDelta {this, "MaxSectorDelta", 1}
Gaudi::Property< double > m_maxDistXY {this, "MaxDistXY", 6800.0}
Gaudi::Property< double > m_maxAbsDz {this, "MaxAbsDz", 15000.0}
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5}
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12}
Gaudi::Property< bool > m_validateEdges {this, "ValidateEdges", true}
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
bool m_isCuda {false}
int m_cudaDeviceId {0}

Static Protected Attributes

static constexpr std::size_t kBucketFeatureCount = 6
static constexpr std::size_t kNodeFeatureCount = 10
static constexpr std::size_t kEdgeFeatureCount = 7
static constexpr std::array< std::string_view, kNodeFeatureCountkDefaultNodeFeatureNames

Private Attributes

ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool

Detailed Description

BucketInferenceToolBase

Common infra to:

  • read buckets & (optionally) geometry
  • build node features
  • (optionally) build GNN sparse edges (via BucketGraphUtils)
  • wrap tensors and run ONNX sessions

GNN-specific operations are in BucketGraphUtils.* Transformer tools reuse feature building without edges and add a pad mask.

Definition at line 41 of file BucketInferenceToolBase.h.

Constructor & Destructor Documentation

◆ ~BucketInferenceToolBase()

MuonML::BucketInferenceToolBase::~BucketInferenceToolBase ( )
overridedefault

Member Function Documentation

◆ buildFeaturesOnly()

StatusCode BucketInferenceToolBase::buildFeaturesOnly ( const EventContext & ctx,
GraphRawData & graphData ) const
protected

Build only features (N,6); attaches one tensor in graph.dataTensor[0].

Definition at line 87 of file BucketInferenceToolBase.cxx.

88 {
89
90 graphData.graph.reset();
91 graphData.srcEdges.clear();
92 graphData.desEdges.clear();
93 graphData.edgeIndexPacked.clear();
94 graphData.featureLeaves.clear();
95 graphData.spacePointsInBucket.clear();
96 graphData.graph = std::make_unique<InferenceGraph>();
97 graphData.graph->dataTensor.reserve(1); // features input; outputs are reserved in runNamedInference()
98
99 const MuonR4::SpacePointContainer* buckets{nullptr};
100 ATH_CHECK(SG::get(buckets, m_readKey, ctx));
101
102 const ActsTrk::GeometryContext* gctx = nullptr;
103 ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
104
105 std::vector<BucketGraphUtils::NodeAux> nodes;
106 BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes,
107 graphData.featureLeaves,
108 graphData.spacePointsInBucket); // now int64_t-compatible
109
110 const int64_t numNodes = static_cast<int64_t>(nodes.size());
111 ATH_MSG_DEBUG("Total buckets: " << buckets->size()
112 << " -> nodes (size>0): " << numNodes
113 << " | features.size()=" << graphData.featureLeaves.size());
114
115 if (numNodes == 0) {
116 ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping inference.");
117 return StatusCode::SUCCESS;
118 }
119
120 const int64_t nFeatPerNode = static_cast<int64_t>(kBucketFeatureCount);
121 if (numNodes * nFeatPerNode != static_cast<int64_t>(graphData.featureLeaves.size())) {
122 ATH_MSG_ERROR( "Feature size mismatch: expected " << (numNodes * nFeatPerNode)
123 << " got " << graphData.featureLeaves.size());
124 return StatusCode::FAILURE;
125 }
126
127 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
128 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
129 graphData.graph->dataTensor.emplace_back(
130 Ort::Value::CreateTensor<float>(memInfo,
131 graphData.featureLeaves.data(),
132 graphData.featureLeaves.size(),
133 featShape.data(),
134 featShape.size()));
135 return StatusCode::SUCCESS;
136}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
size_type size() const noexcept
Returns the number of elements in the collection.
ActsTrk::GeoContextReadKey_t m_geoCtxKey
static constexpr std::size_t kBucketFeatureCount
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition GraphData.h:42
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36

◆ buildGraph()

StatusCode BucketInferenceToolBase::buildGraph ( const EventContext & ctx,
GraphRawData & graphData ) const

GNN-style graph builder (features + edges). Kept for tools that want it.

Definition at line 200 of file BucketInferenceToolBase.cxx.

201 {
202
203 graphData.graph.reset();
204 graphData.srcEdges.clear();
205 graphData.desEdges.clear();
206 graphData.featureLeaves.clear();
207 graphData.spacePointsInBucket.clear();
208 graphData.edgeIndexPacked.clear();
209 graphData.graph = std::make_unique<InferenceGraph>();
210 graphData.graph->dataTensor.reserve(2); // features and edge_index inputs; outputs are reserved in runNamedInference()
211
212 const MuonR4::SpacePointContainer* buckets{nullptr};
213 ATH_CHECK(SG::get(buckets, m_readKey, ctx));
214
215 const ActsTrk::GeometryContext* gctx = nullptr;
216 ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
217
218 std::vector<BucketGraphUtils::NodeAux> nodes;
219
220 BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes,
221 graphData.featureLeaves,
222 graphData.spacePointsInBucket);
223
224 const int64_t numNodes = static_cast<int64_t>(nodes.size());
225 ATH_MSG_DEBUG("Total buckets: " << buckets->size()
226 << " -> nodes (size>0): " << numNodes
227 << " | features.size()=" << graphData.featureLeaves.size());
228
229 if (numNodes == 0) {
230 ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping graph building.");
231 return StatusCode::SUCCESS;
232 }
233
234 const int64_t nFeatPerNode = static_cast<int64_t>(kBucketFeatureCount);
235 if (numNodes * nFeatPerNode != static_cast<int64_t>(graphData.featureLeaves.size())) {
236 ATH_MSG_ERROR("Feature size mismatch: expected " << (numNodes * nFeatPerNode)
237 << " got " << graphData.featureLeaves.size());
238 return StatusCode::FAILURE;
239 }
240
241 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
242 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
243 graphData.graph->dataTensor.emplace_back(
244 Ort::Value::CreateTensor<float>(memInfo,
245 graphData.featureLeaves.data(),
246 graphData.featureLeaves.size(),
247 featShape.data(),
248 featShape.size()));
249
256 graphData.srcEdges, graphData.desEdges);
257 if (m_validateEdges) {
258 size_t bad = 0;
259 size_t write = 0;
260 for (size_t k = 0; k < graphData.srcEdges.size(); ++k) {
261 const int64_t u = graphData.srcEdges[k];
262 const int64_t v = graphData.desEdges[k];
263 const bool okU = (u >= 0 && u < numNodes);
264 const bool okV = (v >= 0 && v < numNodes);
265 if (okU && okV) {
266 graphData.srcEdges[write] = u;
267 graphData.desEdges[write] = v;
268 ++write;
269 } else {
270 ++bad;
271 ATH_MSG_DEBUG( "Drop invalid edge " << k << ": (" << u << "->" << v
272 << "), valid node range [0," << (numNodes-1) << "]");
273 }
274 }
275 if (bad) {
276 ATH_MSG_WARNING( "Removed " << bad << " invalid edges out of "
277 << graphData.srcEdges.size());
278 graphData.srcEdges.resize(write);
279 graphData.desEdges.resize(write);
280 }
281 }
282
283 const size_t E = graphData.srcEdges.size();
284
285 if (msgLvl(MSG::DEBUG)) {
286 // DEBUG: Count connections per node
287 ATH_MSG_DEBUG("Edges built: " << E);
288 const size_t dumpE = std::min<std::size_t>(m_debugDumpFirstNEdges.value(), E);
289 for (size_t k = 0; k < dumpE; ++k) {
290 ATH_MSG_DEBUG("EDGE[" << k << "]: "
291 << graphData.srcEdges[k] << " -> "
292 << graphData.desEdges[k]);
293 }
294
295 std::vector<int> nodeConnections(numNodes, 0);
296 for (size_t k = 0; k < graphData.srcEdges.size(); ++k) {
297 const int64_t u = graphData.srcEdges[k];
298 const int64_t v = graphData.desEdges[k];
299 if (u >= 0 && u < numNodes) nodeConnections[u]++;
300 if (v >= 0 && v < numNodes) nodeConnections[v]++;
301 }
302
303 ATH_MSG_DEBUG("=== DEBUGGING: Node Connections (first 10 nodes) ===");
304 const int64_t debugNodeCount = std::min(numNodes, static_cast<int64_t>(10));
305 for (int64_t i = 0; i < debugNodeCount; ++i) {
306 ATH_MSG_DEBUG("Node[" << i << "] connections: " << nodeConnections[i]);
307 }
308 ATH_MSG_DEBUG("=== END DEBUG NODE CONNECTIONS ===");
309
310 ATH_MSG_DEBUG("=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
311 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
312 std::stringstream connections;
313 connections << "Node[" << nodeIdx << "] connected to: ";
314 bool foundAny = false;
315
316 for (size_t k = 0; k < graphData.srcEdges.size(); ++k) {
317 const int64_t u = graphData.srcEdges[k];
318 const int64_t v = graphData.desEdges[k];
319
320 if (u == nodeIdx) {
321 if (foundAny) connections << ", ";
322 connections << v;
323 foundAny = true;
324 } else if (v == nodeIdx) {
325 if (foundAny) connections << ", ";
326 connections << u;
327 foundAny = true;
328 }
329 }
330
331 if (!foundAny) connections << "none";
332 ATH_MSG_DEBUG(connections.str());
333 }
334 ATH_MSG_DEBUG("=== END DEBUG DETAILED CONNECTIONS ===");
335 }
336
337 nodes = {};
338
339 graphData.edgeIndexPacked.clear();
340 const size_t Efinal = BucketGraphUtils::packEdgeIndex(graphData.srcEdges,
341 graphData.desEdges,
342 graphData.edgeIndexPacked);
343
344 graphData.srcEdges.clear();
345 graphData.desEdges.clear();
346
347 std::vector<int64_t> edgeShape{2, static_cast<int64_t>(Efinal)};
348 graphData.graph->dataTensor.emplace_back(
349 Ort::Value::CreateTensor<int64_t>(memInfo,
350 graphData.edgeIndexPacked.data(),
351 graphData.edgeIndexPacked.size(),
352 edgeShape.data(),
353 edgeShape.size()));
354
355 ATH_MSG_DEBUG("Built sparse bucket graph: N=" << numNodes << ", E=" << Efinal);
356 return StatusCode::SUCCESS;
357}
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< double > m_maxDistXY
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
@ u
Enums for curvilinear frames.
Definition ParamDefs.h:77

◆ buildTransformerInputs()

StatusCode BucketInferenceToolBase::buildTransformerInputs ( const EventContext & ctx,
GraphRawData & graphData ) const
protected

Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.

Definition at line 138 of file BucketInferenceToolBase.cxx.

139 {
140 // Start from (N,6)
141 ATH_CHECK(buildFeaturesOnly(ctx, graphData));
142
143 // Copy features flat buffer for lifetime management
144 std::vector<float> featuresFlat = graphData.featureLeaves;
145 const int64_t S = static_cast<int64_t>(featuresFlat.size() / kBucketFeatureCount);
146
147 if (S == 0) {
148 ATH_MSG_WARNING("No valid features for transformer input. Skipping inference.");
149 return StatusCode::SUCCESS;
150 }
151
152 if (msgLvl(MSG::DEBUG)) {
153 // DEBUG: Print transformer input features for first 10 nodes
154 ATH_MSG_DEBUG("=== DEBUGGING: Transformer input features for first 10 nodes ===");
155 const int64_t debugNodes = std::min(S, static_cast<int64_t>(10));
156 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
157 const int64_t baseIdx = nodeIdx * static_cast<int64_t>(kBucketFeatureCount);
158 ATH_MSG_DEBUG("TransformerNode[" << nodeIdx << "]: "
159 << "x=" << featuresFlat[baseIdx + 0] << ", "
160 << "y=" << featuresFlat[baseIdx + 1] << ", "
161 << "z=" << featuresFlat[baseIdx + 2] << ", "
162 << "layers=" << featuresFlat[baseIdx + 3] << ", "
163 << "nSp=" << featuresFlat[baseIdx + 4] << ", "
164 << "bucketSize=" << featuresFlat[baseIdx + 5]);
165 }
166 ATH_MSG_DEBUG("=== END DEBUG TRANSFORMER FEATURES ===");
167 }
168
169 // Rebuild graph with exactly 2 inputs: features [1,S,6], pad_mask [1,S]
170 graphData.graph.reset();
171 graphData.graph = std::make_unique<InferenceGraph>();
172 graphData.graph->dataTensor.reserve(2); // features and pad_mask inputs; outputs are reserved in runNamedInference()
173
174 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
175
176 // features: [1,S,6] (backed by graphData.featureLeaves to keep alive)
177 std::vector<int64_t> fShape{1, S, static_cast<int64_t>(kBucketFeatureCount)};
178 graphData.featureLeaves.swap(featuresFlat);
179 graphData.graph->dataTensor.emplace_back(
180 Ort::Value::CreateTensor<float>(memInfo,
181 graphData.featureLeaves.data(),
182 graphData.featureLeaves.size(),
183 fShape.data(),
184 fShape.size()));
185
186 // pad_mask: [1,S] (bool). Create ORT-owned tensor and fill with False (=valid).
187 Ort::AllocatorWithDefaultOptions allocator;
188 std::vector<int64_t> mShape{1, S};
189 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
190 mShape.data(),
191 mShape.size(),
192 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
193 bool* maskPtr = padVal.GetTensorMutableData<bool>();
194 for (int64_t i = 0; i < S; ++i) maskPtr[i] = false;
195 graphData.graph->dataTensor.emplace_back(std::move(padVal));
196
197 return StatusCode::SUCCESS;
198}
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].

◆ model()

Ort::Session & BucketInferenceToolBase::model ( ) const
protected

Definition at line 65 of file BucketInferenceToolBase.cxx.

65 {
66 return m_onnxSessionTool->session();
67}
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool

◆ parseFeatureNames()

std::vector< std::string > BucketInferenceToolBase::parseFeatureNames ( const std::string & raw)
staticprotected

Definition at line 32 of file BucketInferenceToolBase.cxx.

32 {
33 std::vector<std::string> out;
34 const std::string s = trimFeatureToken(raw);
35 if (s.empty()) return out;
36
37 // Preferred exporter format: JSON list of strings.
38 if (!s.empty() && s.front() == '[') {
39 bool inQuote = false;
40 std::string token;
41 for (char c : s) {
42 if (c == '"') {
43 if (inQuote) {
44 if (!token.empty()) out.push_back(token);
45 token.clear();
46 }
47 inQuote = !inQuote;
48 continue;
49 }
50 if (inQuote) token.push_back(c);
51 }
52 if (!out.empty()) return out;
53 }
54
55 // Backward-compatible format: comma-separated.
56 std::istringstream ss(s);
57 std::string tok;
58 while (std::getline(ss, tok, ',')) {
59 tok = trimFeatureToken(tok);
60 if (!tok.empty()) out.push_back(tok);
61 }
62 return out;
63}
static Double_t ss
static std::string trimFeatureToken(std::string s)

◆ runInference()

StatusCode BucketInferenceToolBase::runInference ( GraphRawData & graphData) const

Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"logits"}.

Definition at line 532 of file BucketInferenceToolBase.cxx.

532 {
533 std::vector<const char*> inputNames = {"features", "edge_index"};
534 std::vector<const char*> outputNames = {m_outputName.value().c_str()};
535 return runNamedInference(graphData, inputNames, outputNames);
536}
Gaudi::Property< std::string > m_outputName
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.

◆ runNamedInference()

StatusCode BucketInferenceToolBase::runNamedInference ( GraphRawData & graphData,
const std::vector< const char * > & inputNames,
const std::vector< const char * > & outputNames ) const
protected

Generic named inference, for tools with different I/O conventions.

Definition at line 359 of file BucketInferenceToolBase.cxx.

363{
364 if (!graphData.graph) {
365 ATH_MSG_ERROR("Graph data is not built.");
366 return StatusCode::FAILURE;
367 }
368 if (graphData.graph->dataTensor.empty()) {
369 ATH_MSG_ERROR("No input tensors prepared for inference.");
370 return StatusCode::FAILURE;
371 }
372
373 // Reserve the final size here from the actual I/O lists instead
374 // of hard-coding assumptions in the graph builders.
375 graphData.graph->dataTensor.reserve(inputNames.size() + outputNames.size());
376 if (graphData.graph->dataTensor.size() < inputNames.size()) {
377 ATH_MSG_ERROR("Prepared " << graphData.graph->dataTensor.size()
378 << " tensors but inference expects " << inputNames.size() << " inputs.");
379 return StatusCode::FAILURE;
380 }
381
382 if (msgLvl(MSG::DEBUG)) {
383 // DEBUG: Print actual input tensor data for features tensor
384
385 ATH_MSG_DEBUG("=== DEBUGGING: ONNX Input tensor data ===");
386 if (!graphData.graph->dataTensor.empty()) {
387 const auto& featureTensor = graphData.graph->dataTensor[0];
388 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
389 ATH_MSG_DEBUG("Features tensor shape: [" << featShape[0]
390 << (featShape.size()>1 ? ("," + std::to_string(featShape[1])) : "")
391 << (featShape.size()>2 ? ("," + std::to_string(featShape[2])) : "") << "]");
392
393 float* featData = const_cast<Ort::Value&>(featureTensor).GetTensorMutableData<float>();
394 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
395 ATH_MSG_DEBUG("Features tensor total elements: " << totalElements);
396
397 // Print up to 10 nodes; stride = nFeat from tensor shape
398 const size_t nFeat = (featShape.size() > 1 && featShape[1] > 0) ? static_cast<size_t>(featShape[1]) : 1;
399 const size_t nNodes = totalElements / nFeat;
400 const size_t debugNodes = std::min(nNodes, static_cast<size_t>(10));
401
402 // Try to read feature names from model custom metadata.
403 // Prefer x_feature_names (current exporter), then fall back to legacy keys.
404 std::vector<std::string> featNames;
405 {
406 Ort::AllocatorWithDefaultOptions allocator;
407 Ort::ModelMetadata meta = model().GetModelMetadata();
408 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
409 std::vector<std::string> keyNames;
410 keyNames.reserve(keys.size());
411 for (const auto& k : keys) keyNames.emplace_back(k.get());
412 const std::array<std::string, 4> candidates{
413 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
414 for (const std::string& key : candidates) {
415 if (std::find(keyNames.begin(), keyNames.end(), key) != keyNames.end()) {
416 std::string val = meta.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get();
417 featNames = parseFeatureNames(val);
418 break;
419 }
420 }
421 if (featNames.empty()) {
422 ATH_MSG_DEBUG("No usable feature-name metadata key found in model; using generic fN labels.");
423 }
424 }
425 auto featLabel = [&](size_t f) -> std::string {
426 if (f < featNames.size()) return featNames[f];
427 return "f" + std::to_string(f);
428 };
429
430 // Print legend
431 {
432 std::ostringstream legend;
433 legend << "Node feature legend (" << nFeat << " features):";
434 for (size_t f = 0; f < nFeat; ++f) {
435 legend << " f" << f << "=" << featLabel(f);
436 if (f + 1 < nFeat) legend << ",";
437 }
438 ATH_MSG_DEBUG(legend.str());
439 }
440
441 for (size_t n = 0; n < debugNodes; ++n) {
442 std::ostringstream row;
443 row << "ONNXNode[" << n << "]:";
444 for (size_t f = 0; f < nFeat; ++f) {
445 row << " f" << f << "=" << featData[n * nFeat + f];
446 if (f + 1 < nFeat) row << ",";
447 }
448 ATH_MSG_DEBUG(row.str());
449 }
450 }
451 ATH_MSG_DEBUG("=== END DEBUG ONNX INPUT ===");
452 }
453
454 Ort::RunOptions run_options;
455 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
456
457 if (m_isCuda) {
458 // ---- CUDA path: use IoBinding so tensors stay on device ----
459 Ort::IoBinding binding(model());
460 for (std::size_t i = 0; i < inputNames.size(); ++i) {
461 binding.BindInput(inputNames[i], graphData.graph->dataTensor[i]);
462 }
463 // Bind outputs to CPU so predictions are directly readable after sync.
464 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
465 for (const char* outName : outputNames) {
466 binding.BindOutput(outName, cpuOut);
467 }
468
469 model().Run(run_options, binding);
470 binding.SynchronizeOutputs();
471
472 std::vector<Ort::Value> outputs = binding.GetOutputValues();
473 if (outputs.empty()) {
474 ATH_MSG_ERROR("IoBinding inference returned empty output.");
475 return StatusCode::FAILURE;
476 }
477
478 float* outData = outputs[0].GetTensorMutableData<float>();
479 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
480 ATH_MSG_DEBUG("ONNX (IoBinding) raw output elementCount = " << outSize);
481
482 if (m_sanitizeNonFinitePredictions.value()) {
483 std::span<float> preds(outData, outData + outSize);
484 for (size_t i = 0; i < outSize; ++i) {
485 if (!std::isfinite(preds[i])) {
486 ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
487 preds[i] = -100.0f;
488 }
489 }
490 }
491
492 for (auto& v : outputs) {
493 graphData.graph->dataTensor.emplace_back(std::move(v));
494 }
495 return StatusCode::SUCCESS;
496 }
497
498 // ---- CPU path ----
499 std::vector<Ort::Value> outputs =
500 model().Run(run_options,
501 inputNames.data(),
502 graphData.graph->dataTensor.data(),
503 inputNames.size(),
504 outputNames.data(),
505 outputNames.size());
506
507 if (outputs.empty()) {
508 ATH_MSG_ERROR("Inference returned empty output.");
509 return StatusCode::FAILURE;
510 }
511
512 float* outData = outputs[0].GetTensorMutableData<float>();
513 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
514 ATH_MSG_DEBUG("ONNX raw output elementCount = " << outSize);
515
516 if (m_sanitizeNonFinitePredictions.value()) {
517 std::span<float> preds(outData, outData + outSize);
518 for (size_t i = 0; i < outSize; ++i) {
519 if (!std::isfinite(preds[i])) {
520 ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
521 preds[i] = -100.0f;
522 }
523 }
524 }
525
526 for (auto& v : outputs) {
527 graphData.graph->dataTensor.emplace_back(std::move(v));
528 }
529 return StatusCode::SUCCESS;
530}
static std::vector< std::string > parseFeatureNames(const std::string &raw)
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
row
Appending html table to final .html summary file.

◆ setupModel()

StatusCode BucketInferenceToolBase::setupModel ( )
protected

Definition at line 69 of file BucketInferenceToolBase.cxx.

69 {
70 ATH_CHECK(m_onnxSessionTool.retrieve());
71 ATH_CHECK(m_readKey.initialize());
72 ATH_CHECK(m_geoCtxKey.initialize());
73
74 const InferenceUtils::SessionBackend backend = InferenceUtils::sessionBackend(m_onnxSessionTool);
75 m_isCuda = backend.isCuda;
76 m_cudaDeviceId = backend.cudaDeviceId;
77 if (m_isCuda) {
78 ATH_MSG_INFO("ONNX session is running on CUDA device " << m_cudaDeviceId
79 << ". I/O binding will be used.");
80 } else {
81 ATH_MSG_INFO("ONNX session is running on CPU.");
82 }
83
84 return StatusCode::SUCCESS;
85}
#define ATH_MSG_INFO(x)
SessionBackend sessionBackend(const SessionToolHandle &sessionTool)

◆ trimFeatureToken()

std::string BucketInferenceToolBase::trimFeatureToken ( std::string s)
staticprotected

Definition at line 25 of file BucketInferenceToolBase.cxx.

25 {
26 auto notSpace = [](unsigned char c) { return !std::isspace(c); };
27 s.erase(s.begin(), std::find_if(s.begin(), s.end(), notSpace));
28 s.erase(std::find_if(s.rbegin(), s.rend(), notSpace).base(), s.end());
29 return s;
30}

Member Data Documentation

◆ kBucketFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kBucketFeatureCount = 6
staticconstexprprotected

Definition at line 53 of file BucketInferenceToolBase.h.

◆ kDefaultNodeFeatureNames

std::array<std::string_view, kNodeFeatureCount> MuonML::BucketInferenceToolBase::kDefaultNodeFeatureNames
staticconstexprprotected
Initial value:
= {
"segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
"segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
"bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"}

Definition at line 56 of file BucketInferenceToolBase.h.

56 {
57 "segmentPositionX_m", "segmentPositionY_m", "segmentPositionZ_m",
58 "segmentDirectionX", "segmentDirectionY", "segmentDirectionZ",
59 "bucket_chamberIndex", "bucket_layers", "bucket_sector", "bucket_segments"};

◆ kEdgeFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kEdgeFeatureCount = 7
staticconstexprprotected

Definition at line 55 of file BucketInferenceToolBase.h.

◆ kNodeFeatureCount

std::size_t MuonML::BucketInferenceToolBase::kNodeFeatureCount = 10
staticconstexprprotected

Definition at line 54 of file BucketInferenceToolBase.h.

◆ m_cudaDeviceId

int MuonML::BucketInferenceToolBase::m_cudaDeviceId {0}
protected

Definition at line 102 of file BucketInferenceToolBase.h.

102{0};

◆ m_debugDumpFirstNEdges

Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12}
protected

Definition at line 94 of file BucketInferenceToolBase.h.

94{this, "DebugDumpFirstNEdges", 12};

◆ m_debugDumpFirstNNodes

Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5}
protected

Definition at line 93 of file BucketInferenceToolBase.h.

93{this, "DebugDumpFirstNNodes", 5};

◆ m_geoCtxKey

ActsTrk::GeoContextReadKey_t MuonML::BucketInferenceToolBase::m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"}
protected

Definition at line 80 of file BucketInferenceToolBase.h.

80{this, "AlignmentKey", "ActsAlignment", "cond handle key"};

◆ m_isCuda

bool MuonML::BucketInferenceToolBase::m_isCuda {false}
protected

Definition at line 101 of file BucketInferenceToolBase.h.

101{false};

◆ m_maxAbsDz

Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxAbsDz {this, "MaxAbsDz", 15000.0}
protected

Definition at line 90 of file BucketInferenceToolBase.h.

90{this, "MaxAbsDz", 15000.0};

◆ m_maxChamberDelta

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxChamberDelta {this, "MaxChamberDelta", 13}
protected

Definition at line 87 of file BucketInferenceToolBase.h.

87{this, "MaxChamberDelta", 13};

◆ m_maxDistXY

Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxDistXY {this, "MaxDistXY", 6800.0}
protected

Definition at line 89 of file BucketInferenceToolBase.h.

89{this, "MaxDistXY", 6800.0};

◆ m_maxSectorDelta

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxSectorDelta {this, "MaxSectorDelta", 1}
protected

Definition at line 88 of file BucketInferenceToolBase.h.

88{this, "MaxSectorDelta", 1};

◆ m_minLayers

Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_minLayers {this, "MinLayersValid", 3}
protected

Definition at line 86 of file BucketInferenceToolBase.h.

86{this, "MinLayersValid", 3};

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::BucketInferenceToolBase::m_onnxSessionTool
private
Initial value:
{
this, "ModelSession", ""}

Definition at line 105 of file BucketInferenceToolBase.h.

105 {
106 this, "ModelSession", ""};

◆ m_outputName

Gaudi::Property<std::string> MuonML::BucketInferenceToolBase::m_outputName {this, "OutputName", "logits"}
protected

Definition at line 83 of file BucketInferenceToolBase.h.

83{this, "OutputName", "logits"};

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::BucketInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protected

Definition at line 79 of file BucketInferenceToolBase.h.

79{this, "ReadSpacePoints", "MuonSpacePoints"};

◆ m_sanitizeNonFinitePredictions

Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_sanitizeNonFinitePredictions
protected
Initial value:
{
this, "SanitizeNonFinitePredictions", false,
"When true, replace non-finite ONNX outputs with -100 and log a warning."}

Definition at line 96 of file BucketInferenceToolBase.h.

96 {
97 this, "SanitizeNonFinitePredictions", false,
98 "When true, replace non-finite ONNX outputs with -100 and log a warning."};

◆ m_validateEdges

Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_validateEdges {this, "ValidateEdges", true}
protected

Definition at line 95 of file BucketInferenceToolBase.h.

95{this, "ValidateEdges", true};

The documentation for this class was generated from the following files: