ATLAS Offline Software
Loading...
Searching...
No Matches
MuonML::DVInferenceToolBase Class Reference

Athena tool for DisplacedVertex graph-level ONNX inference. More...

#include <DVInferenceToolBase.h>

Inheritance diagram for MuonML::DVInferenceToolBase:
Collaboration diagram for MuonML::DVInferenceToolBase:

Classes

struct  InputTensorSpec

Public Member Functions

 ~DVInferenceToolBase () override=default
StatusCode initialize () override
StatusCode runGraphInference (const EventContext &ctx, GraphRawData &graphData) const override
 IGraphInferenceTool entry point: build the DV event graph and run ONNX.
StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 Build the DV ONNX input tensors: x, edge_index, edge_attr, n_muon_nodes.
StatusCode runInference (GraphRawData &graphData) const
 Run the configured ONNX session on a graph already built by buildGraph.
StatusCode inferEvent (const EventContext &ctx, DVInferenceResult &result) const
 Convenience event-classifier API used by DVInferenceAlg.

Protected Member Functions

StatusCode setupModel ()
Ort::Session & model () const
std::vector< std::string > modelInputNames () const
std::vector< std::string > modelOutputNames () const
StatusCode runNamedInference (GraphRawData &graphData, const std::vector< InputTensorSpec > &inputSpecs, const std::vector< std::string > &outputNames) const
float probabilityFromOutput (const Ort::Value &output, float &rawOutput) const

Protected Attributes

SG::ReadHandleKey< xAOD::MuonSegmentContainerm_segmentKey
SG::ReadHandleKeyArray< MuonR4::SpacePointContainerm_spacePointKeys
SG::ReadHandleKey< CaloTowerContainerm_towerKey
Gaudi::Property< float > m_minTowerEnergyMeV
Gaudi::Property< float > m_maxTowerSegmentDR
Gaudi::Property< float > m_caloRMaxMm
Gaudi::Property< float > m_caloZMaxMm
Gaudi::Property< int > m_sectorModulo
Gaudi::Property< bool > m_requireEdges
Gaudi::Property< bool > m_useBucketSegmentSelection
Gaudi::Property< bool > m_fallbackToAllSegments
Gaudi::Property< int > m_maxEdges {this, "MaxEdges", -1, "Maximum number of directed segment-tower edges to create; negative means no cap"}
Gaudi::Property< std::string > m_inputNodeName {this, "InputNodeName", "x"}
Gaudi::Property< std::string > m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"}
Gaudi::Property< std::string > m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"}
Gaudi::Property< std::string > m_inputNMuonNodesName {this, "InputNMuonNodesName", "n_muon_nodes"}
Gaudi::Property< std::string > m_outputName {this, "OutputName", "logits"}
Gaudi::Property< std::string > m_singleOutputMode
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 0}
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 0}
Gaudi::Property< bool > m_sanitizeNonFiniteInputs
Gaudi::Property< bool > m_sanitizeNonFinitePredictions
bool m_isCuda {false}
int m_cudaDeviceId {0}

Static Protected Attributes

static constexpr std::size_t kNodeFeatureCount = 7
static constexpr std::size_t kEdgeFeatureCount = 5
static constexpr std::size_t kInputTensorCount = 4
static constexpr std::array< std::string_view, kNodeFeatureCountkDefaultNodeFeatureNames
static constexpr std::array< std::string_view, kEdgeFeatureCountkDefaultEdgeFeatureNames

Private Attributes

ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool

Detailed Description

Athena tool for DisplacedVertex graph-level ONNX inference.

The current exported model embeds any training normalization and consumes raw tensors matching dv_converter_utils.py:

x [num_nodes, 7] edge_index [2, num_edges] edge_attr [num_edges, 5] n_muon_nodes [1] logits [1]

Nodes are ordered as muon-segment nodes first, followed by calorimeter tower nodes, because the model-side normalizer uses n_muon_nodes to split the raw x tensor into muon and calo node blocks.

Definition at line 56 of file DVInferenceToolBase.h.

Constructor & Destructor Documentation

◆ ~DVInferenceToolBase()

MuonML::DVInferenceToolBase::~DVInferenceToolBase ( )
overridedefault

Member Function Documentation

◆ buildGraph()

StatusCode DVInferenceToolBase::buildGraph ( const EventContext & ctx,
GraphRawData & graphData ) const

Build the DV ONNX input tensors: x, edge_index, edge_attr, n_muon_nodes.

Definition at line 259 of file DVInferenceToolBase.cxx.

