ATLAS Offline Software
Loading...
Searching...
No Matches
DVInferenceToolBase.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
5
6#include "InferenceUtils.h"
9#include "GaudiKernel/SystemOfUnits.h"
11
12#include "CaloEvent/CaloTower.h"
13#include "CxxUtils/phihelper.h"
18
19#include <algorithm>
20#include <array>
21#include <cmath>
22#include <iomanip>
23#include <limits>
24#include <optional>
25#include <set>
26#include <sstream>
27#include <unordered_map>
28#include <utility>
29
30using namespace MuonML;
31
32namespace {
33
34enum class NodeKind : uint8_t { Muon, Calo };
35
36struct DVNodeAux {
37 NodeKind kind{NodeKind::Muon};
38 std::array<float, 7> features{};
39 float eta{0.f};
40 float phi{0.f};
41 float energyLike{0.f};
42 Amg::Vector3D direction{0., 0., 1.};
43 int sector{-1};
44};
45
46using SegmentList = std::vector<const xAOD::MuonSegment*>;
47
48void appendUniqueSegment(SegmentList& segments, const xAOD::MuonSegment* seg) {
49 if (seg && !Acts::rangeContainsValue(segments, seg)) segments.push_back(seg);
50}
51
52uint16_t countLayersInBucket(const MuonR4::SpacePointBucket& bucket) {
54 std::set<unsigned int> uniqueLayers{};
55 for (const MuonR4::SpacePointBucket::value_type& sp : bucket) {
56 uniqueLayers.insert(sorter.sectorLayerNum(*sp));
57 }
58 return static_cast<uint16_t>(uniqueLayers.size());
59}
60
61std::string bucketSignatureKey(const MuonR4::SpacePointBucket& bucket) {
62
63 std::ostringstream sig;
64 sig << static_cast<const void*>(bucket.msSector()) << '|'
65 << bucket.msSector()->sector() << '|'
66 << bucket.msSector()->side() << '|'
67 << std::fixed << std::setprecision(3)
68 << bucket.coveredMin() << '|' << bucket.coveredMax();
69 return sig.str();
70}
71
72void appendMuonSegmentNode(const xAOD::MuonSegment& seg, int bucketSector, std::vector<DVNodeAux>& nodes) {
73 const Amg::Vector3D posMm = seg.position();
74 const Amg::Vector3D dir = seg.direction();
75 const Amg::Vector3D posM = posMm / Gaudi::Units::m;
76
77 DVNodeAux node{};
78 node.kind = NodeKind::Muon;
79 node.features[0] = static_cast<float>(posM.mag());
80 node.features[1] = static_cast<float>(posM.theta());
81 node.features[2] = static_cast<float>(posM.phi());
82 node.features[3] = static_cast<float>(dir.theta());
83 node.features[4] = static_cast<float>(dir.phi());
84 node.features[5] = 0.f;
85 node.features[6] = static_cast<float>(seg.numberDoF());
86 node.eta = static_cast<float>(posMm.eta());
87 node.phi = static_cast<float>(posMm.phi());
88 node.energyLike = 0.f;
89 node.direction = dir;
90 node.sector = bucketSector;
91 nodes.push_back(node);
92}
93
94bool hasName(const std::vector<std::string>& names, const std::string& needle) {
95 return Acts::rangeContainsValue(names, needle);
96}
97
98std::string joinNames(const std::vector<std::string>& names) {
99 std::ostringstream ostr;
100 for (std::size_t i = 0; i < names.size(); ++i) {
101 if (i != 0u) ostr << ", ";
102 ostr << names[i];
103 }
104 return ostr.str();
105}
106
107std::optional<Amg::Vector3D> firstIntersectionWithEnvelope(const Amg::Vector3D& direction,
108 float rMaxMm,
109 float zMaxMm) {
110 const float ux = static_cast<float>(direction.x());
111 const float uy = static_cast<float>(direction.y());
112 const float uz = static_cast<float>(direction.z());
113 std::optional<float> bestT{};
114 const float ur = static_cast<float>(direction.perp());
115 if (rMaxMm > 0.f && ur > 0.f) {
116 const float tBarrel = rMaxMm / ur;
117 const float zBarrel = tBarrel * uz;
118 if (std::abs(zBarrel) <= zMaxMm) bestT = tBarrel;
119 }
120
121 if (zMaxMm > 0.f && std::abs(uz) > 0.f) {
122 const float tEndcap = zMaxMm / std::abs(uz);
123 const float xEnd = tEndcap * ux;
124 const float yEnd = tEndcap * uy;
125 if (std::hypot(xEnd, yEnd) <= rMaxMm && (!bestT || tEndcap < *bestT)) {
126 bestT = tEndcap;
127 }
128 }
129
130 if (!bestT) return std::nullopt;
131 return (*bestT) * direction;
132}
133
134} // namespace
135
138 if (m_singleOutputMode.value() != "auto" &&
139 m_singleOutputMode.value() != "logit" &&
140 m_singleOutputMode.value() != "prob") {
141 ATH_MSG_ERROR("SingleOutputMode must be one of auto, logit, prob; got " << m_singleOutputMode.value());
142 return StatusCode::FAILURE;
143 }
144
145 if (m_useBucketSegmentSelection.value() && m_spacePointKeys.size() > 1u) {
146 ATH_MSG_WARNING("DV SpacePointKeys has " << m_spacePointKeys.size()
147 << " entries. The MuonBucketDump training samples use the "
148 << "default SegmentKey array with segments attached only to "
149 << "the first SpacePointKeys entry. For parity with training, "
150 << "configure SpacePointKeys=['MuonSpacePoints'] unless the "
151 << "training dump was produced with matching segment keys for all entries.");
152 }
153
154 ATH_MSG_INFO("Initialized DVInferenceToolBase with SegmentKey=" << m_segmentKey.key()
155 << ", SpacePointKeys=" << m_spacePointKeys.size()
157 << ", TowerContainerKey="
158 << (m_towerKey.empty() ? std::string("<disabled>") : m_towerKey.key())
159 << ", " << m_minTowerEnergyMeV
160 << ", " << m_maxTowerSegmentDR
161 << ", " << m_caloRMaxMm
162 << ", " << m_caloZMaxMm
163 << ", " << m_fallbackToAllSegments
164 << ", " << m_singleOutputMode);
165 return StatusCode::SUCCESS;
166}
167
168Ort::Session& DVInferenceToolBase::model() const {
169 return m_onnxSessionTool->session();
170}
171
172std::vector<std::string> DVInferenceToolBase::modelInputNames() const {
173 std::vector<std::string> names{};
174 Ort::AllocatorWithDefaultOptions allocator;
175 const std::size_t nInputs = model().GetInputCount();
176 names.reserve(nInputs);
177 for (std::size_t i = 0; i < nInputs; ++i) {
178 auto name = model().GetInputNameAllocated(i, allocator);
179 if (name) names.emplace_back(name.get());
180 }
181 return names;
182}
183
184std::vector<std::string> DVInferenceToolBase::modelOutputNames() const {
185 std::vector<std::string> names{};
186 Ort::AllocatorWithDefaultOptions allocator;
187 const std::size_t nOutputs = model().GetOutputCount();
188 names.reserve(nOutputs);
189 for (std::size_t i = 0; i < nOutputs; ++i) {
190 auto name = model().GetOutputNameAllocated(i, allocator);
191 if (name) names.emplace_back(name.get());
192 }
193 return names;
194}
195
197 ATH_CHECK(m_onnxSessionTool.retrieve());
198 ATH_CHECK(m_segmentKey.initialize());
199 ATH_CHECK(m_spacePointKeys.initialize());
201
203 m_isCuda = backend.isCuda;
204 m_cudaDeviceId = backend.cudaDeviceId;
205 if (m_isCuda) {
206 ATH_MSG_INFO("ONNX session is running on CUDA device " << m_cudaDeviceId
207 << ". I/O binding will be used.");
208 } else {
209 ATH_MSG_INFO("ONNX session is running on CPU.");
210 }
211 return StatusCode::SUCCESS;
212}
213
214StatusCode DVInferenceToolBase::runGraphInference(const EventContext& ctx,
215 GraphRawData& graphData) const {
216 ATH_CHECK(buildGraph(ctx, graphData));
217 if (!graphData.graph || graphData.graph->dataTensor.empty()) {
218 ATH_MSG_DEBUG("DV graph has no input tensors; skip inference for this event.");
219 return StatusCode::SUCCESS;
220 }
221 return runInference(graphData);
222}
223
224StatusCode DVInferenceToolBase::inferEvent(const EventContext& ctx,
225 DVInferenceResult& result) const {
226 result = DVInferenceResult{};
227 GraphRawData graphData{};
228 ATH_CHECK(buildGraph(ctx, graphData));
229 if (!graphData.graph || graphData.graph->dataTensor.size() < kInputTensorCount) {
230 ATH_MSG_WARNING("DV graph is empty; no event-classifier output will be produced.");
231 return StatusCode::SUCCESS;
232 }
233
234 const auto xShape = graphData.graph->dataTensor[0].GetTensorTypeAndShapeInfo().GetShape();
235 const auto edgeShape = graphData.graph->dataTensor[1].GetTensorTypeAndShapeInfo().GetShape();
236 result.nNodes = !xShape.empty() && xShape[0] > 0 ? static_cast<std::size_t>(xShape[0]) : 0u;
237 result.nEdges = edgeShape.size() > 1 && edgeShape[1] > 0 ? static_cast<std::size_t>(edgeShape[1]) : 0u;
238 if (graphData.spacePointsInBucket.size() >= 2) {
239 result.nMuonNodes = static_cast<std::size_t>(std::max<int64_t>(graphData.spacePointsInBucket[0], 0));
240 result.nCaloNodes = static_cast<std::size_t>(std::max<int64_t>(graphData.spacePointsInBucket[1], 0));
241 }
242
243 ATH_CHECK(runInference(graphData));
244 if (graphData.graph->dataTensor.size() <= kInputTensorCount) {
245 ATH_MSG_ERROR("DV inference finished without an output tensor.");
246 return StatusCode::FAILURE;
247 }
248
249 result.probability = probabilityFromOutput(graphData.graph->dataTensor.back(), result.rawOutput);
250 result.valid = std::isfinite(result.probability);
251 ATH_MSG_DEBUG("DV event classifier: N=" << result.nNodes
252 << " (muon=" << result.nMuonNodes << ", calo=" << result.nCaloNodes
253 << "), E=" << result.nEdges
254 << ", raw=" << result.rawOutput
255 << ", probability=" << result.probability);
256 return StatusCode::SUCCESS;
257}
258
259StatusCode DVInferenceToolBase::buildGraph(const EventContext& ctx,
260 GraphRawData& graphData) const {
261 graphData.graph.reset();
262 graphData.featureLeaves.clear();
263 graphData.srcEdges.clear();
264 graphData.desEdges.clear();
265 graphData.edgeIndexPacked.clear();
266 graphData.spacePointsInBucket.clear();
267 graphData.graph = std::make_unique<InferenceGraph>();
268 graphData.graph->dataTensor.reserve(kInputTensorCount);
269
270 std::vector<DVNodeAux> nodes;
271
272 const xAOD::MuonSegmentContainer* segments{nullptr};
273 ATH_CHECK(SG::get(segments, m_segmentKey, ctx));
274
275 nodes.reserve(segments ? segments->size() : 0u);
276
277 if (segments && m_useBucketSegmentSelection.value() && !m_spacePointKeys.empty()) {
278 using SegmentsPerBucket_t =
279 std::unordered_map<const MuonR4::SpacePointBucket*, SegmentList>;
280
281 using SegmentsPerBucketSignature_t =
282 std::unordered_map<std::string, SegmentList>;
283
284 SegmentsPerBucket_t segmentsPerBucket{};
285 SegmentsPerBucketSignature_t segmentsPerBucketSignature{};
286 for (const xAOD::MuonSegment* seg : *segments) {
287 const auto* detailed = MuonR4::detailedSegment(*seg);
288 const MuonR4::SpacePointBucket* parentBucket = detailed->parent()->parentBucket();
289 appendUniqueSegment(segmentsPerBucket[parentBucket], seg);
290 const std::string parentSig = bucketSignatureKey(*parentBucket);
291 if (!parentSig.empty()) appendUniqueSegment(segmentsPerBucketSignature[parentSig], seg);
292 }
293 std::size_t nSignatureMatchedBuckets = 0u;
295 const MuonR4::SpacePointContainer* spContainer{nullptr};
296 ATH_CHECK(SG::get(spContainer, spKey, ctx));
297
298 for (const MuonR4::SpacePointBucket* bucket : *spContainer) {
299 const auto it = segmentsPerBucket.find(bucket);
300 const SegmentList* matchedSegments{nullptr};
301 if (it != segmentsPerBucket.end() && !it->second.empty()) {
302 matchedSegments = &it->second;
303 } else {
304 const std::string sig = bucketSignatureKey(*bucket);
305 const auto sigIt = sig.empty() ? segmentsPerBucketSignature.end()
306 : segmentsPerBucketSignature.find(sig);
307 if (sigIt != segmentsPerBucketSignature.end() && !sigIt->second.empty()) {
308 matchedSegments = &sigIt->second;
309 ++nSignatureMatchedBuckets;
310 }
311 }
312 if (!matchedSegments) continue;
313
314 const int bucketSector = bucket->msSector() ? static_cast<int>(bucket->msSector()->sector()) : -1;
315 const uint16_t bucketLayers = countLayersInBucket(*bucket);
316 ATH_MSG_VERBOSE("DV bucket segment node source: key=" << spKey.key()
317 << " sector=" << bucketSector
318 << " layers=" << bucketLayers
319 << " segments=" << matchedSegments->size());
320
321 for (const xAOD::MuonSegment* seg : *matchedSegments) {
322 appendMuonSegmentNode(*seg, bucketSector, nodes);
323 }
324 }
325 }
326
327 ATH_MSG_DEBUG("DV graph built " << nodes.size()
328 << " muon nodes from BucketDumper-style SpacePointBucket-associated segments"
329 << " (signature-matched filtered buckets=" << nSignatureMatchedBuckets << ")");
330 }
331
332 if (segments && nodes.empty() &&
334 if (m_useBucketSegmentSelection.value()) {
335 ATH_MSG_WARNING("No bucket-associated segments were found for DV graph building; "
336 "falling back to all segments from " << m_segmentKey.key()
337 << ". This does not match the training converter exactly.");
338 }
339 for (const xAOD::MuonSegment* seg : *segments) {
340 appendMuonSegmentNode(*seg, static_cast<int>(seg->sector()), nodes);
341 }
342 }
343
344 if (segments && nodes.empty() && m_useBucketSegmentSelection.value() &&
345 !m_fallbackToAllSegments.value()) {
346 ATH_MSG_WARNING("No bucket-associated segments were found for DV graph building. "
347 "Not falling back to all segments because that does not match the training converter.");
348 }
349
350 const std::size_t nMuonNodes = nodes.size();
351
352 if (!m_towerKey.empty() && nMuonNodes > 0u) {
353 const CaloTowerContainer* towers{nullptr};
354 ATH_CHECK(SG::get(towers, m_towerKey, ctx));
355
356 nodes.reserve(nodes.size() + towers->size());
357 for (const CaloTower* tower : *towers) {
358 const float energyMeV = static_cast<float>(tower->energy());
359 if (energyMeV < m_minTowerEnergyMeV) continue;
360
361 const float eta = static_cast<float>(tower->eta());
362 const float phi = static_cast<float>(tower->phi());
363 float minDR = std::numeric_limits<float>::max();
364 for (std::size_t i = 0; i < nMuonNodes; ++i) {
365 minDR = std::min(
366 minDR,
367 static_cast<float>(xAOD::P4Helpers::deltaR(eta, phi, nodes[i].eta, nodes[i].phi)));
368 }
369
370 if (minDR >= m_maxTowerSegmentDR) continue;
371 const Amg::Vector3D direction = Acts::makeDirectionFromPhiEta(
372 static_cast<double>(phi), static_cast<double>(eta));
373 const std::optional<Amg::Vector3D> posMm =
374 firstIntersectionWithEnvelope(direction, m_caloRMaxMm.value(), m_caloZMaxMm.value());
375 if (!posMm) continue;
376
377 const Amg::Vector3D posM = (*posMm) / Gaudi::Units::m;
378
379 DVNodeAux node{};
380 node.kind = NodeKind::Calo;
381 node.features[0] = static_cast<float>(posM.mag());
382 node.features[1] = static_cast<float>(posM.theta());
383 node.features[2] = static_cast<float>(posM.phi());
384 node.features[3] = static_cast<float>(direction.theta());
385 node.features[4] = static_cast<float>(direction.phi());
386 node.features[5] = energyMeV;
387 node.features[6] = static_cast<float>(tower->size());
388 node.eta = eta;
389 node.phi = phi;
390 node.energyLike = energyMeV;
391 node.direction = direction;
392 node.sector = static_cast<int>(
393 MuonR4::ExpandedSector{CxxUtils::wrapToPi(static_cast<double>(phi))}.msSector());
394 nodes.push_back(node);
395 }
396 }
397
398 const std::size_t nCaloNodes = nodes.size() - nMuonNodes;
399 const std::size_t nNodes = nodes.size();
400 graphData.spacePointsInBucket.push_back(static_cast<int64_t>(nMuonNodes));
401 graphData.spacePointsInBucket.push_back(static_cast<int64_t>(nCaloNodes));
402
403 if (nNodes == 0u) {
404 ATH_MSG_WARNING("No muon segment or calo tower nodes found. Skipping DV inference.");
405 return StatusCode::SUCCESS;
406 }
407
408 graphData.featureLeaves.reserve(nNodes * kNodeFeatureCount);
409 for (const DVNodeAux& node : nodes) {
410 graphData.featureLeaves.insert(graphData.featureLeaves.end(),
411 node.features.begin(), node.features.end());
412 }
413
414 const int maxEdges = m_maxEdges.value();
415 std::vector<float> edgeAttr;
416 edgeAttr.reserve(2u * nMuonNodes * std::max<std::size_t>(nCaloNodes, 1u) * kEdgeFeatureCount);
417
418 auto addEdge = [&graphData, &edgeAttr, maxEdges](std::size_t src,
419 std::size_t dst,
420 const DVNodeAux& a,
421 const DVNodeAux& b) -> bool {
422 if (maxEdges >= 0 && static_cast<int>(graphData.srcEdges.size()) >= maxEdges) return false;
423 const float dPhi = CxxUtils::deltaPhi(b.phi, a.phi);
424 const float dEta = b.eta - a.eta;
425 const float cosAng = std::clamp(static_cast<float>(a.direction.dot(b.direction)), -1.f, 1.f);
426 std::array<float, kEdgeFeatureCount> attr{
427 b.energyLike - a.energyLike,
428 dPhi,
429 dEta,
430 cosAng,
431 (a.sector == b.sector) ? 1.f : 0.f};
432
433 graphData.srcEdges.push_back(static_cast<int64_t>(src));
434 graphData.desEdges.push_back(static_cast<int64_t>(dst));
435 edgeAttr.insert(edgeAttr.end(), attr.begin(), attr.end());
436 return true;
437 };
438
439 bool edgeCapReached = false;
440 for (std::size_t im = 0; im < nMuonNodes && !edgeCapReached; ++im) {
441 for (std::size_t ic = nMuonNodes; ic < nNodes; ++ic) {
442 if (xAOD::P4Helpers::deltaR(nodes[im].eta, nodes[im].phi,
443 nodes[ic].eta, nodes[ic].phi) >= m_maxTowerSegmentDR.value()) continue;
444 if (!addEdge(im, ic, nodes[im], nodes[ic])) {
445 edgeCapReached = true;
446 break;
447 }
448 if (!addEdge(ic, im, nodes[ic], nodes[im])) {
449 edgeCapReached = true;
450 break;
451 }
452 }
453 }
454
455 const std::size_t nEdges = graphData.srcEdges.size();
456 if (m_requireEdges.value() && nEdges == 0u) {
457 ATH_MSG_DEBUG("DV graph has no segment-tower edges and RequireEdges=True; skip inference.");
458 graphData.graph.reset();
459 return StatusCode::SUCCESS;
460 }
461
462 if (edgeAttr.size() != nEdges * kEdgeFeatureCount) {
463 ATH_MSG_ERROR("DV edge attribute size mismatch: E=" << nEdges
464 << " edge_attr.size=" << edgeAttr.size());
465 return StatusCode::FAILURE;
466 }
467
468 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
469 std::vector<int64_t> nodeShape{static_cast<int64_t>(nNodes), static_cast<int64_t>(kNodeFeatureCount)};
470 graphData.graph->dataTensor.emplace_back(
471 Ort::Value::CreateTensor<float>(memInfo,
472 graphData.featureLeaves.data(),
473 graphData.featureLeaves.size(),
474 nodeShape.data(),
475 nodeShape.size()));
476
477 graphData.edgeIndexPacked.clear();
478 graphData.edgeIndexPacked.reserve(2u * nEdges);
479 graphData.edgeIndexPacked.insert(graphData.edgeIndexPacked.end(), graphData.srcEdges.begin(), graphData.srcEdges.end());
480 graphData.edgeIndexPacked.insert(graphData.edgeIndexPacked.end(), graphData.desEdges.begin(), graphData.desEdges.end());
481
482 std::vector<int64_t> edgeIndexShape{2, static_cast<int64_t>(nEdges)};
483 graphData.graph->dataTensor.emplace_back(
484 Ort::Value::CreateTensor<int64_t>(memInfo,
485 graphData.edgeIndexPacked.data(),
486 graphData.edgeIndexPacked.size(),
487 edgeIndexShape.data(),
488 edgeIndexShape.size()));
489
490 Ort::AllocatorWithDefaultOptions allocator;
491 std::vector<int64_t> edgeAttrShape{static_cast<int64_t>(nEdges), static_cast<int64_t>(kEdgeFeatureCount)};
492 Ort::Value edgeAttrTensor = Ort::Value::CreateTensor(allocator,
493 edgeAttrShape.data(),
494 edgeAttrShape.size(),
495 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
496 if (!edgeAttr.empty()) {
497 float* edgeAttrData = edgeAttrTensor.GetTensorMutableData<float>();
498 std::copy(edgeAttr.begin(), edgeAttr.end(), edgeAttrData);
499 }
500 graphData.graph->dataTensor.emplace_back(std::move(edgeAttrTensor));
501
502 std::vector<int64_t> nMuonShape{1};
503 graphData.graph->dataTensor.emplace_back(
504 Ort::Value::CreateTensor<int64_t>(memInfo,
505 graphData.spacePointsInBucket.data(),
506 1,
507 nMuonShape.data(),
508 nMuonShape.size()));
509
510 if (msgLvl(MSG::DEBUG)) {
511 ATH_MSG_DEBUG("Built DV graph: N=" << nNodes << " (muon=" << nMuonNodes
512 << ", calo=" << nCaloNodes << "), E=" << nEdges
513 << ", n_muon_nodes=" << graphData.spacePointsInBucket[0]);
514 const std::size_t dumpNodes = std::min<std::size_t>(m_debugDumpFirstNNodes.value(), nNodes);
515 for (std::size_t i = 0; i < dumpNodes; ++i) {
516 std::ostringstream row;
517 row << "DVNode[" << i << "] kind=" << (nodes[i].kind == NodeKind::Muon ? "muon" : "calo") << ":";
518 for (std::size_t f = 0; f < kNodeFeatureCount; ++f) {
519 row << " f" << f << "=" << graphData.featureLeaves[i * kNodeFeatureCount + f];
520 }
521 ATH_MSG_DEBUG(row.str());
522 }
523 const std::size_t dumpEdges = std::min<std::size_t>(m_debugDumpFirstNEdges.value(), nEdges);
524 for (std::size_t e = 0; e < dumpEdges; ++e) {
525 ATH_MSG_DEBUG("DVEdge[" << e << "]: " << graphData.srcEdges[e]
526 << " -> " << graphData.desEdges[e]
527 << " edge_attr=["
528 << edgeAttr[e * kEdgeFeatureCount + 0] << ", "
529 << edgeAttr[e * kEdgeFeatureCount + 1] << ", "
530 << edgeAttr[e * kEdgeFeatureCount + 2] << ", "
531 << edgeAttr[e * kEdgeFeatureCount + 3] << ", "
532 << edgeAttr[e * kEdgeFeatureCount + 4] << "]");
533 }
534 }
535
536 graphData.srcEdges.clear();
537 graphData.desEdges.clear();
538 return StatusCode::SUCCESS;
539}
540
542 GraphRawData& graphData,
543 const std::vector<InputTensorSpec>& inputSpecs,
544 const std::vector<std::string>& outputNames) const {
545 if (!graphData.graph) {
546 ATH_MSG_ERROR("Graph data is not built.");
547 return StatusCode::FAILURE;
548 }
549 if (inputSpecs.empty()) {
550 ATH_MSG_ERROR("No DV ONNX inputs were selected for inference.");
551 return StatusCode::FAILURE;
552 }
553
554 for (const InputTensorSpec& spec : inputSpecs) {
555 if (spec.tensorIndex >= graphData.graph->dataTensor.size()) {
556 ATH_MSG_ERROR("Input " << spec.name << " requests tensor index " << spec.tensorIndex
557 << " but only " << graphData.graph->dataTensor.size()
558 << " tensors were prepared.");
559 return StatusCode::FAILURE;
560 }
561 }
562
563 std::vector<const char*> inputNamePtrs{};
564 inputNamePtrs.reserve(inputSpecs.size());
565 for (const InputTensorSpec& spec : inputSpecs) {
566 inputNamePtrs.push_back(spec.name.c_str());
567 }
568
569 std::vector<const char*> outputNamePtrs{};
570 outputNamePtrs.reserve(outputNames.size());
571 for (const std::string& name : outputNames) {
572 outputNamePtrs.push_back(name.c_str());
573 }
574
575 graphData.graph->dataTensor.reserve(graphData.graph->dataTensor.size() + outputNamePtrs.size());
576
577 Ort::RunOptions runOptions;
578 runOptions.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
579
580 if (m_isCuda) {
581 Ort::IoBinding binding(model());
582 for (const InputTensorSpec& spec : inputSpecs) {
583 binding.BindInput(spec.name.c_str(), graphData.graph->dataTensor[spec.tensorIndex]);
584 }
585 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
586 for (const char* outName : outputNamePtrs) {
587 binding.BindOutput(outName, cpuOut);
588 }
589
590 model().Run(runOptions, binding);
591 binding.SynchronizeOutputs();
592
593 std::vector<Ort::Value> outputs = binding.GetOutputValues();
594 if (outputs.empty()) {
595 ATH_MSG_ERROR("IoBinding inference returned empty output.");
596 return StatusCode::FAILURE;
597 }
598
599 if (m_sanitizeNonFinitePredictions.value()) {
600 float* outData = outputs[0].GetTensorMutableData<float>();
601 const std::size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
602 for (std::size_t i = 0; i < outSize; ++i) {
603 if (!std::isfinite(outData[i])) {
604 ATH_MSG_WARNING("Non-finite DV prediction detected at " << i << " -> set to -100.");
605 outData[i] = -100.f;
606 }
607 }
608 }
609
610 for (auto& v : outputs) {
611 graphData.graph->dataTensor.emplace_back(std::move(v));
612 }
613 return StatusCode::SUCCESS;
614 }
615
616 std::vector<Ort::Value> orderedInputs{};
617 orderedInputs.reserve(inputSpecs.size());
618 for (const InputTensorSpec& spec : inputSpecs) {
619 orderedInputs.emplace_back(std::move(graphData.graph->dataTensor[spec.tensorIndex]));
620 }
621
622 std::vector<Ort::Value> outputs =
623 model().Run(runOptions,
624 inputNamePtrs.data(),
625 orderedInputs.data(),
626 inputNamePtrs.size(),
627 outputNamePtrs.data(),
628 outputNamePtrs.size());
629
630 if (outputs.empty()) {
631 ATH_MSG_ERROR("Inference returned empty output.");
632 return StatusCode::FAILURE;
633 }
634
635 ATH_MSG_DEBUG("DV ONNX raw output elementCount = "
636 << outputs[0].GetTensorTypeAndShapeInfo().GetElementCount());
637
638 if (m_sanitizeNonFinitePredictions.value()) {
639 float* outData = outputs[0].GetTensorMutableData<float>();
640 const std::size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
641 for (std::size_t i = 0; i < outSize; ++i) {
642 if (!std::isfinite(outData[i])) {
643 ATH_MSG_WARNING("Non-finite DV prediction detected at " << i << " -> set to -100.");
644 outData[i] = -100.f;
645 }
646 }
647 }
648
649 for (auto& v : outputs) {
650 graphData.graph->dataTensor.emplace_back(std::move(v));
651 }
652 return StatusCode::SUCCESS;
653}
654
656 const std::vector<std::string> availableInputs = modelInputNames();
657 if (availableInputs.empty()) {
658 ATH_MSG_ERROR("DV ONNX model has no inputs.");
659 return StatusCode::FAILURE;
660 }
661
662 ATH_MSG_DEBUG("DV ONNX model inputs: " << joinNames(availableInputs));
663
664 const std::string nodeName = m_inputNodeName.value();
665 const std::string edgeIndexName = m_inputEdgeIndexName.value();
666 const std::string edgeAttrName = m_inputEdgeAttrName.value();
667 const std::string nMuonNodesName = m_inputNMuonNodesName.value();
668
669 std::vector<InputTensorSpec> inputSpecs{};
670 inputSpecs.reserve(kInputTensorCount);
671
672 auto addIfPresent = [&availableInputs, &inputSpecs](const std::string& name, std::size_t tensorIndex) {
673 if (hasName(availableInputs, name)) {
674 inputSpecs.push_back(InputTensorSpec{name, tensorIndex});
675 return true;
676 }
677 return false;
678 };
679
680 const bool hasNodeInput = addIfPresent(nodeName, 0u);
681 const bool hasEdgeIndexInput = addIfPresent(edgeIndexName, 1u);
682 const bool hasEdgeAttrInput = addIfPresent(edgeAttrName, 2u);
683 const bool hasNMuonNodesInput = addIfPresent(nMuonNodesName, 3u);
684
685 if (!hasNodeInput || !hasEdgeIndexInput) {
686 ATH_MSG_ERROR("DV ONNX model is missing required inputs. Expected at least "
687 << nodeName << " and " << edgeIndexName
688 << "; model inputs are: " << joinNames(availableInputs));
689 return StatusCode::FAILURE;
690 }
691
692 if (!hasEdgeAttrInput) {
693 ATH_MSG_DEBUG("DV ONNX model has no input named " << edgeAttrName
694 << "; not binding edge_attr. This is expected for exports where "
695 << "the architecture does not consume edge attributes and ONNX pruned the input.");
696 }
697 if (!hasNMuonNodesInput) {
698 ATH_MSG_DEBUG("DV ONNX model has no input named " << nMuonNodesName
699 << "; not binding n_muon_nodes. This is expected only if the exported "
700 << "model does not need model-side muon/calo normalization.");
701 }
702
703 for (const std::string& inputName : availableInputs) {
704 if (inputName != nodeName && inputName != edgeIndexName &&
705 inputName != edgeAttrName && inputName != nMuonNodesName) {
706 ATH_MSG_ERROR("DV ONNX model has unsupported input " << inputName
707 << ". Configure the input-name properties or update the tool mapping.");
708 return StatusCode::FAILURE;
709 }
710 }
711
712 std::vector<std::string> outputNames{};
713 const std::vector<std::string> availableOutputs = modelOutputNames();
714 if (hasName(availableOutputs, m_outputName.value())) {
715 outputNames.push_back(m_outputName.value());
716 } else if (!availableOutputs.empty()) {
717 ATH_MSG_WARNING("DV ONNX model has no output named " << m_outputName.value()
718 << "; using first model output " << availableOutputs.front() << ".");
719 outputNames.push_back(availableOutputs.front());
720 } else {
721 ATH_MSG_ERROR("DV ONNX model has no outputs.");
722 return StatusCode::FAILURE;
723 }
724
725 return runNamedInference(graphData, inputSpecs, outputNames);
726}
727
728float DVInferenceToolBase::probabilityFromOutput(const Ort::Value& output, float& rawOutput) const {
729 rawOutput = 0.f;
730 const float* data = output.GetTensorData<float>();
731 const auto shapeInfo = output.GetTensorTypeAndShapeInfo();
732 const std::size_t nElem = shapeInfo.GetElementCount();
733 if (nElem == 0 || data == nullptr) return std::numeric_limits<float>::quiet_NaN();
734
735 if (nElem == 1) {
736 rawOutput = data[0];
737 if (m_singleOutputMode.value() == "prob") return rawOutput;
738 if (m_singleOutputMode.value() == "logit" || m_singleOutputMode.value() == "auto") {
739 return InferenceUtils::sigmoid(rawOutput);
740 }
741 return InferenceUtils::sigmoid(rawOutput);
742 }
743
744 if (nElem == 2) {
745 rawOutput = data[1];
746 const float z0 = data[0] - std::max(data[0], data[1]);
747 const float z1 = data[1] - std::max(data[0], data[1]);
748 const float e0 = std::exp(z0);
749 const float e1 = std::exp(z1);
750 return e1 / (e0 + e1);
751 }
752
753 ATH_MSG_WARNING("DV output tensor has " << nElem
754 << " elements; using element 0 with SingleOutputMode=" << m_singleOutputMode.value());
755 rawOutput = data[0];
756 if (m_singleOutputMode.value() == "prob") return rawOutput;
757 if (m_singleOutputMode.value() == "logit" || m_singleOutputMode.value() == "auto") {
758 return InferenceUtils::sigmoid(rawOutput);
759 }
760 return InferenceUtils::sigmoid(rawOutput);
761}
Scalar eta() const
pseudorapidity method
Scalar phi() const
phi method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
static Double_t sp
static Double_t a
Handle class for reading from StoreGate.
Storable container class for CaloTower.
Data class for calorimeter cell towers.
size_type size() const noexcept
Returns the number of elements in the collection.
int8_t side() const
Returns the side of the MS-sector 1 -> A side ; -1 -> C side.
int sector() const
Returns the sector of the MS-sector.
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Gaudi::Property< float > m_caloRMaxMm
Gaudi::Property< std::string > m_inputNodeName
Gaudi::Property< float > m_maxTowerSegmentDR
SG::ReadHandleKey< xAOD::MuonSegmentContainer > m_segmentKey
Gaudi::Property< std::string > m_outputName
std::vector< std::string > modelOutputNames() const
static constexpr std::size_t kNodeFeatureCount
SG::ReadHandleKey< CaloTowerContainer > m_towerKey
Gaudi::Property< float > m_minTowerEnergyMeV
std::vector< std::string > modelInputNames() const
Gaudi::Property< bool > m_requireEdges
Gaudi::Property< int > m_maxEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes
StatusCode runInference(GraphRawData &graphData) const
Run the configured ONNX session on a graph already built by buildGraph.
Gaudi::Property< bool > m_useBucketSegmentSelection
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< float > m_caloZMaxMm
Gaudi::Property< std::string > m_inputEdgeIndexName
static constexpr std::size_t kEdgeFeatureCount
Gaudi::Property< bool > m_fallbackToAllSegments
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< InputTensorSpec > &inputSpecs, const std::vector< std::string > &outputNames) const
StatusCode runGraphInference(const EventContext &ctx, GraphRawData &graphData) const override
IGraphInferenceTool entry point: build the DV event graph and run ONNX.
Gaudi::Property< std::string > m_inputEdgeAttrName
SG::ReadHandleKeyArray< MuonR4::SpacePointContainer > m_spacePointKeys
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Build the DV ONNX input tensors: x, edge_index, edge_attr, n_muon_nodes.
float probabilityFromOutput(const Ort::Value &output, float &rawOutput) const
static constexpr std::size_t kInputTensorCount
StatusCode inferEvent(const EventContext &ctx, DVInferenceResult &result) const
Convenience event-classifier API used by DVInferenceAlg.
Gaudi::Property< std::string > m_singleOutputMode
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
Gaudi::Property< std::string > m_inputNMuonNodesName
unsigned msSector() const
Returns the ms sector corresponding to the expanded sector.
: The muon space point bucket represents a collection of points that will bre processed together in t...
const MuonGMR4::SpectrometerSector * msSector() const
returns th associated muonChamber
The SpacePointPerLayerSorter sort two given space points by their layer Identifier.
Property holding a SG store/key/clid from which a ReadHandle is made.
List of segments.
Definition node.h:24
float numberDoF() const
Returns the numberDoF.
Amg::Vector3D direction() const
Returns the direction as Amg::Vector.
Amg::Vector3D position() const
Returns the position as Amg::Vector.
Eigen::Matrix< double, 3, 1 > Vector3D
T wrapToPi(T phi)
Wrap angle in radians to [-pi, pi].
Definition phihelper.h:24
T deltaPhi(T phiA, T phiB)
Return difference phiA - phiB in range [-pi, pi].
Definition phihelper.h:42
SessionBackend sessionBackend(const SessionToolHandle &sessionTool)
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const Segment * detailedSegment(const xAOD::MuonSegment &seg)
Helper function to navigate from the xAOD::MuonSegment to the MuonR4::Segment.
NRpcCablingAlg reads raw condition data and writes derived condition data to the condition store.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
const Amg::Vector3D & direction() const
Method to retrieve the direction at the Intersection.
double deltaR(double rapidity1, double phi1, double rapidity2, double phi2)
from bare bare rapidity,phi
MuonSegmentContainer_v1 MuonSegmentContainer
Definition of the current "MuonSegment container version".
setWord1 uint16_t
MuonSegment_v1 MuonSegment
Reference the current persistent version:
Helper for azimuthal angle calculations.
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition GraphData.h:25
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition GraphData.h:42
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36