ATLAS Offline Software
Public Member Functions | Protected Member Functions | Protected Attributes | Private Attributes | List of all members
MuonML::GraphInferenceToolBase Class Reference

Baseline tool to handle the
More...

#include <GraphInferenceToolBase.h>

Inheritance diagram for MuonML::GraphInferenceToolBase:
Collaboration diagram for MuonML::GraphInferenceToolBase:

Public Member Functions

StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 Fill up the GraphRawData and construct the graph for the ML inference with ONNX. More...
 
StatusCode runInference (GraphRawData &graphData) const
 

Protected Member Functions

StatusCode setupModel ()
 
Ort::Session & model () const
 

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
 Input space points to filter
More...
 

Private Attributes

NodeFeatureList m_graphFeatures {}
 List of features to be used for the inference. More...
 
ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool {this, "ModelSession", "" }
 

Detailed Description

Baseline tool to handle the

Definition at line 22 of file GraphInferenceToolBase.h.

Member Function Documentation

◆ buildGraph()

StatusCode MuonML::GraphInferenceToolBase::buildGraph ( const EventContext &  ctx,
GraphRawData graphData 
) const

Fill up the GraphRawData and construct the graph for the ML inference with ONNX.

If the graph has been built by another inference tool and would be the same than this one the rebuild is skipped

Parameters
ctxEventContext to access the space ponit container from StoreGate
graphDataRerference to the data object to be filled.

Check whether the graph needs a rebuild

Don't launch the rebuild of the graph

Fill the graph edge features and all their respective connections

Definition at line 114 of file GraphInferenceToolBase.cxx.

115  {
116 
118  if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
119  graphData.graph.reset();
120  }
122  if (graphData.graph) {
123  return StatusCode::SUCCESS;
124  }
125  if (!m_graphFeatures.isValid()) {
126  ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
127  return StatusCode::FAILURE;
128  }
129 
130  graphData.graph = std::make_unique<InferenceGraph>();
131 
133  ATH_CHECK(spacePoints.isPresent());
134 
135  int64_t nNodes{0}, possConn{0};
136  graphData.spacePointsInBucket.clear();
137  graphData.spacePointsInBucket.reserve(spacePoints->size());
138 
139  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
140  nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
141  possConn += Acts::sumUpToN(graphData.spacePointsInBucket.back());
142  }
143 
144  graphData.nodeIndex = 0;
145  graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
146  graphData.currLeave = graphData.featureLeaves.begin();
147 
148  graphData.srcEdges.reserve(possConn);
149  graphData.desEdges.reserve(possConn);
150 
152  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
153  const LayerSpBucket mlBucket{*bucket};
154  m_graphFeatures.fillInData(mlBucket, graphData);
155  }
156 
157  graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
158  std::make_move_iterator(graphData.desEdges.end()));
159 
161  ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
162  ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
163 
164  std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
165  std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
166 
167 
168  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
169  graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
170  graphData.featureLeaves.data(), graphData.featureLeaves.size(),
171  featShape.data(), featShape.size()));
172 
173 
174  Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
175  graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
176 
177  graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
178 
179  graphData.previousList = &m_graphFeatures;
180 
181  graphData.srcEdges.clear();
182  graphData.desEdges.clear();
183  graphData.featureLeaves.clear();
184  ATH_MSG_DEBUG("Graph data built successfully.");
185  return StatusCode::SUCCESS;
186  }

◆ model()

Ort::Session & MuonML::GraphInferenceToolBase::model ( ) const
protected

Definition at line 87 of file GraphInferenceToolBase.cxx.

87  {
88  return m_onnxSessionTool->session();
89  }

◆ runInference()

StatusCode MuonML::GraphInferenceToolBase::runInference ( GraphRawData graphData) const

Definition at line 188 of file GraphInferenceToolBase.cxx.