260 {
261 graphData.graph.reset();
262 graphData.featureLeaves.clear();
263 graphData.srcEdges.clear();
264 graphData.desEdges.clear();
265 graphData.edgeIndexPacked.clear();
266 graphData.spacePointsInBucket.clear();
267 graphData.graph = std::make_unique<InferenceGraph>();
268 graphData.graph->dataTensor.reserve(kInputTensorCount);
269
270 std::vector<DVNodeAux> nodes;
271
272 const xAOD::MuonSegmentContainer* segments{nullptr};
273 ATH_CHECK(SG::get(segments, m_segmentKey, ctx));
274
275 nodes.reserve(segments ? segments->size() : 0u);
276
277 if (segments && m_useBucketSegmentSelection.value() && !m_spacePointKeys.empty()) {
278 using SegmentsPerBucket_t =
279 std::unordered_map<const MuonR4::SpacePointBucket*, SegmentList>;
280
281 using SegmentsPerBucketSignature_t =
282 std::unordered_map<std::string, SegmentList>;
283
284 SegmentsPerBucket_t segmentsPerBucket{};
285 SegmentsPerBucketSignature_t segmentsPerBucketSignature{};
286 for (const xAOD::MuonSegment* seg : *segments) {
287 const auto* detailed = MuonR4::detailedSegment(*seg);
288 const MuonR4::SpacePointBucket* parentBucket = detailed->parent()->parentBucket();
289 appendUniqueSegment(segmentsPerBucket[parentBucket], seg);
290 const std::string parentSig = bucketSignatureKey(*parentBucket);
291 if (!parentSig.empty()) appendUniqueSegment(segmentsPerBucketSignature[parentSig], seg);
292 }
293 std::size_t nSignatureMatchedBuckets = 0u;
294 for (const SG::ReadHandleKey<MuonR4::SpacePointContainer>& spKey : m_spacePointKeys) {
295 const MuonR4::SpacePointContainer* spContainer{nullptr};
296 ATH_CHECK(SG::get(spContainer, spKey, ctx));
297
298 for (const MuonR4::SpacePointBucket* bucket : *spContainer) {
299 const auto it = segmentsPerBucket.find(bucket);
300 const SegmentList* matchedSegments{nullptr};
301 if (it != segmentsPerBucket.end() && !it->second.empty()) {
302 matchedSegments = &it->second;
303 } else {
304 const std::string sig = bucketSignatureKey(*bucket);
305 const auto sigIt = sig.empty() ? segmentsPerBucketSignature.end()
306 : segmentsPerBucketSignature.find(sig);
307 if (sigIt != segmentsPerBucketSignature.end() && !sigIt->second.empty()) {
308 matchedSegments = &sigIt->second;
309 ++nSignatureMatchedBuckets;
310 }
311 }
312 if (!matchedSegments) continue;
313
314 const int bucketSector = bucket->msSector() ? static_cast<int>(bucket->msSector()->sector()) : -1;
315 const uint16_t bucketLayers = countLayersInBucket(*bucket);
316 ATH_MSG_VERBOSE("DV bucket segment node source: key=" << spKey.key()
317 << " sector=" << bucketSector
318 << " layers=" << bucketLayers
319 << " segments=" << matchedSegments->size());
320
321 for (const xAOD::MuonSegment* seg : *matchedSegments) {
322 appendMuonSegmentNode(*seg, bucketSector, nodes);
323 }
324 }
325 }
326
327 ATH_MSG_DEBUG("DV graph built " << nodes.size()
328 << " muon nodes from BucketDumper-style SpacePointBucket-associated segments"
329 << " (signature-matched filtered buckets=" << nSignatureMatchedBuckets << ")");
330 }
331
332 if (segments && nodes.empty() &&
334 if (m_useBucketSegmentSelection.value()) {
335 ATH_MSG_WARNING("No bucket-associated segments were found for DV graph building; "
336 "falling back to all segments from " << m_segmentKey.key()
337 << ". This does not match the training converter exactly.");
338 }
339 for (const xAOD::MuonSegment* seg : *segments) {
340 appendMuonSegmentNode(*seg, static_cast<int>(seg->sector()), nodes);
341 }
342 }
343
344 if (segments && nodes.empty() && m_useBucketSegmentSelection.value() &&
345 !m_fallbackToAllSegments.value()) {
346 ATH_MSG_WARNING("No bucket-associated segments were found for DV graph building. "
347 "Not falling back to all segments because that does not match the training converter.");
348 }
349
350 const std::size_t nMuonNodes = nodes.size();
351
352 if (!m_towerKey.empty() && nMuonNodes > 0u) {
353 const CaloTowerContainer* towers{nullptr};
354 ATH_CHECK(SG::get(towers, m_towerKey, ctx));
355
356 nodes.reserve(nodes.size() + towers->size());
357 for (const CaloTower* tower : *towers) {
358 const float energyMeV = static_cast<float>(tower->energy());
359 if (energyMeV < m_minTowerEnergyMeV) continue;
360
361 const float eta = static_cast<float>(tower->eta());
362 const float phi = static_cast<float>(tower->phi());
363 float minDR = std::numeric_limits<float>::max();
364 for (std::size_t i = 0; i < nMuonNodes; ++i) {
365 minDR = std::min(
366 minDR,
367 static_cast<float>(xAOD::P4Helpers::deltaR(eta, phi, nodes[i].eta, nodes[i].phi)));
368 }
369
370 if (minDR >= m_maxTowerSegmentDR) continue;
371 const Amg::Vector3D direction = Acts::makeDirectionFromPhiEta(
372 static_cast<double>(phi), static_cast<double>(eta));
373 const std::optional<Amg::Vector3D> posMm =
374 firstIntersectionWithEnvelope(direction, m_caloRMaxMm.value(), m_caloZMaxMm.value());
375 if (!posMm) continue;
376
377 const Amg::Vector3D posM = (*posMm) / Gaudi::Units::m;
378
379 DVNodeAux node{};
380 node.kind = NodeKind::Calo;
381 node.features[0] = static_cast<float>(posM.mag());
382 node.features[1] = static_cast<float>(posM.theta());
383 node.features[2] = static_cast<float>(posM.phi());
384 node.features[3] = static_cast<float>(direction.theta());
385 node.features[4] = static_cast<float>(direction.phi());
386 node.features[5] = energyMeV;
387 node.features[6] = static_cast<float>(tower->size());
388 node.eta = eta;
389 node.phi = phi;
390 node.energyLike = energyMeV;
391 node.direction = direction;
392 node.sector = static_cast<int>(
393 MuonR4::ExpandedSector{CxxUtils::wrapToPi(static_cast<double>(phi))}.msSector());
394 nodes.push_back(node);
395 }
396 }
397
398 const std::size_t nCaloNodes = nodes.size() - nMuonNodes;
399 const std::size_t nNodes = nodes.size();
400 graphData.spacePointsInBucket.push_back(static_cast<int64_t>(nMuonNodes));
401 graphData.spacePointsInBucket.push_back(static_cast<int64_t>(nCaloNodes));
402
403 if (nNodes == 0u) {
404 ATH_MSG_WARNING("No muon segment or calo tower nodes found. Skipping DV inference.");
405 return StatusCode::SUCCESS;
406 }
407
408 graphData.featureLeaves.reserve(nNodes * kNodeFeatureCount);
409 for (const DVNodeAux& node : nodes) {
410 graphData.featureLeaves.insert(graphData.featureLeaves.end(),
411 node.features.begin(), node.features.end());
412 }
413
414 const int maxEdges = m_maxEdges.value();
415 std::vector<float> edgeAttr;
416 edgeAttr.reserve(2u * nMuonNodes * std::max<std::size_t>(nCaloNodes, 1u) * kEdgeFeatureCount);
417
418 auto addEdge = [&graphData, &edgeAttr, maxEdges](std::size_t src,
419 std::size_t dst,
420 const DVNodeAux& a,
421 const DVNodeAux& b) -> bool {
422 if (maxEdges >= 0 && static_cast<int>(graphData.srcEdges.size()) >= maxEdges) return false;
423 const float dPhi = CxxUtils::deltaPhi(b.phi, a.phi);
424 const float dEta = b.eta - a.eta;
425 const float cosAng = std::clamp(static_cast<float>(a.direction.dot(b.direction)), -1.f, 1.f);
426 std::array<float, kEdgeFeatureCount> attr{
427 b.energyLike - a.energyLike,
428 dPhi,
429 dEta,
430 cosAng,
431 (a.sector == b.sector) ? 1.f : 0.f};
432
433 graphData.srcEdges.push_back(static_cast<int64_t>(src));
434 graphData.desEdges.push_back(static_cast<int64_t>(dst));
435 edgeAttr.insert(edgeAttr.end(), attr.begin(), attr.end());
436 return true;
437 };
438
439 bool edgeCapReached = false;
440 for (std::size_t im = 0; im < nMuonNodes && !edgeCapReached; ++im) {
441 for (std::size_t ic = nMuonNodes; ic < nNodes; ++ic) {
442 if (xAOD::P4Helpers::deltaR(nodes[im].eta, nodes[im].phi,
443 nodes[ic].eta, nodes[ic].phi) >= m_maxTowerSegmentDR.value()) continue;
444 if (!addEdge(im, ic, nodes[im], nodes[ic])) {
445 edgeCapReached = true;
446 break;
447 }
448 if (!addEdge(ic, im, nodes[ic], nodes[im])) {
449 edgeCapReached = true;
450 break;
451 }
452 }
453 }
454
455 const std::size_t nEdges = graphData.srcEdges.size();
456 if (m_requireEdges.value() && nEdges == 0u) {
457 ATH_MSG_DEBUG("DV graph has no segment-tower edges and RequireEdges=True; skip inference.");
458 graphData.graph.reset();
459 return StatusCode::SUCCESS;
460 }
461
462 if (edgeAttr.size() != nEdges * kEdgeFeatureCount) {
463 ATH_MSG_ERROR("DV edge attribute size mismatch: E=" << nEdges
464 << " edge_attr.size=" << edgeAttr.size());
465 return StatusCode::FAILURE;
466 }
467
468 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
469 std::vector<int64_t> nodeShape{static_cast<int64_t>(nNodes), static_cast<int64_t>(kNodeFeatureCount)};
470 graphData.graph->dataTensor.emplace_back(
471 Ort::Value::CreateTensor<float>(memInfo,
472 graphData.featureLeaves.data(),
473 graphData.featureLeaves.size(),
474 nodeShape.data(),
475 nodeShape.size()));
476
477 graphData.edgeIndexPacked.clear();
478 graphData.edgeIndexPacked.reserve(2u * nEdges);
479 graphData.edgeIndexPacked.insert(graphData.edgeIndexPacked.end(), graphData.srcEdges.begin(), graphData.srcEdges.end());
480 graphData.edgeIndexPacked.insert(graphData.edgeIndexPacked.end(), graphData.desEdges.begin(), graphData.desEdges.end());
481
482 std::vector<int64_t> edgeIndexShape{2, static_cast<int64_t>(nEdges)};
483 graphData.graph->dataTensor.emplace_back(
484 Ort::Value::CreateTensor<int64_t>(memInfo,
485 graphData.edgeIndexPacked.data(),
486 graphData.edgeIndexPacked.size(),
487 edgeIndexShape.data(),
488 edgeIndexShape.size()));
489
490 Ort::AllocatorWithDefaultOptions allocator;
491 std::vector<int64_t> edgeAttrShape{static_cast<int64_t>(nEdges), static_cast<int64_t>(kEdgeFeatureCount)};
492 Ort::Value edgeAttrTensor = Ort::Value::CreateTensor(allocator,
493 edgeAttrShape.data(),
494 edgeAttrShape.size(),
495 ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
496 if (!edgeAttr.empty()) {
497 float* edgeAttrData = edgeAttrTensor.GetTensorMutableData<float>();
498 std::copy(edgeAttr.begin(), edgeAttr.end(), edgeAttrData);
499 }
500 graphData.graph->dataTensor.emplace_back(std::move(edgeAttrTensor));
501
502 std::vector<int64_t> nMuonShape{1};
503 graphData.graph->dataTensor.emplace_back(
504 Ort::Value::CreateTensor<int64_t>(memInfo,
505 graphData.spacePointsInBucket.data(),
506 1,
507 nMuonShape.data(),
508 nMuonShape.size()));
509
510 if (msgLvl(MSG::DEBUG)) {
511 ATH_MSG_DEBUG("Built DV graph: N=" << nNodes << " (muon=" << nMuonNodes
512 << ", calo=" << nCaloNodes << "), E=" << nEdges
513 << ", n_muon_nodes=" << graphData.spacePointsInBucket[0]);
514 const std::size_t dumpNodes = std::min<std::size_t>(m_debugDumpFirstNNodes.value(), nNodes);
515 for (std::size_t i = 0; i < dumpNodes; ++i) {
516 std::ostringstream row;
517 row << "DVNode[" << i << "] kind=" << (nodes[i].kind == NodeKind::Muon ? "muon" : "calo") << ":";
518 for (std::size_t f = 0; f < kNodeFeatureCount; ++f) {
519 row << " f" << f << "=" << graphData.featureLeaves[i * kNodeFeatureCount + f];
520 }
521 ATH_MSG_DEBUG(row.str());
522 }
523 const std::size_t dumpEdges = std::min<std::size_t>(m_debugDumpFirstNEdges.value(), nEdges);
524 for (std::size_t e = 0; e < dumpEdges; ++e) {
525 ATH_MSG_DEBUG("DVEdge[" << e << "]: " << graphData.srcEdges[e]
526 << " -> " << graphData.desEdges[e]
527 << " edge_attr=["
528 << edgeAttr[e * kEdgeFeatureCount + 0] << ", "
529 << edgeAttr[e * kEdgeFeatureCount + 1] << ", "
530 << edgeAttr[e * kEdgeFeatureCount + 2] << ", "
531 << edgeAttr[e * kEdgeFeatureCount + 3] << ", "
532 << edgeAttr[e * kEdgeFeatureCount + 4] << "]");
533 }
534 }
535
536 graphData.srcEdges.clear();
537 graphData.desEdges.clear();
538 return StatusCode::SUCCESS;
539}
Scalar eta() const
pseudorapidity method
Scalar phi() const
phi method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
Athena::TPCnvVers::Old Athena::TPCnvVers::Old Athena::TPCnvVers::Current Athena::TPCnvVers::Current CaloTowerContainer
Definition CaloTPCnv.cxx:68
static Double_t a
size_type size() const noexcept
Returns the number of elements in the collection.
Gaudi::Property< float > m_caloRMaxMm
Gaudi::Property< float > m_maxTowerSegmentDR
SG::ReadHandleKey< xAOD::MuonSegmentContainer > m_segmentKey
static constexpr std::size_t kNodeFeatureCount
SG::ReadHandleKey< CaloTowerContainer > m_towerKey
Gaudi::Property< float > m_minTowerEnergyMeV
Gaudi::Property< bool > m_requireEdges
Gaudi::Property< int > m_maxEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNNodes
Gaudi::Property< bool > m_useBucketSegmentSelection
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Gaudi::Property< float > m_caloZMaxMm
static constexpr std::size_t kEdgeFeatureCount
Gaudi::Property< bool > m_fallbackToAllSegments
SG::ReadHandleKeyArray< MuonR4::SpacePointContainer > m_spacePointKeys
static constexpr std::size_t kInputTensorCount
Eigen::Matrix< double, 3, 1 > Vector3D
T wrapToPi(T phi)
Wrap angle in radians to [-pi, pi].
Definition phihelper.h:24
T deltaPhi(T phiA, T phiB)
Return difference phiA - phiB in range [-pi, pi].
Definition phihelper.h:42
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const Segment * detailedSegment(const xAOD::MuonSegment &seg)
Helper function to navigate from the xAOD::MuonSegment to the MuonR4::Segment.
row
Appending html table to final .html summary file.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
const Amg::Vector3D & direction() const
Method to retrieve the direction at the Intersection.
@ u
Enums for curvilinear frames.
Definition ParamDefs.h:77
int ic
Definition grepfile.py:33
double deltaR(double rapidity1, double phi1, double rapidity2, double phi2)
from bare bare rapidity,phi
MuonSegmentContainer_v1 MuonSegmentContainer
Definition of the current "MuonSegment container version".
setWord1 uint16_t
MuonSegment_v1 MuonSegment
Reference the current persistent version:
FeatureVec_t featureLeaves
Vector containing all features.
Definition GraphData.h:30
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition GraphData.h:42
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition GraphData.h:46
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition GraphData.h:32
EdgeCounterVec_t desEdges
Vect.
Definition GraphData.h:34
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition GraphData.h:36

◆ inferEvent()

StatusCode DVInferenceToolBase::inferEvent ( const EventContext & ctx,
DVInferenceResult & result ) const

Convenience event-classifier API used by DVInferenceAlg.

Definition at line 224 of file DVInferenceToolBase.cxx.

225 {
226 result = DVInferenceResult{};
227 GraphRawData graphData{};
228 ATH_CHECK(buildGraph(ctx, graphData));
229 if (!graphData.graph || graphData.graph->dataTensor.size() < kInputTensorCount) {
230 ATH_MSG_WARNING("DV graph is empty; no event-classifier output will be produced.");
231 return StatusCode::SUCCESS;
232 }
233
234 const auto xShape = graphData.graph->dataTensor[0].GetTensorTypeAndShapeInfo().GetShape();
235 const auto edgeShape = graphData.graph->dataTensor[1].GetTensorTypeAndShapeInfo().GetShape();
236 result.nNodes = !xShape.empty() && xShape[0] > 0 ? static_cast<std::size_t>(xShape[0]) : 0u;
237 result.nEdges = edgeShape.size() > 1 && edgeShape[1] > 0 ? static_cast<std::size_t>(edgeShape[1]) : 0u;
238 if (graphData.spacePointsInBucket.size() >= 2) {
239 result.nMuonNodes = static_cast<std::size_t>(std::max<int64_t>(graphData.spacePointsInBucket[0], 0));
240 result.nCaloNodes = static_cast<std::size_t>(std::max<int64_t>(graphData.spacePointsInBucket[1], 0));
241 }
242
243 ATH_CHECK(runInference(graphData));
244 if (graphData.graph->dataTensor.size() <= kInputTensorCount) {
245 ATH_MSG_ERROR("DV inference finished without an output tensor.");
246 return StatusCode::FAILURE;
247 }
248
249 result.probability = probabilityFromOutput(graphData.graph->dataTensor.back(), result.rawOutput);
250 result.valid = std::isfinite(result.probability);
251 ATH_MSG_DEBUG("DV event classifier: N=" << result.nNodes
252 << " (muon=" << result.nMuonNodes << ", calo=" << result.nCaloNodes
253 << "), E=" << result.nEdges
254 << ", raw=" << result.rawOutput
255 << ", probability=" << result.probability);
256 return StatusCode::SUCCESS;
257}
StatusCode runInference(GraphRawData &graphData) const
Run the configured ONNX session on a graph already built by buildGraph.
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Build the DV ONNX input tensors: x, edge_index, edge_attr, n_muon_nodes.
float probabilityFromOutput(const Ort::Value &output, float &rawOutput) const

◆ initialize()

StatusCode DVInferenceToolBase::initialize ( )
override

Definition at line 136 of file DVInferenceToolBase.cxx.

136 {
138 if (m_singleOutputMode.value() != "auto" &&
139 m_singleOutputMode.value() != "logit" &&
140 m_singleOutputMode.value() != "prob") {
141 ATH_MSG_ERROR("SingleOutputMode must be one of auto, logit, prob; got " << m_singleOutputMode.value());
142 return StatusCode::FAILURE;
143 }
144
145 if (m_useBucketSegmentSelection.value() && m_spacePointKeys.size() > 1u) {
146 ATH_MSG_WARNING("DV SpacePointKeys has " << m_spacePointKeys.size()
147 << " entries. The MuonBucketDump training samples use the "
148 << "default SegmentKey array with segments attached only to "
149 << "the first SpacePointKeys entry. For parity with training, "
150 << "configure SpacePointKeys=['MuonSpacePoints'] unless the "
151 << "training dump was produced with matching segment keys for all entries.");
152 }
153
154 ATH_MSG_INFO("Initialized DVInferenceToolBase with SegmentKey=" << m_segmentKey.key()
155 << ", SpacePointKeys=" << m_spacePointKeys.size()
157 << ", TowerContainerKey="
158 << (m_towerKey.empty() ? std::string("<disabled>") : m_towerKey.key())
159 << ", " << m_minTowerEnergyMeV
160 << ", " << m_maxTowerSegmentDR
161 << ", " << m_caloRMaxMm
162 << ", " << m_caloZMaxMm
163 << ", " << m_fallbackToAllSegments
164 << ", " << m_singleOutputMode);
165 return StatusCode::SUCCESS;
166}
#define ATH_MSG_INFO(x)
Gaudi::Property< std::string > m_singleOutputMode

◆ model()

Ort::Session & DVInferenceToolBase::model ( ) const
protected

Definition at line 168 of file DVInferenceToolBase.cxx.

168 {
169 return m_onnxSessionTool->session();
170}
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool

◆ modelInputNames()

std::vector< std::string > DVInferenceToolBase::modelInputNames ( ) const
protected

Definition at line 172 of file DVInferenceToolBase.cxx.

172 {
173 std::vector<std::string> names{};
174 Ort::AllocatorWithDefaultOptions allocator;
175 const std::size_t nInputs = model().GetInputCount();
176 names.reserve(nInputs);
177 for (std::size_t i = 0; i < nInputs; ++i) {
178 auto name = model().GetInputNameAllocated(i, allocator);
179 if (name) names.emplace_back(name.get());
180 }
181 return names;
182}

◆ modelOutputNames()

std::vector< std::string > DVInferenceToolBase::modelOutputNames ( ) const
protected

Definition at line 184 of file DVInferenceToolBase.cxx.

184 {
185 std::vector<std::string> names{};
186 Ort::AllocatorWithDefaultOptions allocator;
187 const std::size_t nOutputs = model().GetOutputCount();
188 names.reserve(nOutputs);
189 for (std::size_t i = 0; i < nOutputs; ++i) {
190 auto name = model().GetOutputNameAllocated(i, allocator);
191 if (name) names.emplace_back(name.get());
192 }
193 return names;
194}

◆ probabilityFromOutput()

float DVInferenceToolBase::probabilityFromOutput ( const Ort::Value & output,
float & rawOutput ) const
protected

Definition at line 728 of file DVInferenceToolBase.cxx.

728 {
729 rawOutput = 0.f;
730 const float* data = output.GetTensorData<float>();
731 const auto shapeInfo = output.GetTensorTypeAndShapeInfo();
732 const std::size_t nElem = shapeInfo.GetElementCount();
733 if (nElem == 0 || data == nullptr) return std::numeric_limits<float>::quiet_NaN();
734
735 if (nElem == 1) {
736 rawOutput = data[0];
737 if (m_singleOutputMode.value() == "prob") return rawOutput;
738 if (m_singleOutputMode.value() == "logit" || m_singleOutputMode.value() == "auto") {
739 return InferenceUtils::sigmoid(rawOutput);
740 }
741 return InferenceUtils::sigmoid(rawOutput);
742 }
743
744 if (nElem == 2) {
745 rawOutput = data[1];
746 const float z0 = data[0] - std::max(data[0], data[1]);
747 const float z1 = data[1] - std::max(data[0], data[1]);
748 const float e0 = std::exp(z0);
749 const float e1 = std::exp(z1);
750 return e1 / (e0 + e1);
751 }
752
753 ATH_MSG_WARNING("DV output tensor has " << nElem
754 << " elements; using element 0 with SingleOutputMode=" << m_singleOutputMode.value());
755 rawOutput = data[0];
756 if (m_singleOutputMode.value() == "prob") return rawOutput;
757 if (m_singleOutputMode.value() == "logit" || m_singleOutputMode.value() == "auto") {
758 return InferenceUtils::sigmoid(rawOutput);
759 }
760 return InferenceUtils::sigmoid(rawOutput);
761}
double e0(const xAOD::CaloCluster &cluster)
return the uncorrected cluster energy in pre-sampler
double e1(const xAOD::CaloCluster &cluster)
return the uncorrected cluster energy in 1st sampling
output
Definition merge.py:16

◆ runGraphInference()

StatusCode DVInferenceToolBase::runGraphInference ( const EventContext & ctx,
GraphRawData & graphData ) const
override

IGraphInferenceTool entry point: build the DV event graph and run ONNX.

Definition at line 214 of file DVInferenceToolBase.cxx.

215 {
216 ATH_CHECK(buildGraph(ctx, graphData));
217 if (!graphData.graph || graphData.graph->dataTensor.empty()) {
218 ATH_MSG_DEBUG("DV graph has no input tensors; skip inference for this event.");
219 return StatusCode::SUCCESS;
220 }
221 return runInference(graphData);
222}

◆ runInference()

StatusCode DVInferenceToolBase::runInference ( GraphRawData & graphData) const

Run the configured ONNX session on a graph already built by buildGraph.

Definition at line 655 of file DVInferenceToolBase.cxx.

655 {
656 const std::vector<std::string> availableInputs = modelInputNames();
657 if (availableInputs.empty()) {
658 ATH_MSG_ERROR("DV ONNX model has no inputs.");
659 return StatusCode::FAILURE;
660 }
661
662 ATH_MSG_DEBUG("DV ONNX model inputs: " << joinNames(availableInputs));
663
664 const std::string nodeName = m_inputNodeName.value();
665 const std::string edgeIndexName = m_inputEdgeIndexName.value();
666 const std::string edgeAttrName = m_inputEdgeAttrName.value();
667 const std::string nMuonNodesName = m_inputNMuonNodesName.value();
668
669 std::vector<InputTensorSpec> inputSpecs{};
670 inputSpecs.reserve(kInputTensorCount);
671
672 auto addIfPresent = [&availableInputs, &inputSpecs](const std::string& name, std::size_t tensorIndex) {
673 if (hasName(availableInputs, name)) {
674 inputSpecs.push_back(InputTensorSpec{name, tensorIndex});
675 return true;
676 }
677 return false;
678 };
679
680 const bool hasNodeInput = addIfPresent(nodeName, 0u);
681 const bool hasEdgeIndexInput = addIfPresent(edgeIndexName, 1u);
682 const bool hasEdgeAttrInput = addIfPresent(edgeAttrName, 2u);
683 const bool hasNMuonNodesInput = addIfPresent(nMuonNodesName, 3u);
684
685 if (!hasNodeInput || !hasEdgeIndexInput) {
686 ATH_MSG_ERROR("DV ONNX model is missing required inputs. Expected at least "
687 << nodeName << " and " << edgeIndexName
688 << "; model inputs are: " << joinNames(availableInputs));
689 return StatusCode::FAILURE;
690 }
691
692 if (!hasEdgeAttrInput) {
693 ATH_MSG_DEBUG("DV ONNX model has no input named " << edgeAttrName
694 << "; not binding edge_attr. This is expected for exports where "
695 << "the architecture does not consume edge attributes and ONNX pruned the input.");
696 }
697 if (!hasNMuonNodesInput) {
698 ATH_MSG_DEBUG("DV ONNX model has no input named " << nMuonNodesName
699 << "; not binding n_muon_nodes. This is expected only if the exported "
700 << "model does not need model-side muon/calo normalization.");
701 }
702
703 for (const std::string& inputName : availableInputs) {
704 if (inputName != nodeName && inputName != edgeIndexName &&
705 inputName != edgeAttrName && inputName != nMuonNodesName) {
706 ATH_MSG_ERROR("DV ONNX model has unsupported input " << inputName
707 << ". Configure the input-name properties or update the tool mapping.");
708 return StatusCode::FAILURE;
709 }
710 }
711
712 std::vector<std::string> outputNames{};
713 const std::vector<std::string> availableOutputs = modelOutputNames();
714 if (hasName(availableOutputs, m_outputName.value())) {
715 outputNames.push_back(m_outputName.value());
716 } else if (!availableOutputs.empty()) {
717 ATH_MSG_WARNING("DV ONNX model has no output named " << m_outputName.value()
718 << "; using first model output " << availableOutputs.front() << ".");
719 outputNames.push_back(availableOutputs.front());
720 } else {
721 ATH_MSG_ERROR("DV ONNX model has no outputs.");
722 return StatusCode::FAILURE;
723 }
724
725 return runNamedInference(graphData, inputSpecs, outputNames);
726}
Gaudi::Property< std::string > m_inputNodeName
Gaudi::Property< std::string > m_outputName
std::vector< std::string > modelOutputNames() const
std::vector< std::string > modelInputNames() const
Gaudi::Property< std::string > m_inputEdgeIndexName
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< InputTensorSpec > &inputSpecs, const std::vector< std::string > &outputNames) const
Gaudi::Property< std::string > m_inputEdgeAttrName
Gaudi::Property< std::string > m_inputNMuonNodesName

