ATLAS Offline Software
FlavorTagDiscriminants/Root/OnnxUtil.cxx
Go to the documentation of this file.
1 /*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 
8 #include "lwtnn/parse_json.hh"
9 
10 #include <stdexcept>
11 #include <tuple>
12 #include <set>
13 
14 namespace FlavorTagDiscriminants {
15 
16  OnnxUtil::OnnxUtil(const std::string& path_to_onnx)
17  //load the onnx model to memory using the path m_path_to_onnx
18  : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, ""))
19  {
20  // initialize session options
21  Ort::SessionOptions session_options;
22  session_options.SetIntraOpNumThreads(1);
23 
24  // Ignore all non-fatal errors. This isn't a good idea, but it's
25  // what we get for uploading semi-working graphs.
26  session_options.SetLogSeverityLevel(4);
27  session_options.SetGraphOptimizationLevel(
28  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
29 
30  // declare an allocator with default options
31  Ort::AllocatorWithDefaultOptions allocator;
32 
33  // create session and load model into memory
34  m_session = std::make_unique<Ort::Session>(
35  *m_env, path_to_onnx.c_str(), session_options);
36 
37  // get metadata from the onnx model
38  m_metadata = loadMetadata("gnn_config");
39  m_num_inputs = m_session->GetInputCount();
40  m_num_outputs = m_session->GetOutputCount();
41 
42  // get the onnx model version
43  if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
44  m_onnx_model_version = m_metadata["onnx_model_version"].get<OnnxModelVersion>();
45  if (m_onnx_model_version == OnnxModelVersion::UNKNOWN){
46  throw std::runtime_error("Unknown Onnx model version!");
47  }
48  } else { // metadata version is not set, infer from the presence of "outputs" key
49  if (m_metadata.contains("outputs")){
50  m_onnx_model_version = OnnxModelVersion::V0;
51  } else {
52  throw std::runtime_error("Onnx model version not found in metadata");
53  }
54  }
55 
56  // get the model name
57  m_model_name = determineModelName();
58 
59  // iterate over input nodes and get their names
60  for (size_t i = 0; i < m_num_inputs; i++) {
61  std::string input_name = m_session->GetInputNameAllocated(i, allocator).get();
62  m_input_node_names.push_back(input_name);
63  }
64 
65  // iterate over output nodes and get their configuration
66  for (size_t i = 0; i < m_num_outputs; i++) {
67  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
68  const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
69  const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
70  if (m_onnx_model_version == OnnxModelVersion::V0) {
71  const OnnxOutput onnxOutput(name, type, m_model_name);
72  m_output_nodes.push_back(onnxOutput);
73  } else {
74  const OnnxOutput onnxOutput(name, type, rank);
75  m_output_nodes.push_back(onnxOutput);
76  }
77  }
78  }
79 
80  const nlohmann::json OnnxUtil::loadMetadata(const std::string& key) const {
81  Ort::AllocatorWithDefaultOptions allocator;
82  Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata();
83  std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
84  return nlohmann::json::parse(metadataString);
85  }
86 
87  const std::string OnnxUtil::determineModelName() const {
88  Ort::AllocatorWithDefaultOptions allocator;
89  if (m_onnx_model_version == OnnxModelVersion::V0) {
90  // get the model name directly from the metadata
91  return std::string(m_metadata["outputs"].begin().key());
92  } else {
93  // get the model name from the output node names
94  // each output node name is of the form "<model_name>_<output_name>"
95  std::set<std::string> model_names;
96  for (size_t i = 0; i < m_num_outputs; i++) {
97  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
98  size_t underscore_pos = name.find('_');
99  if (underscore_pos != std::string::npos) {
100  std::string substring = name.substr(0, underscore_pos);
101  model_names.insert(substring);
102  } else {
103  return std::string("UnknownModelName");
104  }
105  }
106  if (model_names.size() != 1) {
107  throw std::runtime_error("OnnxUtil: model names are not consistent between outputs");
108  }
109  return *model_names.begin();
110  }
111 
112  }
113 
114  const lwt::GraphConfig OnnxUtil::getLwtConfig() const {
115  /* for the new metadata format (>V0), the outputs are inferred directly from
116  the model graph, rather than being configured as json metadata.
117  however we still need to add an empty "outputs" key to the config so that
118  lwt::parse_json_graph doesn't throw an exception */
119 
120  // deep copy the metadata by round tripping through a string stream
121  nlohmann::json metadataCopy = nlohmann::json::parse(m_metadata.dump());
122  if (getOnnxModelVersion() != OnnxModelVersion::V0){
123  metadataCopy["outputs"] = nlohmann::json::object();
124  }
125  std::stringstream metadataStream;
126  metadataStream << metadataCopy.dump();
127  return lwt::parse_json_graph(metadataStream);
128  }
129 
130  const nlohmann::json& OnnxUtil::getMetadata() const {
131  return m_metadata;
132  }
133 
134  const OnnxUtil::OutputConfig& OnnxUtil::getOutputConfig() const {
135  return m_output_nodes;
136  }
137 
138  OnnxModelVersion OnnxUtil::getOnnxModelVersion() const {
139  return m_onnx_model_version;
140  }
141 
142  const std::string& OnnxUtil::getModelName() const {
143  return m_model_name;
144  }
145 
146 
147  OnnxUtil::InferenceOutput OnnxUtil::runInference(
148  std::map<std::string, Inputs>& gnn_inputs) const {
149 
150  std::vector<float> input_tensor_values;
151 
152  // create input tensor object from data values
153  auto memory_info = Ort::MemoryInfo::CreateCpu(
154  OrtArenaAllocator, OrtMemTypeDefault
155  );
156  std::vector<Ort::Value> input_tensors;
157  for (auto& node_name : m_input_node_names) {
158  input_tensors.push_back(Ort::Value::CreateTensor<float>(
159  memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
160  gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
161  );
162  }
163 
164  // casting vector<string> to vector<const char*>. this is what ORT expects
165  std::vector<const char*> input_node_names;
166  input_node_names.reserve(m_input_node_names.size());
167  for (const auto& name : m_input_node_names) {
168  input_node_names.push_back(name.c_str());
169  }
170  std::vector<const char*> output_node_names;
171  output_node_names.reserve(m_output_nodes.size());
172  for (const auto& node : m_output_nodes) {
173  output_node_names.push_back(node.name_in_model.c_str());
174  }
175 
176  // score model & input tensor, get back output tensor
177  // Although Session::Run is non-const, the onnx authors say
178  // it is safe to call from multiple threads:
179  // https://github.com/microsoft/onnxruntime/discussions/10107
181  auto output_tensors = session.Run(Ort::RunOptions{nullptr},
182  input_node_names.data(), input_tensors.data(), input_node_names.size(),
183  output_node_names.data(), output_node_names.size()
184  );
185 
186  // Extract outputs with improved clarity and structure
187  InferenceOutput output;
188  for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
189  const auto& output_node = m_output_nodes[node_idx];
190  const auto& tensor = output_tensors[node_idx];
191  auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
192  auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
193  int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
194  if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
195  if (tensor_shape.size() == 0) {
196  output.singleFloat[output_node.name] = *tensor.GetTensorData<float>();
197  } else if (tensor_shape.size() == 1) {
198  const float* data = tensor.GetTensorData<float>();
199  output.vecFloat[output_node.name] = std::vector<float>(data, data + length);
200  } else {
201  throw std::runtime_error("Unsupported tensor shape for FLOAT type");
202  }
203  } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
204  if (tensor_shape.size() == 1) {
205  const char* data = tensor.GetTensorData<char>();
206  output.vecChar[output_node.name] = std::vector<char>(data, data + length);
207  } else {
208  throw std::runtime_error("Unsupported tensor shape for INT8 type");
209  }
210  } else {
211  throw std::runtime_error("Unsupported tensor type");
212  }
213  }
214 
215  return output;
216  }
217 
218 } // end of FlavorTagDiscriminants namespace
python.StoreID.UNKNOWN
int UNKNOWN
Definition: StoreID.py:16
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
json
nlohmann::json json
Definition: HistogramDef.cxx:9
make_unique
std::unique_ptr< T > make_unique(Args &&... args)
Definition: SkimmingToolEXOT5.cxx:23
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:983
OnnxUtil::m_session
std::unique_ptr< Ort::Session > m_session
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:40
PlotCalibFromCool.begin
begin
Definition: PlotCalibFromCool.py:94
OnnxUtil::OnnxUtil
OnnxUtil(const std::string &name)
Definition: FlavorTagDiscriminants/Root/OnnxUtil.cxx:16
python.HanMetadata.getMetadata
def getMetadata(f, key)
Definition: HanMetadata.py:12
python.oracle.Session
Session
Definition: oracle.py:78
plotting.efficiency.substring
string substring
Definition: efficiency.py:21
lumiFormat.i
int i
Definition: lumiFormat.py:92
OnnxUtil::m_input_node_names
std::vector< std::string > m_input_node_names
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:38
merge.output
output
Definition: merge.py:17
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
OnnxUtil::runInference
void runInference(const std::vector< std::vector< float >> &node_feat, std::vector< float > &effAllJet) const
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx:64
FlavorTagDiscriminants::OnnxModelVersion
OnnxModelVersion
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:30
OnnxUtil::m_env
std::unique_ptr< Ort::Env > m_env
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:41
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
pickleTool.object
object
Definition: pickleTool.py:30
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
checker_macros.h
Define macros for attributes used to control the static checker.
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
node
Definition: memory_hooks-stdcmalloc.h:74
OnnxUtil.h
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37