Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Public Member Functions | Protected Member Functions | Protected Attributes | Private Attributes | List of all members
MuonML::GraphInferenceToolBase Class Reference

Baseline tool to handle the
More...

#include <GraphInferenceToolBase.h>

Inheritance diagram for MuonML::GraphInferenceToolBase:
Collaboration diagram for MuonML::GraphInferenceToolBase:

Public Member Functions

StatusCode buildGraph (const EventContext &ctx, GraphRawData &graphData) const
 Fill up the GraphRawData and construct the graph for the ML inference with ONNX. More...
 
StatusCode runInference (GraphRawData &graphData) const
 

Protected Member Functions

StatusCode setupModel ()
 
Ort::Session & model () const
 

Protected Attributes

SG::ReadHandleKey< MuonR4::SpacePointContainerm_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
 Input space points to filter
More...
 

Private Attributes

NodeFeatureList m_graphFeatures {}
 List of features to be used for the inference. More...
 
ToolHandle< AthOnnx::IOnnxRuntimeSessionToolm_onnxSessionTool {this, "ModelSession", "" }
 

Detailed Description

Baseline tool to handle the

Definition at line 22 of file GraphInferenceToolBase.h.

Member Function Documentation

◆ buildGraph()

StatusCode MuonML::GraphInferenceToolBase::buildGraph ( const EventContext &  ctx,
GraphRawData graphData 
) const

Fill up the GraphRawData and construct the graph for the ML inference with ONNX.

If the graph has been built by another inference tool and would be the same than this one the rebuild is skipped

Parameters
ctxEventContext to access the space ponit container from StoreGate
graphDataRerference to the data object to be filled.

Check whether the graph needs a rebuild

Don't launch the rebuild of the graph

Fill the graph edge features and all their respective connections

Definition at line 115 of file GraphInferenceToolBase.cxx.

116  {
117 
119  if (graphData.previousList && (*graphData.previousList) != m_graphFeatures) {
120  graphData.graph.reset();
121  }
123  if (graphData.graph) {
124  return StatusCode::SUCCESS;
125  }
126  if (!m_graphFeatures.isValid()) {
127  ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
128  return StatusCode::FAILURE;
129  }
130 
131  graphData.graph = std::make_unique<InferenceGraph>();
132 
134  ATH_CHECK(spacePoints.isPresent());
135 
136  int64_t nNodes{0}, possConn{0};
137  graphData.spacePointsInBucket.clear();
138  graphData.spacePointsInBucket.reserve(spacePoints->size());
139 
140  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
141  nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
142  possConn += MuonR4::sumUp(graphData.spacePointsInBucket.back());
143  }
144 
145  graphData.nodeIndex = 0;
146  graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
147  graphData.currLeave = graphData.featureLeaves.begin();
148 
149  graphData.srcEdges.reserve(possConn);
150  graphData.desEdges.reserve(possConn);
151 
153  for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
154  const LayerSpBucket mlBucket{*bucket};
155  m_graphFeatures.fillInData(mlBucket, graphData);
156  }
157 
158  graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
159  std::make_move_iterator(graphData.desEdges.end()));
160 
162  ATH_MSG_DEBUG(formatNodeFeatures(graphData.featureLeaves, m_graphFeatures.numFeatures()));
163  ATH_MSG_DEBUG("Edge indices:"<<makeIndexPairs(makeSortedEdges(graphData.srcEdges)));
164 
165  std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, nFeatures)
166  std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E)
167 
168 
169  Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
170  graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
171  graphData.featureLeaves.data(), graphData.featureLeaves.size(),
172  featShape.data(), featShape.size()));
173 
174 
175  Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(),
176  graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
177 
178  graphData.graph->dataTensor.emplace_back(std::move(edge_tensor));
179 
180  graphData.previousList = &m_graphFeatures;
181 
182  graphData.srcEdges.clear();
183  graphData.desEdges.clear();
184  graphData.featureLeaves.clear();
185  ATH_MSG_DEBUG("Graph data built successfully.");
186  return StatusCode::SUCCESS;
187  }

◆ model()

Ort::Session & MuonML::GraphInferenceToolBase::model ( ) const
protected

Definition at line 88 of file GraphInferenceToolBase.cxx.

88  {
89  return m_onnxSessionTool->session();
90  }

◆ runInference()

StatusCode MuonML::GraphInferenceToolBase::runInference ( GraphRawData graphData) const

Definition at line 189 of file GraphInferenceToolBase.cxx.