◆ runNamedInference()

StatusCode DVInferenceToolBase::runNamedInference ( GraphRawData & graphData,
const std::vector< InputTensorSpec > & inputSpecs,
const std::vector< std::string > & outputNames ) const
protected

Definition at line 541 of file DVInferenceToolBase.cxx.

544 {
545 if (!graphData.graph) {
546 ATH_MSG_ERROR("Graph data is not built.");
547 return StatusCode::FAILURE;
548 }
549 if (inputSpecs.empty()) {
550 ATH_MSG_ERROR("No DV ONNX inputs were selected for inference.");
551 return StatusCode::FAILURE;
552 }
553
554 for (const InputTensorSpec& spec : inputSpecs) {
555 if (spec.tensorIndex >= graphData.graph->dataTensor.size()) {
556 ATH_MSG_ERROR("Input " << spec.name << " requests tensor index " << spec.tensorIndex
557 << " but only " << graphData.graph->dataTensor.size()
558 << " tensors were prepared.");
559 return StatusCode::FAILURE;
560 }
561 }
562
563 std::vector<const char*> inputNamePtrs{};
564 inputNamePtrs.reserve(inputSpecs.size());
565 for (const InputTensorSpec& spec : inputSpecs) {
566 inputNamePtrs.push_back(spec.name.c_str());
567 }
568
569 std::vector<const char*> outputNamePtrs{};
570 outputNamePtrs.reserve(outputNames.size());
571 for (const std::string& name : outputNames) {
572 outputNamePtrs.push_back(name.c_str());
573 }
574
575 graphData.graph->dataTensor.reserve(graphData.graph->dataTensor.size() + outputNamePtrs.size());
576
577 Ort::RunOptions runOptions;
578 runOptions.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
579
580 if (m_isCuda) {
581 Ort::IoBinding binding(model());
582 for (const InputTensorSpec& spec : inputSpecs) {
583 binding.BindInput(spec.name.c_str(), graphData.graph->dataTensor[spec.tensorIndex]);
584 }
585 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
586 for (const char* outName : outputNamePtrs) {
587 binding.BindOutput(outName, cpuOut);
588 }
589
590 model().Run(runOptions, binding);
591 binding.SynchronizeOutputs();
592
593 std::vector<Ort::Value> outputs = binding.GetOutputValues();
594 if (outputs.empty()) {
595 ATH_MSG_ERROR("IoBinding inference returned empty output.");
596 return StatusCode::FAILURE;
597 }
598
599 if (m_sanitizeNonFinitePredictions.value()) {
600 float* outData = outputs[0].GetTensorMutableData<float>();
601 const std::size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
602 for (std::size_t i = 0; i < outSize; ++i) {
603 if (!std::isfinite(outData[i])) {
604 ATH_MSG_WARNING("Non-finite DV prediction detected at " << i << " -> set to -100.");
605 outData[i] = -100.f;
606 }
607 }
608 }
609
610 for (auto& v : outputs) {
611 graphData.graph->dataTensor.emplace_back(std::move(v));
612 }
613 return StatusCode::SUCCESS;
614 }
615
616 std::vector<Ort::Value> orderedInputs{};
617 orderedInputs.reserve(inputSpecs.size());
618 for (const InputTensorSpec& spec : inputSpecs) {
619 orderedInputs.emplace_back(std::move(graphData.graph->dataTensor[spec.tensorIndex]));
620 }
621
622 std::vector<Ort::Value> outputs =
623 model().Run(runOptions,
624 inputNamePtrs.data(),
625 orderedInputs.data(),
626 inputNamePtrs.size(),
627 outputNamePtrs.data(),
628 outputNamePtrs.size());
629
630 if (outputs.empty()) {
631 ATH_MSG_ERROR("Inference returned empty output.");
632 return StatusCode::FAILURE;
633 }
634
635 ATH_MSG_DEBUG("DV ONNX raw output elementCount = "
636 << outputs[0].GetTensorTypeAndShapeInfo().GetElementCount());
637
638 if (m_sanitizeNonFinitePredictions.value()) {
639 float* outData = outputs[0].GetTensorMutableData<float>();
640 const std::size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
641 for (std::size_t i = 0; i < outSize; ++i) {
642 if (!std::isfinite(outData[i])) {
643 ATH_MSG_WARNING("Non-finite DV prediction detected at " << i << " -> set to -100.");
644 outData[i] = -100.f;
645 }
646 }
647 }
648
649 for (auto& v : outputs) {
650 graphData.graph->dataTensor.emplace_back(std::move(v));
651 }
652 return StatusCode::SUCCESS;
653}
Gaudi::Property< bool > m_sanitizeNonFinitePredictions