188  {
189  if (!m_graphFeatures.isValid()) {
190  ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
191  return StatusCode::FAILURE;
192  }
193  if (!graphData.graph) {
194  ATH_MSG_ERROR("Graph data is not built.");
195  return StatusCode::FAILURE;
196  }
197 
198  if (graphData.graph->dataTensor.size() < 2) {
199  ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
200  return StatusCode::FAILURE;
201  }
202 
203  std::vector<const char*> inputNames = {"features", "edge_index"};
204  std::vector<const char*> outputNames = {"output"};
205 
206  Ort::RunOptions run_options;
207  run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
208 
209  std::vector<Ort::Value> outputTensors = model().Run(run_options,
210  inputNames.data(), // input tensor names
211  graphData.graph->dataTensor.data(), // pointer to the tensor vector
212  graphData.graph->dataTensor.size(), // size of the tensor vector
213  outputNames.data(), // output tensor names
214  outputNames.size()); // number of output tensors
215 
216  if (outputTensors.empty()) {
217  ATH_MSG_ERROR("Inference returned empty output.");
218  return StatusCode::FAILURE;
219  }
220 
221  float* output_data = outputTensors[0].GetTensorMutableData<float>();
222  size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
223 
224  std::span<float> predictions(output_data, output_data + output_size);
225 
226  for (size_t i = 0; i < output_size; i++) {
227  if (!std::isfinite(predictions[i])) {
228  ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
229  predictions[i] = -100.0f;
230  }
231  }
232 
233  graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
234 
235  return StatusCode::SUCCESS;
236  }

◆ setupModel()

StatusCode MuonML::GraphInferenceToolBase::setupModel ( )
protected

Definition at line 90 of file GraphInferenceToolBase.cxx.

90  {
91  ATH_CHECK(m_onnxSessionTool.retrieve());
92  ATH_CHECK(m_readKey.initialize());
93 
94  Ort::ModelMetadata metadata = model().GetModelMetadata();
95  Ort::AllocatorWithDefaultOptions allocator;
96  Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
97 
98  if (feature_json_ptr) {
99  std::string feature_json = feature_json_ptr.get();
100  nlohmann::json json_obj = nlohmann::json::parse(feature_json);
101  for (const auto& feature : json_obj) {
102  m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
103  }
104  }
105  m_graphFeatures.setConnector("fullyConnected", msgStream());
106 
107  if (!m_graphFeatures.isValid()) {
108  ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
109  return StatusCode::FAILURE;
110  }
111  return StatusCode::SUCCESS;
112  }

Member Data Documentation

◆ m_graphFeatures

NodeFeatureList MuonML::GraphInferenceToolBase::m_graphFeatures {}
private

List of features to be used for the inference.

Definition at line 46 of file GraphInferenceToolBase.h.

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::GraphInferenceToolBase::m_onnxSessionTool {this, "ModelSession", "" }
private

Definition at line 47 of file GraphInferenceToolBase.h.

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::GraphInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protected

Input space points to filter

Definition at line 43 of file GraphInferenceToolBase.h.


The documentation for this class was generated from the following files:
MuonML::GraphInferenceToolBase::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter
Definition: GraphInferenceToolBase.h:43
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
MuonR4::SpacePointBucket
: The muon space point bucket represents a collection of points that will bre processed together in t...
Definition: MuonSpectrometer/MuonPhaseII/Event/MuonSpacePoint/MuonSpacePoint/SpacePointContainer.h:21
MuonML::NodeFeatureList::numFeatures
size_t numFeatures() const
Returns the number of features in the list.
Definition: NodeFeatureList.cxx:40
json
nlohmann::json json
Definition: HistogramDef.cxx:9
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:67
MuonML::NodeFeatureList::fillInData
void fillInData(const Bucket_t &bucket, GraphRawData &graphData) const
Definition: NodeFeatureList.cxx:75
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1113
MuonML::GraphInferenceToolBase::m_onnxSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: GraphInferenceToolBase.h:47
MuonML::NodeFeatureList::featureNames
std::vector< std::string > featureNames() const
Returns the name of the features in the list.
Definition: NodeFeatureList.cxx:45
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
MuonML::NodeFeatureList::isValid
bool isValid() const
Returns whether the NodeFeatureList is complete, i.e.
Definition: NodeFeatureList.cxx:15
MuonML::NodeFeatureList::addFeature
bool addFeature(const std::string &featName, MsgStream &msg)
Tries to add a new feature to the list using the predefined list of features in the GraphFeatureFacto...
Definition: NodeFeatureList.cxx:52
MuonML::NodeFeatureList::setConnector
bool setConnector(const std::string &conName, MsgStream &msg)
Tries to set the graph connector based on the connector name.
Definition: NodeFeatureList.cxx:18
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
F600IntegrationConfig.spacePoints
spacePoints
Definition: F600IntegrationConfig.py:122
MuonML::GraphInferenceToolBase::model
Ort::Session & model() const
Definition: GraphInferenceToolBase.cxx:87
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:17
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
MuonML::GraphInferenceToolBase::m_graphFeatures
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
Definition: GraphInferenceToolBase.h:46