189  {
190  if (!m_graphFeatures.isValid()) {
191  ATH_MSG_ERROR("ONNX model is not loaded. Please call setupModel()");
192  return StatusCode::FAILURE;
193  }
194  if (!graphData.graph) {
195  ATH_MSG_ERROR("Graph data is not built.");
196  return StatusCode::FAILURE;
197  }
198 
199  if (graphData.graph->dataTensor.size() < 2) {
200  ATH_MSG_ERROR("Data tensor does not contain both feature and edge tensors.");
201  return StatusCode::FAILURE;
202  }
203 
204  std::vector<const char*> inputNames = {"features", "edge_index"};
205  std::vector<const char*> outputNames = {"output"};
206 
207  Ort::RunOptions run_options;
208  run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
209 
210  std::vector<Ort::Value> outputTensors = model().Run(run_options,
211  inputNames.data(), // input tensor names
212  graphData.graph->dataTensor.data(), // pointer to the tensor vector
213  graphData.graph->dataTensor.size(), // size of the tensor vector
214  outputNames.data(), // output tensor names
215  outputNames.size()); // number of output tensors
216 
217  if (outputTensors.empty()) {
218  ATH_MSG_ERROR("Inference returned empty output.");
219  return StatusCode::FAILURE;
220  }
221 
222  float* output_data = outputTensors[0].GetTensorMutableData<float>();
223  size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
224 
225  std::span<float> predictions(output_data, output_data + output_size);
226 
227  for (size_t i = 0; i < output_size; i++) {
228  if (!std::isfinite(predictions[i])) {
229  ATH_MSG_WARNING("Non-finite prediction detected! Setting to -100..");
230  predictions[i] = -100.0f;
231  }
232  }
233 
234  graphData.graph->dataTensor.emplace_back(std::move(outputTensors[0]));
235 
236  return StatusCode::SUCCESS;
237  }

◆ setupModel()

StatusCode MuonML::GraphInferenceToolBase::setupModel ( )
protected

Definition at line 91 of file GraphInferenceToolBase.cxx.

91  {
92  ATH_CHECK(m_onnxSessionTool.retrieve());
93  ATH_CHECK(m_readKey.initialize());
94 
95  Ort::ModelMetadata metadata = model().GetModelMetadata();
96  Ort::AllocatorWithDefaultOptions allocator;
97  Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
98 
99  if (feature_json_ptr) {
100  std::string feature_json = feature_json_ptr.get();
101  nlohmann::json json_obj = nlohmann::json::parse(feature_json);
102  for (const auto& feature : json_obj) {
103  m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
104  }
105  }
106  m_graphFeatures.setConnector("fullyConnected", msgStream());
107 
108  if (!m_graphFeatures.isValid()) {
109  ATH_MSG_FATAL("No graph features have been parsed. Please check the model: "<<m_graphFeatures.featureNames());
110  return StatusCode::FAILURE;
111  }
112  return StatusCode::SUCCESS;
113  }

Member Data Documentation

◆ m_graphFeatures

NodeFeatureList MuonML::GraphInferenceToolBase::m_graphFeatures {}
private

List of features to be used for the inference.

Definition at line 46 of file GraphInferenceToolBase.h.

◆ m_onnxSessionTool

ToolHandle<AthOnnx::IOnnxRuntimeSessionTool> MuonML::GraphInferenceToolBase::m_onnxSessionTool {this, "ModelSession", "" }
private

Definition at line 47 of file GraphInferenceToolBase.h.

◆ m_readKey

SG::ReadHandleKey<MuonR4::SpacePointContainer> MuonML::GraphInferenceToolBase::m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"}
protected

Input space points to filter

Definition at line 43 of file GraphInferenceToolBase.h.


The documentation for this class was generated from the following files:
MuonML::GraphInferenceToolBase::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Input space points to filter
Definition: GraphInferenceToolBase.h:43
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
MuonR4::SpacePointBucket
: The muon space point bucket represents a collection of points that will bre processed together in t...
Definition: MuonSpectrometer/MuonPhaseII/Event/MuonSpacePoint/MuonSpacePoint/SpacePointContainer.h:21
MuonML::NodeFeatureList::numFeatures
size_t numFeatures() const
Returns the number of features in the list.
Definition: NodeFeatureList.cxx:40
json
nlohmann::json json
Definition: HistogramDef.cxx:9
SG::ReadHandle
Definition: StoreGate/StoreGate/ReadHandle.h:67
MuonML::NodeFeatureList::fillInData
void fillInData(const Bucket_t &bucket, GraphRawData &graphData) const
Definition: NodeFeatureList.cxx:75
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1082
MuonML::GraphInferenceToolBase::m_onnxSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: GraphInferenceToolBase.h:47
MuonML::NodeFeatureList::featureNames
std::vector< std::string > featureNames() const
Returns the name of the features in the list.
Definition: NodeFeatureList.cxx:45
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
MuonR4::sumUp
constexpr unsigned int sumUp(unsigned k)
Calculates the sum of 1 + 2 +3 +4 +...
Definition: MatrixUtils.h:15
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
MuonML::NodeFeatureList::isValid
bool isValid() const
Returns whether the NodeFeatureList is complete, i.e.
Definition: NodeFeatureList.cxx:15
MuonML::NodeFeatureList::addFeature
bool addFeature(const std::string &featName, MsgStream &msg)
Tries to add a new feature to the list using the predefined list of features in the GraphFeatureFacto...
Definition: NodeFeatureList.cxx:52
MuonML::NodeFeatureList::setConnector
bool setConnector(const std::string &conName, MsgStream &msg)
Tries to set the graph connector based on the connector name.
Definition: NodeFeatureList.cxx:18
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MuonML::GraphInferenceToolBase::model
Ort::Session & model() const
Definition: GraphInferenceToolBase.cxx:88
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:18
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
MuonML::GraphInferenceToolBase::m_graphFeatures
NodeFeatureList m_graphFeatures
List of features to be used for the inference.
Definition: GraphInferenceToolBase.h:46
python.FPGATrackSimAnalysisConfig.spacePoints
spacePoints
Definition: FPGATrackSimAnalysisConfig.py:638