◆ setupModel()

StatusCode DVInferenceToolBase::setupModel ( )
protected

Definition at line 196 of file DVInferenceToolBase.cxx.

196 {
197 ATH_CHECK(m_onnxSessionTool.retrieve());
198 ATH_CHECK(m_segmentKey.initialize());
199 ATH_CHECK(m_spacePointKeys.initialize());
201
202 const InferenceUtils::SessionBackend backend = InferenceUtils::sessionBackend(m_onnxSessionTool);
203 m_isCuda = backend.isCuda;
204 m_cudaDeviceId = backend.cudaDeviceId;
205 if (m_isCuda) {
206 ATH_MSG_INFO("ONNX session is running on CUDA device " << m_cudaDeviceId
207 << ". I/O binding will be used.");
208 } else {
209 ATH_MSG_INFO("ONNX session is running on CPU.");
210 }
211 return StatusCode::SUCCESS;
212}
SessionBackend sessionBackend(const SessionToolHandle &sessionTool)

Member Data Documentation

◆ kDefaultEdgeFeatureNames

std::array<std::string_view, kEdgeFeatureCount> MuonML::DVInferenceToolBase::kDefaultEdgeFeatureNames
staticconstexprprotected
Initial value:
= {
"d_energy_like", "d_phi", "d_eta", "cos_angle", "same_sector"}

