ATLAS Offline Software
GraphInferenceToolBase.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
5 
10 #include "Acts/Utilities/MathHelpers.hpp"
11 #include <span>
12 
13 namespace {
14  template <typename T>
15  std::ostream& operator<<(std::ostream& ostr, const std::vector<T>& vec) {
16  ostr << "[";
17  for (const T& val : vec) {
18  ostr << val << ", ";
19  }
20  ostr << "]";
21  return ostr;
22  }
23  template <typename T1 , typename T2>
24  std::ostream& operator<<(std::ostream& ostr, const std::pair<T1, T2>& pair) {
25  ostr << "(" << pair.first << ", " << pair.second << ")";
26  return ostr;
27  }
28 
29  std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(const std::vector<int64_t>& edges) {
30  std::vector<std::pair<int64_t, int64_t>> indexPairs;
31  const size_t nEdges = edges.size() / 2;
32  indexPairs.reserve(nEdges);
33  for (size_t i = 0; i < nEdges; ++i) {
34  indexPairs.emplace_back(std::make_pair(edges[i], edges[i+nEdges]));
35  }
36  return indexPairs;
37  }
38 
39  std::vector<int64_t> makeSortedEdges(const std::vector<int64_t>& edges) {
40  const size_t nEdges = edges.size() / 2;
41  std::vector<std::pair<int64_t, int64_t>> indexPairs;
42  indexPairs.reserve(nEdges);
43 
44  // Create (src, dst) pairs
45  for (size_t i = 0; i < nEdges; ++i) {
46  indexPairs.emplace_back(edges[i], edges[i + nEdges]);
47  }
48 
49  // Sort by src, then dst
50  std::sort(indexPairs.begin(), indexPairs.end(), [](const auto& a, const auto& b) {
51  return (a.first < b.first) || (a.first == b.first && a.second < b.second);
52  });
53 
54  // Reconstruct sorted flat edges: [sorted_src..., sorted_dst...]
55  std::vector<int64_t> sortedEdges;
56  sortedEdges.reserve(2 * nEdges);
57  for (const auto& pair : indexPairs) {
58  sortedEdges.push_back(pair.first);
59  }
60  for (const auto& pair : indexPairs) {
61  sortedEdges.push_back(pair.second);
62  }
63 
64  return sortedEdges;
65  }
66 
67  std::string formatNodeFeatures(const std::vector<float>& featureLeaves, size_t numFeaturesPerNode) {
68  std::ostringstream oss;
69 
70  const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
71  oss << "Number of nodes: " << numNodes << "\n";
72  oss << "Features per node: " << numFeaturesPerNode << "\n";
73 
74  for (size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
75  oss << "Node[" << nodeIdx << "]: [";
76  for (size_t f = 0; f < numFeaturesPerNode; ++f) {
77  if (f > 0) oss << ", ";
78  oss << featureLeaves[nodeIdx * numFeaturesPerNode + f];
79  }
80  oss << "]\n";
81  }
82 
83  return oss.str();
84  }
85 }
86 namespace MuonML{
88  return m_onnxSessionTool->session();
89  }
91  ATH_CHECK(m_onnxSessionTool.retrieve());
92  ATH_CHECK(m_readKey.initialize());
93 
94  Ort::ModelMetadata metadata = model().GetModelMetadata();
95  Ort::AllocatorWithDefaultOptions allocator;
96  Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
97 
98  if (feature_json_ptr) {
99  std::string feature_json = feature_json_ptr.get();
100  nlohmann::json json_obj = nlohmann::json::parse(feature_json);
101  for (const auto& feature : json_obj) {
102  m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
103  }
104  }
105  m_graphFeatures.setConnector("fullyConnected", msgStream());
106 
107  if (!m_graphFeatures.isValid()) {
108  ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
109  return StatusCode::FAILURE;
110  }
111  return StatusCode::SUCCESS;
112  }
113 
115  GraphRawData& graphData) const {
116 
118  if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
119  graphData.graph.reset();
120  }
122  if (graphData.graph) {
123  return StatusCode::SUCCESS;
124  }
125  if (!m_graphFeatures.isValid()) {
126  ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
127  return StatusCode::FAILURE;
128  }
129 
130  graphData.graph = std::make_unique<InferenceGraph>();
131 
133  ATH_CHECK(spacePoints.isPresent());
134 
135  int64_t nNodes{0}, possConn{0};
136  graphData.spacePointsInBucket.clear();
137  graphData.spacePointsInBucket.reserve(spacePoints->size());
138 
139  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
140  nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
141  possConn += Acts::sumUpToN(graphData.spacePointsInBucket.back());
142  }
143 
144  graphData.nodeIndex = 0;
145  graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
146  graphData.currLeave = graphData.featureLeaves.begin();
147 
148  graphData.srcEdges.reserve(possConn);
149  graphData.desEdges.reserve(possConn);
150 
152  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
153  const LayerSpBucket mlBucket{*bucket};
154  m_graphFeatures.fillInData(mlBucket, graphData);
155  }
156 
157  graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
158  std::make_move_iterator(graphData.desEdges.end()));
159 
161  ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
162  ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
163 
164  std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
165  std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
166 
167 
168  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
169  graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
170  graphData.featureLeaves.data(), graphData.featureLeaves.size(),
171  featShape.data(), featShape.size()));
172 
173 
174  Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
175  graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
176 
177  graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
178 
179  graphData.previousList = &m_graphFeatures;
180 
181  graphData.srcEdges.clear();
182  graphData.desEdges.clear();
183  graphData.featureLeaves.clear();
184  ATH_MSG_DEBUG("Graph data built successfully.");
185  return StatusCode::SUCCESS;
186  }
187 
189  if (!m_graphFeatures.isValid()) {
190  ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
191  return StatusCode::FAILURE;
192  }
193  if (!graphData.graph) {
194  ATH_MSG_ERROR("Graph data is not built.");
195  return StatusCode::FAILURE;
196  }
197 
198  if (graphData.graph->dataTensor.size() < 2) {
199  ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
200  return StatusCode::FAILURE;
201  }
202 
203  std::vector<const char*> inputNames = {"features", "edge_index"};
204  std::vector<const char*> outputNames = {"output"};
205 
206  Ort::RunOptions run_options;
207  run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
208 
209  std::vector<Ort::Value> outputTensors = model().Run(run_options,
210  inputNames.data(), // input tensor names
211  graphData.graph->dataTensor.data(), // pointer to the tensor vector
212  graphData.graph->dataTensor.size(), // size of the tensor vector
213  outputNames.data(), // output tensor names
214  outputNames.size()); // number of output tensors
215 
216  if (outputTensors.empty()) {
217  ATH_MSG_ERROR("Inference returned empty output.");
218  return StatusCode::FAILURE;
219  }
220 
221  float* output_data = outputTensors[0].GetTensorMutableData<float>();
222  size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
223 
224  std::span<float> predictions(output_data, output_data + output_size);
225 
226  for (size_t i = 0; i < output_size; i++) {
227  if (!std::isfinite(predictions[i])) {
228  ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
229  predictions[i] = -100.0f;
230  }
231  }
232 
233  graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
234 
235  return StatusCode::SUCCESS;
236  }
237 
238 }
MuonML::GraphInferenceToolBase::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter
Definition: GraphInferenceToolBase.h:43
MuonML::GraphInferenceToolBase::buildGraph
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Fill up the GraphRawData and construct the graph for the ML inference with ONNX.
Definition: GraphInferenceToolBase.cxx:114
MuonML::GraphRawData::spacePointsInBucket
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition: GraphData.h:36
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
MuonR4::SpacePointBucket
: The muon space point bucket represents a collection of points that will bre processed together in t...
Definition: MuonSpectrometer/MuonPhaseII/Event/MuonSpacePoint/MuonSpacePoint/SpacePointContainer.h:21
MuonML::NodeFeatureList::numFeatures
size_t numFeatures() const
Returns the number of features in the list.
Definition: NodeFeatureList.cxx:40
AthMsgStreamMacros.h
json
nlohmann::json json
Definition: HistogramDef.cxx:9
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:67
MuonML::NodeFeatureList::fillInData
void fillInData(const Bucket_t &bucket, GraphRawData &graphData) const
Definition: NodeFeatureList.cxx:75
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1113
MuonML::GraphInferenceToolBase::m_onnxSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: GraphInferenceToolBase.h:47
MuonML::GraphRawData::graph
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition: GraphData.h:40
vec
std::vector< size_t > vec
Definition: CombinationsGeneratorTest.cxx:9
python.oracle.Session
Session
Definition: oracle.py:76
MuonML::GraphRawData::desEdges
EdgeCounterVec_t desEdges
Vect
Definition: GraphData.h:34
MuonML
Definition: GraphBucketFilterTool.cxx:9
MuonML::NodeFeatureList::featureNames
std::vector< std::string > featureNames() const
Returns the name of the features in the list.
Definition: NodeFeatureList.cxx:45
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
MuonML::NodeFeatureList::isValid
bool isValid() const
Returns whether the NodeFeatureList is complete, i.e.
Definition: NodeFeatureList.cxx:15
MuonML::LayerSpBucket
The LayerSpBucket is a space pointbucket where the points are internally sorted by their layer number...
Definition: LayerBucket.h:14
MuonML::NodeFeatureList::addFeature
bool addFeature(const std::string &featName, MsgStream &msg)
Tries to add a new feature to the list using the predefined list of features in the GraphFeatureFacto...
Definition: NodeFeatureList.cxx:52
MuonML::NodeFeatureList::setConnector
bool setConnector(const std::string &conName, MsgStream &msg)
Tries to set the graph connector based on the connector name.
Definition: NodeFeatureList.cxx:18
MuonML::GraphRawData::srcEdges
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition: GraphData.h:32
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MuonML::GraphRawData::featureLeaves
FeatureVec_t featureLeaves
Vector containing all features.
Definition: GraphData.h:30
hist_file_dump.f
f
Definition: hist_file_dump.py:140
MuonML::GraphRawData
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition: GraphData.h:25
F600IntegrationConfig.spacePoints
spacePoints
Definition: F600IntegrationConfig.py:122
MuonML::GraphInferenceToolBase::model
Ort::Session & model() const
Definition: GraphInferenceToolBase.cxx:87
PathResolver.h
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:76
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:17
GraphData.h
operator<<
std::ostream & operator<<(std::ostream &lhs, const TestGaudiProperty &rhs)
Definition: TestGaudiProperty.cxx:69
MuonML::GraphRawData::previousList
const NodeFeatureList * previousList
Pointer to the latest parsed NodeFeatureList.
Definition: GraphData.h:38
GraphInferenceToolBase.h
MuonML::GraphInferenceToolBase::setupModel
StatusCode setupModel()
Definition: GraphInferenceToolBase.cxx:90
a
TList * a
Definition: liststreamerinfos.cxx:10
NodeFeatureList.h
std::sort
void sort(typename std::reverse_iterator< DataModel_detail::iterator< DVL > > beg, typename std::reverse_iterator< DataModel_detail::iterator< DVL > > end, const Compare &comp)
Specialization of sort for DataVector/List.
Definition: DVL_algorithms.h:623
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
MuonML::GraphRawData::currLeave
std::vector< float >::iterator currLeave
The following variables are needed to fill the consistently the raw data for the Graph Building.
Definition: GraphData.h:44
MuonML::GraphInferenceToolBase::runInference
StatusCode runInference(GraphRawData &graphData) const
Definition: GraphInferenceToolBase.cxx:188
MuonML::GraphInferenceToolBase::m_graphFeatures
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
Definition: GraphInferenceToolBase.h:46
TSU::T
unsigned long long T
Definition: L1TopoDataTypes.h:35
MuonML::GraphRawData::nodeIndex
unsigned int nodeIndex
Number of the already filled nodes.
Definition: GraphData.h:46