Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
GraphInferenceToolBase.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
5 
11 
12 #include <span>
13 
14 namespace {
15  template <typename T>
16  std::ostream& operator<<(std::ostream& ostr, const std::vector<T>& vec) {
17  ostr << "[";
18  for (const T& val : vec) {
19  ostr << val << ", ";
20  }
21  ostr << "]";
22  return ostr;
23  }
24  template <typename T1 , typename T2>
25  std::ostream& operator<<(std::ostream& ostr, const std::pair<T1, T2>& pair) {
26  ostr << "(" << pair.first << ", " << pair.second << ")";
27  return ostr;
28  }
29 
30  std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(const std::vector<int64_t>& edges) {
31  std::vector<std::pair<int64_t, int64_t>> indexPairs;
32  const size_t nEdges = edges.size() / 2;
33  indexPairs.reserve(nEdges);
34  for (size_t i = 0; i < nEdges; ++i) {
35  indexPairs.emplace_back(std::make_pair(edges[i], edges[i+nEdges]));
36  }
37  return indexPairs;
38  }
39 
40  std::vector<int64_t> makeSortedEdges(const std::vector<int64_t>& edges) {
41  const size_t nEdges = edges.size() / 2;
42  std::vector<std::pair<int64_t, int64_t>> indexPairs;
43  indexPairs.reserve(nEdges);
44 
45  // Create (src, dst) pairs
46  for (size_t i = 0; i < nEdges; ++i) {
47  indexPairs.emplace_back(edges[i], edges[i + nEdges]);
48  }
49 
50  // Sort by src, then dst
51  std::sort(indexPairs.begin(), indexPairs.end(), [](const auto& a, const auto& b) {
52  return (a.first < b.first) || (a.first == b.first && a.second < b.second);
53  });
54 
55  // Reconstruct sorted flat edges: [sorted_src..., sorted_dst...]
56  std::vector<int64_t> sortedEdges;
57  sortedEdges.reserve(2 * nEdges);
58  for (const auto& pair : indexPairs) {
59  sortedEdges.push_back(pair.first);
60  }
61  for (const auto& pair : indexPairs) {
62  sortedEdges.push_back(pair.second);
63  }
64 
65  return sortedEdges;
66  }
67 
68  std::string formatNodeFeatures(const std::vector<float>& featureLeaves, size_t numFeaturesPerNode) {
69  std::ostringstream oss;
70 
71  const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
72  oss << "Number of nodes: " << numNodes << "\n";
73  oss << "Features per node: " << numFeaturesPerNode << "\n";
74 
75  for (size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
76  oss << "Node[" << nodeIdx << "]: [";
77  for (size_t f = 0; f < numFeaturesPerNode; ++f) {
78  if (f > 0) oss << ", ";
79  oss << featureLeaves[nodeIdx * numFeaturesPerNode + f];
80  }
81  oss << "]\n";
82  }
83 
84  return oss.str();
85  }
86 }
87 namespace MuonML{
89  return m_onnxSessionTool->session();
90  }
92  ATH_CHECK(m_onnxSessionTool.retrieve());
93  ATH_CHECK(m_readKey.initialize());
94 
95  Ort::ModelMetadata metadata = model().GetModelMetadata();
96  Ort::AllocatorWithDefaultOptions allocator;
97  Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
98 
99  if (feature_json_ptr) {
100  std::string feature_json = feature_json_ptr.get();
101  nlohmann::json json_obj = nlohmann::json::parse(feature_json);
102  for (const auto& feature : json_obj) {
103  m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
104  }
105  }
106  m_graphFeatures.setConnector("fullyConnected", msgStream());
107 
108  if (!m_graphFeatures.isValid()) {
109  ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
110  return StatusCode::FAILURE;
111  }
112  return StatusCode::SUCCESS;
113  }
114 
116  GraphRawData& graphData) const {
117 
119  if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
120  graphData.graph.reset();
121  }
123  if (graphData.graph) {
124  return StatusCode::SUCCESS;
125  }
126  if (!m_graphFeatures.isValid()) {
127  ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
128  return StatusCode::FAILURE;
129  }
130 
131  graphData.graph = std::make_unique<InferenceGraph>();
132 
134  ATH_CHECK(spacePoints.isPresent());
135 
136  int64_t nNodes{0}, possConn{0};
137  graphData.spacePointsInBucket.clear();
138  graphData.spacePointsInBucket.reserve(spacePoints->size());
139 
140  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
141  nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
142  possConn += MuonR4::sumUp(graphData.spacePointsInBucket.back());
143  }
144 
145  graphData.nodeIndex = 0;
146  graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
147  graphData.currLeave = graphData.featureLeaves.begin();
148 
149  graphData.srcEdges.reserve(possConn);
150  graphData.desEdges.reserve(possConn);
151 
153  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
154  const LayerSpBucket mlBucket{*bucket};
155  m_graphFeatures.fillInData(mlBucket, graphData);
156  }
157 
158  graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
159  std::make_move_iterator(graphData.desEdges.end()));
160 
162  ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
163  ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
164 
165  std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
166  std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
167 
168 
169  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
170  graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
171  graphData.featureLeaves.data(), graphData.featureLeaves.size(),
172  featShape.data(), featShape.size()));
173 
174 
175  Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
176  graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
177 
178  graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
179 
180  graphData.previousList = &m_graphFeatures;
181 
182  graphData.srcEdges.clear();
183  graphData.desEdges.clear();
184  graphData.featureLeaves.clear();
185  ATH_MSG_DEBUG("Graph data built successfully.");
186  return StatusCode::SUCCESS;
187  }
188 
190  if (!m_graphFeatures.isValid()) {
191  ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
192  return StatusCode::FAILURE;
193  }
194  if (!graphData.graph) {
195  ATH_MSG_ERROR("Graph data is not built.");
196  return StatusCode::FAILURE;
197  }
198 
199  if (graphData.graph->dataTensor.size() < 2) {
200  ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
201  return StatusCode::FAILURE;
202  }
203 
204  std::vector<const char*> inputNames = {"features", "edge_index"};
205  std::vector<const char*> outputNames = {"output"};
206 
207  Ort::RunOptions run_options;
208  run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
209 
210  std::vector<Ort::Value> outputTensors = model().Run(run_options,
211  inputNames.data(), // input tensor names
212  graphData.graph->dataTensor.data(), // pointer to the tensor vector
213  graphData.graph->dataTensor.size(), // size of the tensor vector
214  outputNames.data(), // output tensor names
215  outputNames.size()); // number of output tensors
216 
217  if (outputTensors.empty()) {
218  ATH_MSG_ERROR("Inference returned empty output.");
219  return StatusCode::FAILURE;
220  }
221 
222  float* output_data = outputTensors[0].GetTensorMutableData<float>();
223  size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
224 
225  std::span<float> predictions(output_data, output_data + output_size);
226 
227  for (size_t i = 0; i < output_size; i++) {
228  if (!std::isfinite(predictions[i])) {
229  ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
230  predictions[i] = -100.0f;
231  }
232  }
233 
234  graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
235 
236  return StatusCode::SUCCESS;
237  }
238 
239 }
MuonML::GraphInferenceToolBase::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter
Definition: GraphInferenceToolBase.h:43
MuonML::GraphInferenceToolBase::buildGraph
StatusCode buildGraph(const EventContext &ctx, GraphRawData &graphData) const
Fill up the GraphRawData and construct the graph for the ML inference with ONNX.
Definition: GraphInferenceToolBase.cxx:115
MuonML::GraphRawData::spacePointsInBucket
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
Definition: GraphData.h:36
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
MuonR4::SpacePointBucket
: The muon space point bucket represents a collection of points that will bre processed together in t...
Definition: MuonSpectrometer/MuonPhaseII/Event/MuonSpacePoint/MuonSpacePoint/SpacePointContainer.h:21
MuonML::NodeFeatureList::numFeatures
size_t numFeatures() const
Returns the number of features in the list.
Definition: NodeFeatureList.cxx:40
AthMsgStreamMacros.h
json
nlohmann::json json
Definition: HistogramDef.cxx:9
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:67
MuonML::NodeFeatureList::fillInData
void fillInData(const Bucket_t &bucket, GraphRawData &graphData) const
Definition: NodeFeatureList.cxx:75
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1080
MuonML::GraphInferenceToolBase::m_onnxSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: GraphInferenceToolBase.h:47
MuonML::GraphRawData::graph
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition: GraphData.h:40
vec
std::vector< size_t > vec
Definition: CombinationsGeneratorTest.cxx:9
python.oracle.Session
Session
Definition: oracle.py:78
MatrixUtils.h
MuonML::GraphRawData::desEdges
EdgeCounterVec_t desEdges
Vect
Definition: GraphData.h:34
MuonML
Definition: GraphBucketFilterTool.cxx:9
MuonML::NodeFeatureList::featureNames
std::vector< std::string > featureNames() const
Returns the name of the features in the list.
Definition: NodeFeatureList.cxx:45
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
MuonR4::sumUp
constexpr unsigned int sumUp(unsigned k)
Calculates the sum of 1 + 2 +3 +4 +...
Definition: MatrixUtils.h:15
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
MuonML::NodeFeatureList::isValid
bool isValid() const
Returns whether the NodeFeatureList is complete, i.e.
Definition: NodeFeatureList.cxx:15
MuonML::LayerSpBucket
The LayerSpBucket is a space pointbucket where the points are internally sorted by their layer number...
Definition: LayerBucket.h:14
MuonML::NodeFeatureList::addFeature
bool addFeature(const std::string &featName, MsgStream &msg)
Tries to add a new feature to the list using the predefined list of features in the GraphFeatureFacto...
Definition: NodeFeatureList.cxx:52
MuonML::NodeFeatureList::setConnector
bool setConnector(const std::string &conName, MsgStream &msg)
Tries to set the graph connector based on the connector name.
Definition: NodeFeatureList.cxx:18
MuonML::GraphRawData::srcEdges
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
Definition: GraphData.h:32
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MuonML::GraphRawData::featureLeaves
FeatureVec_t featureLeaves
Vector containing all features.
Definition: GraphData.h:30
hist_file_dump.f
f
Definition: hist_file_dump.py:141
MuonML::GraphRawData
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition: GraphData.h:25
MuonML::GraphInferenceToolBase::model
Ort::Session & model() const
Definition: GraphInferenceToolBase.cxx:88
PathResolver.h
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:18
GraphData.h
operator<<
std::ostream & operator<<(std::ostream &lhs, const TestGaudiProperty &rhs)
Definition: TestGaudiProperty.cxx:69
MuonML::GraphRawData::previousList
const NodeFeatureList * previousList
Pointer to the latest parsed NodeFeatureList.
Definition: GraphData.h:38
GraphInferenceToolBase.h
MuonML::GraphInferenceToolBase::setupModel
StatusCode setupModel()
Definition: GraphInferenceToolBase.cxx:91
a
TList * a
Definition: liststreamerinfos.cxx:10
NodeFeatureList.h
std::sort
void sort(typename std::reverse_iterator< DataModel_detail::iterator< DVL > > beg, typename std::reverse_iterator< DataModel_detail::iterator< DVL > > end, const Compare &comp)
Specialization of sort for DataVector/List.
Definition: DVL_algorithms.h:623
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
MuonML::GraphRawData::currLeave
std::vector< float >::iterator currLeave
The following variables are needed to fill the consistently the raw data for the Graph Building.
Definition: GraphData.h:44
MuonML::GraphInferenceToolBase::runInference
StatusCode runInference(GraphRawData &graphData) const
Definition: GraphInferenceToolBase.cxx:189
MuonML::GraphInferenceToolBase::m_graphFeatures
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
Definition: GraphInferenceToolBase.h:46
python.FPGATrackSimAnalysisConfig.spacePoints
spacePoints
Definition: FPGATrackSimAnalysisConfig.py:642
TSU::T
unsigned long long T
Definition: L1TopoDataTypes.h:35
MuonML::GraphRawData::nodeIndex
unsigned int nodeIndex
Number of the already filled nodes.
Definition: GraphData.h:46