Definition at line 82 of file DVInferenceToolBase.h.

82 {
83 "d_energy_like", "d_phi", "d_eta", "cos_angle", "same_sector"};

◆ kDefaultNodeFeatureNames

std::array<std::string_view, kNodeFeatureCount> MuonML::DVInferenceToolBase::kDefaultNodeFeatureNames
staticconstexprprotected
Initial value:
= {
"r_pos", "theta_pos", "phi_pos", "theta_dir", "phi_dir", "energy_like", "nCells_or_DoF"}

Definition at line 80 of file DVInferenceToolBase.h.

80 {
81 "r_pos", "theta_pos", "phi_pos", "theta_dir", "phi_dir", "energy_like", "nCells_or_DoF"};

◆ kEdgeFeatureCount

std::size_t MuonML::DVInferenceToolBase::kEdgeFeatureCount = 5
staticconstexprprotected

Definition at line 77 of file DVInferenceToolBase.h.

◆ kInputTensorCount

std::size_t MuonML::DVInferenceToolBase::kInputTensorCount = 4
staticconstexprprotected

Definition at line 78 of file DVInferenceToolBase.h.

◆ kNodeFeatureCount

std::size_t MuonML::DVInferenceToolBase::kNodeFeatureCount = 7
staticconstexprprotected

