ATLAS Offline Software
BucketInferenceToolBase.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN
3  for the benefit of the ATLAS collaboration
4 */
6 
8 #include "StoreGate/ReadHandle.h"
11 
12 #include "BucketGraphUtils.h"
14 
15 #include <algorithm>
16 #include <limits>
17 #include <span>
18 
19 using namespace MuonML;
20 
22  return m_onnxSessionTool->session();
23 }
24 
26  ATH_CHECK(m_onnxSessionTool.retrieve());
27  ATH_CHECK(m_readKey.initialize());
29  return StatusCode::SUCCESS;
30 }
31 
33  GraphRawData& graphData) const {
34  graphData.graph = std::make_unique<InferenceGraph>();
35  graphData.srcEdges.clear();
36  graphData.desEdges.clear();
37  graphData.featureLeaves.clear();
38  graphData.spacePointsInBucket.clear();
39 
40  const MuonR4::SpacePointContainer* buckets{nullptr};
41  ATH_CHECK(SG::get(buckets, m_readKey, ctx));
42 
43  const ActsTrk::GeometryContext* gctx = nullptr;
44  ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
45 
46  std::vector<BucketGraphUtils::NodeAux> nodes;
47  BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes,
48  graphData.featureLeaves,
49  graphData.spacePointsInBucket); // now int64_t-compatible
50 
51  const int64_t numNodes = static_cast<int64_t>(nodes.size());
52  ATH_MSG_DEBUG("Total buckets: " << buckets->size()
53  << " -> nodes (size>0): " << numNodes
54  << " | features.size()=" << graphData.featureLeaves.size());
55 
56  if (numNodes == 0) {
57  ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping inference.");
58  return StatusCode::SUCCESS;
59  }
60 
61  const int64_t nFeatPerNode = 6;
62  if (numNodes * nFeatPerNode != static_cast<int64_t>(graphData.featureLeaves.size())) {
63  ATH_MSG_ERROR( "Feature size mismatch: expected " << (numNodes * nFeatPerNode)
64  << " got " << graphData.featureLeaves.size());
65  return StatusCode::FAILURE;
66  }
67 
68  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
69  std::vector<int64_t> featShape{numNodes, nFeatPerNode};
70  graphData.graph->dataTensor.emplace_back(
71  Ort::Value::CreateTensor<float>(memInfo,
72  graphData.featureLeaves.data(),
73  graphData.featureLeaves.size(),
74  featShape.data(),
75  featShape.size()));
76  return StatusCode::SUCCESS;
77 }
78 
80  GraphRawData& graphData) const {
81  // Start from (N,6)
82  ATH_CHECK(buildFeaturesOnly(ctx, graphData));
83 
84  // Copy features flat buffer for lifetime management
85  std::vector<float> featuresFlat = graphData.featureLeaves;
86  const int64_t S = static_cast<int64_t>(featuresFlat.size() / 6);
87 
88  if (S == 0) {
89  ATH_MSG_WARNING("No valid features for transformer input. Skipping inference.");
90  return StatusCode::SUCCESS;
91  }
92 
93  if (msgLvl(MSG::DEBUG)) {
94  // DEBUG: Print transformer input features for first 10 nodes
95  ATH_MSG_DEBUG("=== DEBUGGING: Transformer input features for first 10 nodes ===");
96  const int64_t debugNodes = std::min(S, static_cast<int64_t>(10));
97  for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
98  const int64_t baseIdx = nodeIdx * 6;
99  ATH_MSG_DEBUG("TransformerNode[" << nodeIdx << "]: "
100  << "x=" << featuresFlat[baseIdx + 0] << ", "
101  << "y=" << featuresFlat[baseIdx + 1] << ", "
102  << "z=" << featuresFlat[baseIdx + 2] << ", "
103  << "layers=" << featuresFlat[baseIdx + 3] << ", "
104  << "nSp=" << featuresFlat[baseIdx + 4] << ", "
105  << "bucketSize=" << featuresFlat[baseIdx + 5]);
106  }
107  ATH_MSG_DEBUG("=== END DEBUG TRANSFORMER FEATURES ===");
108  }
109 
110  // Rebuild graph with exactly 2 inputs: features [1,S,6], pad_mask [1,S]
111  graphData.graph = std::make_unique<InferenceGraph>();
112 
113  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
114 
115  // features: [1,S,6] (backed by graphData.featureLeaves to keep alive)
116  std::vector<int64_t> fShape{1, S, 6};
117  graphData.featureLeaves.swap(featuresFlat);
118  graphData.graph->dataTensor.emplace_back(
119  Ort::Value::CreateTensor<float>(memInfo,
120  graphData.featureLeaves.data(),
121  graphData.featureLeaves.size(),
122  fShape.data(),
123  fShape.size()));
124 
125  // pad_mask: [1,S] (bool). Create ORT-owned tensor and fill with False (=valid).
126  Ort::AllocatorWithDefaultOptions allocator;
127  std::vector<int64_t> mShape{1, S};
128  Ort::Value padVal = Ort::Value::CreateTensor(allocator,
129  mShape.data(),
130  mShape.size(),
131  ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
132  bool* maskPtr = padVal.GetTensorMutableData<bool>();
133  for (int64_t i = 0; i < S; ++i) maskPtr[i] = false;
134  graphData.graph->dataTensor.emplace_back(std::move(padVal));
135 
136  return StatusCode::SUCCESS;
137 }
138 
140  GraphRawData& graphData) const {
141  ATH_CHECK(buildFeaturesOnly(ctx, graphData));
142 
143  const MuonR4::SpacePointContainer* buckets{nullptr};
144  ATH_CHECK(SG::get(buckets, m_readKey, ctx));
145 
146  const ActsTrk::GeometryContext* gctx = nullptr;
147  ATH_CHECK(SG::get(gctx, m_geoCtxKey, ctx));
148 
149  std::vector<BucketGraphUtils::NodeAux> nodes;
150  std::vector<float> throwawayFeatures;
151  std::vector<int64_t> throwawaySp; // int64_t
152  BucketGraphUtils::buildNodesAndFeatures(*buckets, *gctx, nodes, throwawayFeatures, throwawaySp);
153 
154  const int64_t numNodes = static_cast<int64_t>(nodes.size());
155  if (numNodes == 0) {
156  ATH_MSG_WARNING("No valid buckets found (all have size 0.0). Skipping graph building.");
157  return StatusCode::SUCCESS;
158  }
159 
160  std::vector<int64_t> srcEdges, dstEdges;
162  m_minLayers,
165  m_maxDistXY,
166  m_maxAbsDz,
167  srcEdges, dstEdges);
168  if (m_validateEdges) {
169  size_t bad = 0;
170  std::vector<int64_t> newSrc;
171  std::vector<int64_t> newDst;
172  newSrc.reserve(srcEdges.size());
173  newDst.reserve(dstEdges.size());
174  for (size_t k = 0; k < srcEdges.size(); ++k) {
175  const int64_t u = srcEdges[k];
176  const int64_t v = dstEdges[k];
177  const bool okU = (u >= 0 && u < numNodes);
178  const bool okV = (v >= 0 && v < numNodes);
179  if (okU && okV) {
180  newSrc.push_back(u);
181  newDst.push_back(v);
182  } else {
183  ++bad;
184  ATH_MSG_DEBUG( "Drop invalid edge " << k << ": (" << u << "->" << v
185  << "), valid node range [0," << (numNodes-1) << "]");
186  }
187  }
188  if (bad) {
189  ATH_MSG_WARNING( "Removed " << bad << " invalid edges out of "
190  << srcEdges.size());
191  srcEdges.swap(newSrc);
192  dstEdges.swap(newDst);
193  }
194  }
195 
196  const size_t E = srcEdges.size();
197 
198  if (msgLvl(MSG::DEBUG)) {
199  // DEBUG: Count connections per node
200  ATH_MSG_DEBUG("Edges built: " << E);
201  const unsigned int dumpE = std::min<unsigned int>(m_debugDumpFirstNEdges, E);
202  for (unsigned int k = 0; k < dumpE; ++k) {
203  ATH_MSG_DEBUG("EDGE[" << k << "]: " << srcEdges[k] << " -> " << dstEdges[k]);
204  }
205  std::vector<int> nodeConnections(numNodes, 0);
206  for (size_t k = 0; k < srcEdges.size(); ++k) {
207  const int64_t u = srcEdges[k];
208  const int64_t v = dstEdges[k];
209  if (u >= 0 && u < numNodes) nodeConnections[u]++;
210  if (v >= 0 && v < numNodes) nodeConnections[v]++;
211  }
212 
213  ATH_MSG_INFO("=== DEBUGGING: Node Connections (first 10 nodes) ===");
214  const int64_t debugNodeCount = std::min(numNodes, static_cast<int64_t>(10));
215  for (int64_t i = 0; i < debugNodeCount; ++i) {
216  ATH_MSG_DEBUG("Node[" << i << "] connections: " << nodeConnections[i]);
217  }
218  ATH_MSG_DEBUG("=== END DEBUG NODE CONNECTIONS ===");
219 
220  // DEBUG: Show detailed edge connections for first 10 nodes
221  ATH_MSG_DEBUG("=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
222  for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
223  std::stringstream connections;
224  connections << "Node[" << nodeIdx << "] connected to: ";
225  bool foundAny = false;
226 
227  for (size_t k = 0; k < srcEdges.size(); ++k) {
228  const int64_t u = srcEdges[k];
229  const int64_t v = dstEdges[k];
230 
231  if (u == nodeIdx) {
232  if (foundAny) connections << ", ";
233  connections << v;
234  foundAny = true;
235  } else if (v == nodeIdx) {
236  if (foundAny) connections << ", ";
237  connections << u;
238  foundAny = true;
239  }
240  }
241 
242  if (!foundAny) connections << "none";
243  ATH_MSG_DEBUG(connections.str());
244  }
245  ATH_MSG_DEBUG("=== END DEBUG DETAILED CONNECTIONS ===");
246  }
247 
248  graphData.edgeIndexPacked.clear();
249  const size_t Efinal = BucketGraphUtils::packEdgeIndex(srcEdges, dstEdges, graphData.edgeIndexPacked);
250 
251  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
252  std::vector<int64_t> edgeShape{2, static_cast<int64_t>(Efinal)};
253  graphData.graph->dataTensor.emplace_back(
254  Ort::Value::CreateTensor<int64_t>(memInfo,
255  graphData.edgeIndexPacked.data(),
256  graphData.edgeIndexPacked.size(),
257  edgeShape.data(),
258  edgeShape.size()));
259 
260  ATH_MSG_DEBUG("Built sparse bucket graph: N=" << numNodes << ", E=" << Efinal);
261  return StatusCode::SUCCESS;
262 }
263 
265  GraphRawData& graphData,
266  const std::vector<const char*>& inputNames,
267  const std::vector<const char*>& outputNames) const
268 {
269  if (!graphData.graph) {
270  ATH_MSG_ERROR("Graph data is not built.");
271  return StatusCode::FAILURE;
272  }
273  if (graphData.graph->dataTensor.empty()) {
274  ATH_MSG_ERROR("No input tensors prepared for inference.");
275  return StatusCode::FAILURE;
276  }
277 
278  if (msgLvl(MSG::DEBUG)) {
279  // DEBUG: Print actual input tensor data for features tensor
280 
281  ATH_MSG_DEBUG("=== DEBUGGING: ONNX Input tensor data ===");
282  if (!graphData.graph->dataTensor.empty()) {
283  const auto& featureTensor = graphData.graph->dataTensor[0];
284  auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
285  ATH_MSG_DEBUG("Features tensor shape: [" << featShape[0]
286  << (featShape.size()>1 ? ("," + std::to_string(featShape[1])) : "")
287  << (featShape.size()>2 ? ("," + std::to_string(featShape[2])) : "") << "]");
288 
289  float* featData = const_cast<Ort::Value&>(featureTensor).GetTensorMutableData<float>();
290  const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
291  ATH_MSG_DEBUG("Features tensor total elements: " << totalElements);
292 
293  // Print first 10 nodes (60 values, 6 per node)
294  const size_t debugElements = std::min(totalElements, static_cast<size_t>(60));
295  for (size_t i = 0; i < debugElements; i += 6) {
296  if (i + 5 < totalElements) {
297  ATH_MSG_DEBUG("ONNXNode[" << (i/6) << "]: "
298  << "x=" << featData[i+0] << ", "
299  << "y=" << featData[i+1] << ", "
300  << "z=" << featData[i+2] << ", "
301  << "layers=" << featData[i+3] << ", "
302  << "nSp=" << featData[i+4] << ", "
303  << "bucketSize=" << featData[i+5]);
304  }
305  }
306  }
307  ATH_MSG_DEBUG("=== END DEBUG ONNX INPUT ===");
308  }
309 
310  Ort::RunOptions run_options;
311  run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
312 
313  std::vector<Ort::Value> outputs =
314  model().Run(run_options,
315  inputNames.data(),
316  graphData.graph->dataTensor.data(),
317  graphData.graph->dataTensor.size(),
318  outputNames.data(),
319  outputNames.size());
320 
321  if (outputs.empty()) {
322  ATH_MSG_ERROR("Inference returned empty output.");
323  return StatusCode::FAILURE;
324  }
325 
326  float* outData = outputs[0].GetTensorMutableData<float>();
327  const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
328  ATH_MSG_DEBUG("ONNX raw output elementCount = " << outSize);
329 
330  std::span<float> preds(outData, outData + outSize);
331  for (size_t i = 0; i < outSize; ++i) {
332  if (!std::isfinite(preds[i])) {
333  ATH_MSG_WARNING("Non-finite prediction detected at " << i << " -> set to -100.");
334  preds[i] = -100.0f;
335  }
336  }
337 
338  for (auto& v : outputs) {
339  graphData.graph->dataTensor.emplace_back(std::move(v));
340  }
341  return StatusCode::SUCCESS;
342 }
343 
345  std::vector<const char*> inputNames = {"features", "edge_index"};
346  std::vector<const char*> outputNames = {"output"};
347  return runNamedInference(graphData, inputNames, outputNames);
348 }
BucketGraphUtils.h
bad
@ bad
Definition: SUSYToolsTester.cxx:95
MuonML::GraphRawData::spacePointsInBucket
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition: GraphData.h:36
MuonML::BucketInferenceToolBase::model
Ort::Session & model() const
Definition: BucketInferenceToolBase.cxx:21
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthMsgStreamMacros.h
MuonML::BucketInferenceToolBase::buildGraph
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
GNN-style graph builder (features + edges). Kept for tools that want it.
Definition: BucketInferenceToolBase.cxx:139
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
MuonML::BucketInferenceToolBase::m_debugDumpFirstNEdges
Gaudi::Property< unsigned int > m_debugDumpFirstNEdges
Definition: BucketInferenceToolBase.h:75
MuonML::GraphRawData::graph
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition: GraphData.h:46
MuonML::BucketInferenceToolBase::buildTransformerInputs
StatusCode buildTransformerInputs(const EventContext &ctx, GraphRawData &graphData) const
Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
Definition: BucketInferenceToolBase.cxx:79
MuonR4::to_string
std::string to_string(const SectorProjector proj)
Definition: MsTrackSeeder.cxx:66
python.oracle.Session
Session
Definition: oracle.py:76
JetTiledMap::S
@ S
Definition: TiledEtaPhiMap.h:44
Trk::u
@ u
Enums for curvilinear frames.
Definition: ParamDefs.h:77
MuonML::GraphRawData::desEdges
EdgeCounterVec_t desEdges
Vect
Definition: GraphData.h:34
ReadCondHandle.h
MuonML::BucketInferenceToolBase::runNamedInference
StatusCode runNamedInference(GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const
Generic named inference, for tools with different I/O conventions.
Definition: BucketInferenceToolBase.cxx:264
MuonML
Definition: BucketGraphUtils.h:19
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
SG::get
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
Definition: ReadCondHandle.h:283
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
MuonML::GraphRawData::srcEdges
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition: GraphData.h:32
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MuonML::GraphRawData::featureLeaves
FeatureVec_t featureLeaves
Vector containing all features.
Definition: GraphData.h:30
MuonML::GraphRawData
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition: GraphData.h:25
SG::VarHandleKey::initialize
StatusCode initialize(bool used=true)
If this object is used as a property, then this should be called during the initialize phase.
Definition: AthToolSupport/AsgDataHandles/Root/VarHandleKey.cxx:103
ActsTrk::GeometryContext
Definition: GeometryContext.h:28
MuonML::BucketInferenceToolBase::m_validateEdges
Gaudi::Property< bool > m_validateEdges
Definition: BucketInferenceToolBase.h:76
DataVector
Derived DataVector<T>.
Definition: DataVector.h:795
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
MuonML::BucketInferenceToolBase::m_maxSectorDelta
Gaudi::Property< int > m_maxSectorDelta
Definition: BucketInferenceToolBase.h:69
MuonML::GraphRawData::edgeIndexPacked
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
Definition: GraphData.h:42
MuonML::BucketInferenceToolBase::setupModel
StatusCode setupModel()
Definition: BucketInferenceToolBase.cxx:25
MuonML::BucketGraphUtils::buildSparseEdges
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
Definition: BucketGraphUtils.h:95
MuonML::BucketInferenceToolBase::m_geoCtxKey
SG::ReadHandleKey< ActsTrk::GeometryContext > m_geoCtxKey
Definition: BucketInferenceToolBase.h:64
VP1PartSpect::E
@ E
Definition: VP1PartSpectFlags.h:21
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:17
MuonML::BucketInferenceToolBase::m_maxDistXY
Gaudi::Property< double > m_maxDistXY
Definition: BucketInferenceToolBase.h:70
MuonML::BucketInferenceToolBase::buildFeaturesOnly
StatusCode buildFeaturesOnly(const EventContext &ctx, GraphRawData &graphData) const
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
Definition: BucketInferenceToolBase.cxx:32
MuonML::BucketGraphUtils::packEdgeIndex
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
Definition: BucketGraphUtils.h:162
MuonML::BucketInferenceToolBase::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Definition: BucketInferenceToolBase.h:63
MuonML::BucketInferenceToolBase::m_onnxSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: BucketInferenceToolBase.h:79
python.PyAthena.v
v
Definition: PyAthena.py:154
MuonML::BucketInferenceToolBase::m_minLayers
Gaudi::Property< int > m_minLayers
Definition: BucketInferenceToolBase.h:67
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
BucketInferenceToolBase.h
SpacePointContainer.h
DEBUG
#define DEBUG
Definition: page_access.h:11
MuonML::BucketInferenceToolBase::runInference
StatusCode runInference(GraphRawData &graphData) const
Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"output"}.
Definition: BucketInferenceToolBase.cxx:344
MuonML::BucketGraphUtils::buildNodesAndFeatures
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.
Definition: BucketGraphUtils.h:40
ReadHandle.h
Handle class for reading from StoreGate.
StoreGateSvc.h
MuonML::BucketInferenceToolBase::m_maxChamberDelta
Gaudi::Property< int > m_maxChamberDelta
Definition: BucketInferenceToolBase.h:68
fitman.k
k
Definition: fitman.py:528
MuonML::BucketInferenceToolBase::m_maxAbsDz
Gaudi::Property< double > m_maxAbsDz
Definition: BucketInferenceToolBase.h:71