Definition at line 76 of file DVInferenceToolBase.h.

◆ m_caloRMaxMm

Gaudi::Property<float> MuonML::DVInferenceToolBase::m_caloRMaxMm
protected
Initial value:
{
this, "CaloRMaxMm", 4250.f, "Barrel radius used for the calo-envelope intersection in mm"}

Definition at line 114 of file DVInferenceToolBase.h.

114 {
115 this, "CaloRMaxMm", 4250.f, "Barrel radius used for the calo-envelope intersection in mm"};

◆ m_caloZMaxMm

Gaudi::Property<float> MuonML::DVInferenceToolBase::m_caloZMaxMm
protected
Initial value:
{
this, "CaloZMaxMm", 6500.f, "Endcap |z| used for the calo-envelope intersection in mm"}

Definition at line 116 of file DVInferenceToolBase.h.

116 {
117 this, "CaloZMaxMm", 6500.f, "Endcap |z| used for the calo-envelope intersection in mm"};

◆ m_cudaDeviceId

int MuonML::DVInferenceToolBase::m_cudaDeviceId {0}
protected

Definition at line 145 of file DVInferenceToolBase.h.

145{0};

◆ m_debugDumpFirstNEdges

Gaudi::Property<unsigned int> MuonML::DVInferenceToolBase::m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 0}
protected

Definition at line 138 of file DVInferenceToolBase.h.

138{this, "DebugDumpFirstNEdges", 0};

◆ m_debugDumpFirstNNodes

Gaudi::Property<unsigned int> MuonML::DVInferenceToolBase::m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 0}
protected

Definition at line 137 of file DVInferenceToolBase.h.

137{this, "DebugDumpFirstNNodes", 0};

◆ m_fallbackToAllSegments

Gaudi::Property<bool> MuonML::DVInferenceToolBase::m_fallbackToAllSegments
protected
Initial value:
{
this, "FallbackToAllSegments", false, "If bucket-segment matching fails, fall back to all SegmentKey segments."}

Definition at line 124 of file DVInferenceToolBase.h.

124 {
125 this, "FallbackToAllSegments", false, "If bucket-segment matching fails, fall back to all SegmentKey segments."};

◆ m_inputEdgeAttrName

Gaudi::Property<std::string> MuonML::DVInferenceToolBase::m_inputEdgeAttrName {this, "InputEdgeAttrName", "edge_attr"}
protected

Definition at line 130 of file DVInferenceToolBase.h.

130{this, "InputEdgeAttrName", "edge_attr"};

◆ m_inputEdgeIndexName

Gaudi::Property<std::string> MuonML::DVInferenceToolBase::m_inputEdgeIndexName {this, "InputEdgeIndexName", "edge_index"}
protected

Definition at line 129 of file DVInferenceToolBase.h.

129{this, "InputEdgeIndexName", "edge_index"};

◆ m_inputNMuonNodesName

Gaudi::Property<std::string> MuonML::DVInferenceToolBase::m_inputNMuonNodesName {this, "InputNMuonNodesName", "n_muon_nodes"}
protected

Definition at line 131 of file DVInferenceToolBase.h.

131{this, "InputNMuonNodesName", "n_muon_nodes"};

◆ m_inputNodeName

Gaudi::Property<std::string> MuonML::DVInferenceToolBase::m_inputNodeName {this, "InputNodeName", "x"}
protected

Definition at line 128 of file DVInferenceToolBase.h.

128{this, "InputNodeName", "x"};

◆ m_isCuda

bool MuonML::DVInferenceToolBase::m_isCuda {false}
protected

Definition at line 144 of file DVInferenceToolBase.h.

144{false};

◆ m_maxEdges

Gaudi::Property<int> MuonML::DVInferenceToolBase::m_maxEdges {this, "MaxEdges", -1, "Maximum number of directed segment-tower edges to create; negative means no cap"}
protected

Definition at line 126 of file DVInferenceToolBase.h.

126{this, "MaxEdges", -1, "Maximum number of directed segment-tower edges to create; negative means no cap"};

◆ m_maxTowerSegmentDR

Gaudi::Property<float> MuonML::DVInferenceToolBase::m_maxTowerSegmentDR
protected
Initial value:
{
this, "MaxTowerSegmentDR", 0.4f, "Maximum segment-calo deltaR used in the converter"}

Definition at line 112 of file DVInferenceToolBase.h.

112 {
113 this, "MaxTowerSegmentDR", 0.4f, "Maximum segment-calo deltaR used in the converter"};

◆ m_minTowerEnergyMeV

Gaudi::Property<float> MuonML::DVInferenceToolBase::m_minTowerEnergyMeV
protected
Initial value:
{
this, "MinTowerEnergyMeV", 1000.f, "Minimum calo tower energy used as a DV graph node"}

Definition at line 110 of file DVInferenceToolBase.h.

110 {
111 this, "MinTowerEnergyMeV", 1000.f, "Minimum calo tower energy used as a DV graph node"};

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::DVInferenceToolBase::m_onnxSessionTool
private
Initial value:
{
this, "ModelSession", "", "ONNX Runtime session tool for the DV classifier"}

Definition at line 148 of file DVInferenceToolBase.h.

148 {
149 this, "ModelSession", "", "ONNX Runtime session tool for the DV classifier"};

◆ m_outputName

Gaudi::Property<std::string> MuonML::DVInferenceToolBase::m_outputName {this, "OutputName", "logits"}
protected

Definition at line 132 of file DVInferenceToolBase.h.

132{this, "OutputName", "logits"};

◆ m_requireEdges

Gaudi::Property<bool> MuonML::DVInferenceToolBase::m_requireEdges
protected
Initial value:
{
this, "RequireEdges", false, "Skip inference when the event graph has no segment-tower edges"}

Definition at line 120 of file DVInferenceToolBase.h.

120 {
121 this, "RequireEdges", false, "Skip inference when the event graph has no segment-tower edges"};

◆ m_sanitizeNonFiniteInputs

Gaudi::Property<bool> MuonML::DVInferenceToolBase::m_sanitizeNonFiniteInputs
protected
Initial value:
{
this, "SanitizeNonFiniteInputs", true, "Replace non-finite input features with zero before creating ONNX tensors"}

Definition at line 139 of file DVInferenceToolBase.h.

139 {
140 this, "SanitizeNonFiniteInputs", true, "Replace non-finite input features with zero before creating ONNX tensors"};

◆ m_sanitizeNonFinitePredictions

Gaudi::Property<bool> MuonML::DVInferenceToolBase::m_sanitizeNonFinitePredictions
protected
Initial value:
{
this, "SanitizeNonFinitePredictions", false, "Replace non-finite ONNX outputs with -100 and log a warning"}

Definition at line 141 of file DVInferenceToolBase.h.

141 {
142 this, "SanitizeNonFinitePredictions", false, "Replace non-finite ONNX outputs with -100 and log a warning"};

◆ m_sectorModulo

Gaudi::Property<int> MuonML::DVInferenceToolBase::m_sectorModulo
protected
Initial value:
{
this, "SectorModulo", 16, "Number of sectors used by the calo phi->sector converter"}

Definition at line 118 of file DVInferenceToolBase.h.

118 {
119 this, "SectorModulo", 16, "Number of sectors used by the calo phi->sector converter"};

◆ m_segmentKey

SG::ReadHandleKey<xAOD::MuonSegmentContainer> MuonML::DVInferenceToolBase::m_segmentKey
protected
Initial value:
{
this, "SegmentKey", "MuonSegmentsFromR4", "Input R4 muon segment container"}

Definition at line 102 of file DVInferenceToolBase.h.

102 {
103 this, "SegmentKey", "MuonSegmentsFromR4", "Input R4 muon segment container"};

◆ m_singleOutputMode

Gaudi::Property<std::string> MuonML::DVInferenceToolBase::m_singleOutputMode
protected
Initial value:
{
this, "SingleOutputMode", "logit", "How to interpret a one-value output: auto, logit, or prob"}

Definition at line 134 of file DVInferenceToolBase.h.

134 {
135 this, "SingleOutputMode", "logit", "How to interpret a one-value output: auto, logit, or prob"};

◆ m_spacePointKeys

SG::ReadHandleKeyArray<MuonR4::SpacePointContainer> MuonML::DVInferenceToolBase::m_spacePointKeys
protected
Initial value:
{
this, "SpacePointKeys", {"MuonSpacePoints"},
"Default is MuonSpacePoints only, matching the training MuonBucketDump SegmentKey alignment."}

Definition at line 104 of file DVInferenceToolBase.h.

104 {
105 this, "SpacePointKeys", {"MuonSpacePoints"},
106 "Default is MuonSpacePoints only, matching the training MuonBucketDump SegmentKey alignment."};

◆ m_towerKey

SG::ReadHandleKey<CaloTowerContainer> MuonML::DVInferenceToolBase::m_towerKey
protected
Initial value:
{
this, "TowerContainerKey", "CombinedTower", "Input calorimeter tower container"}

Definition at line 107 of file DVInferenceToolBase.h.

107 {
108 this, "TowerContainerKey", "CombinedTower", "Input calorimeter tower container"};

◆ m_useBucketSegmentSelection

Gaudi::Property<bool> MuonML::DVInferenceToolBase::m_useBucketSegmentSelection
protected
Initial value:
{
this, "UseBucketSegmentSelection", true, "Build muon nodes from segment-parent SpacePoint buckets"}

Definition at line 122 of file DVInferenceToolBase.h.

122 {
123 this, "UseBucketSegmentSelection", true, "Build muon nodes from segment-parent SpacePoint buckets"};

The documentation for this class was generated from the following files: