Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
SaltModel.cxx
Go to the documentation of this file.
1 /*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 #include "lwtnn/parse_json.hh"
8 
9 #include <stdexcept>
10 #include <tuple>
11 #include <set>
12 
13 namespace FlavorTagInference {
14 
15  SaltModel::SaltModel(const std::string& path_to_onnx)
16  //load the onnx model to memory using the path m_path_to_onnx
17  : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, ""))
18  {
19  // initialize session options
20  Ort::SessionOptions session_options;
21  session_options.SetIntraOpNumThreads(1);
22 
23  // Ignore all non-fatal errors. This isn't a good idea, but it's
24  // what we get for uploading semi-working graphs.
25  session_options.SetLogSeverityLevel(4);
26  session_options.SetGraphOptimizationLevel(
27  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
28 
29  // declare an allocator with default options
30  Ort::AllocatorWithDefaultOptions allocator;
31 
32  // create session and load model into memory
33  m_session = std::make_unique<Ort::Session>(
34  *m_env, path_to_onnx.c_str(), session_options);
35 
36  // get metadata from the onnx model
37  m_metadata = loadMetadata("gnn_config");
38  m_num_inputs = m_session->GetInputCount();
39  m_num_outputs = m_session->GetOutputCount();
40 
41  // get the onnx model version
42  if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
43  m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
44  if (m_onnx_model_version == SaltModelVersion::UNKNOWN){
45  throw std::runtime_error("Unknown Onnx model version!");
46  }
47  } else { // metadata version is not set, infer from the presence of "outputs" key
48  if (m_metadata.contains("outputs")){
49  m_onnx_model_version = SaltModelVersion::V0;
50  } else {
51  throw std::runtime_error("Onnx model version not found in metadata");
52  }
53  }
54 
55  // get the model name
56  m_model_name = determineModelName();
57 
58  // iterate over input nodes and get their names
59  for (size_t i = 0; i < m_num_inputs; i++) {
60  std::string input_name = m_session->GetInputNameAllocated(i, allocator).get();
61  m_input_node_names.push_back(input_name);
62  }
63 
64  // iterate over output nodes and get their configuration
65  for (size_t i = 0; i < m_num_outputs; i++) {
66  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
67  const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
68  const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
69  if (m_onnx_model_version == SaltModelVersion::V0) {
70  const SaltModelOutput saltModelOutput(name, type, m_model_name);
71  m_output_nodes.push_back(saltModelOutput);
72  } else {
73  const SaltModelOutput saltModelOutput(name, type, rank);
74  m_output_nodes.push_back(saltModelOutput);
75  }
76  }
77  }
78 
79  const nlohmann::json SaltModel::loadMetadata(const std::string& key) const {
80  Ort::AllocatorWithDefaultOptions allocator;
81  Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata();
82  std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
83  return nlohmann::json::parse(metadataString);
84  }
85 
86  const std::string SaltModel::determineModelName() const {
87  Ort::AllocatorWithDefaultOptions allocator;
88  if (m_onnx_model_version == SaltModelVersion::V0) {
89  // get the model name directly from the metadata
90  return std::string(m_metadata["outputs"].begin().key());
91  } else {
92  // get the model name from the output node names
93  // each output node name is of the form "<model_name>_<output_name>"
94  std::set<std::string> model_names;
95  for (size_t i = 0; i < m_num_outputs; i++) {
96  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
97  size_t underscore_pos = name.find('_');
98  if (underscore_pos != std::string::npos) {
99  std::string substring = name.substr(0, underscore_pos);
100  model_names.insert(substring);
101  } else {
102  return std::string("UnknownModelName");
103  }
104  }
105  if (model_names.size() != 1) {
106  throw std::runtime_error("SaltModel: model names are not consistent between outputs");
107  }
108  return *model_names.begin();
109  }
110 
111  }
112 
113  const lwt::GraphConfig SaltModel::getLwtConfig() const {
114  /* for the new metadata format (>V0), the outputs are inferred directly from
115  the model graph, rather than being configured as json metadata.
116  however we still need to add an empty "outputs" key to the config so that
117  lwt::parse_json_graph doesn't throw an exception */
118 
119  // deep copy the metadata by round tripping through a string stream
120  nlohmann::json metadataCopy = nlohmann::json::parse(m_metadata.dump());
121  if (getSaltModelVersion() != SaltModelVersion::V0){
122  metadataCopy["outputs"] = nlohmann::json::object();
123  }
124  std::stringstream metadataStream;
125  metadataStream << metadataCopy.dump();
126  return lwt::parse_json_graph(metadataStream);
127  }
128 
129  const nlohmann::json& SaltModel::getMetadata() const {
130  return m_metadata;
131  }
132 
133  const SaltModel::OutputConfig& SaltModel::getOutputConfig() const {
134  return m_output_nodes;
135  }
136 
137  SaltModelVersion SaltModel::getSaltModelVersion() const {
138  return m_onnx_model_version;
139  }
140 
141  const std::string& SaltModel::getModelName() const {
142  return m_model_name;
143  }
144 
145 
146  SaltModel::InferenceOutput SaltModel::runInference(
147  std::map<std::string, Inputs>& gnn_inputs) const {
148 
149  std::vector<float> input_tensor_values;
150 
151  // create input tensor object from data values
152  auto memory_info = Ort::MemoryInfo::CreateCpu(
153  OrtArenaAllocator, OrtMemTypeDefault
154  );
155  std::vector<Ort::Value> input_tensors;
156  for (auto& node_name : m_input_node_names) {
157  input_tensors.push_back(Ort::Value::CreateTensor<float>(
158  memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
159  gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
160  );
161  }
162 
163  // casting vector<string> to vector<const char*>. this is what ORT expects
164  std::vector<const char*> input_node_names;
165  input_node_names.reserve(m_input_node_names.size());
166  for (const auto& name : m_input_node_names) {
167  input_node_names.push_back(name.c_str());
168  }
169  std::vector<const char*> output_node_names;
170  output_node_names.reserve(m_output_nodes.size());
171  for (const auto& node : m_output_nodes) {
172  output_node_names.push_back(node.name_in_model.c_str());
173  }
174 
175  // score model & input tensor, get back output tensor
176  // Although Session::Run is non-const, the onnx authors say
177  // it is safe to call from multiple threads:
178  // https://github.com/microsoft/onnxruntime/discussions/10107
180  auto output_tensors = session.Run(Ort::RunOptions{nullptr},
181  input_node_names.data(), input_tensors.data(), input_node_names.size(),
182  output_node_names.data(), output_node_names.size()
183  );
184 
185  // Extract outputs with improved clarity and structure
186  InferenceOutput output;
187  for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
188  const auto& output_node = m_output_nodes[node_idx];
189  const auto& tensor = output_tensors[node_idx];
190  auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
191  auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
192  int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
193  if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
194  if (tensor_shape.size() == 0) {
195  output.singleFloat[output_node.name] = *tensor.GetTensorData<float>();
196  } else if (tensor_shape.size() == 1) {
197  const float* data = tensor.GetTensorData<float>();
198  output.vecFloat[output_node.name] = std::vector<float>(data, data + length);
199  } else {
200  throw std::runtime_error("Unsupported tensor shape for FLOAT type");
201  }
202  } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
203  if (tensor_shape.size() == 1) {
204  const char* data = tensor.GetTensorData<char>();
205  output.vecChar[output_node.name] = std::vector<char>(data, data + length);
206  } else {
207  throw std::runtime_error("Unsupported tensor shape for INT8 type");
208  }
209  } else {
210  throw std::runtime_error("Unsupported tensor type");
211  }
212  }
213 
214  return output;
215  }
216 
217 } // end of FlavorTagInference namespace
SaltModel.h
python.StoreID.UNKNOWN
int UNKNOWN
Definition: StoreID.py:16
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
SaltModel::m_env
std::unique_ptr< Ort::Env > m_env
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:41
json
nlohmann::json json
Definition: HistogramDef.cxx:9
make_unique
std::unique_ptr< T > make_unique(Args &&... args)
Definition: SkimmingToolEXOT5.cxx:23
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1082
PlotCalibFromCool.begin
begin
Definition: PlotCalibFromCool.py:94
python.HanMetadata.getMetadata
def getMetadata(f, key)
Definition: HanMetadata.py:12
python.oracle.Session
Session
Definition: oracle.py:78
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
lumiFormat.i
int i
Definition: lumiFormat.py:85
SaltModel::m_session
std::unique_ptr< Ort::Session > m_session
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:40
SaltModel::m_input_node_names
std::vector< std::string > m_input_node_names
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:38
merge.output
output
Definition: merge.py:17
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
SaltModel::SaltModel
SaltModel(const std::string &name)
Definition: SaltModel.cxx:15
SaltModel::runInference
void runInference(const std::vector< std::vector< float >> &node_feat, std::vector< float > &effAllJet) const
Definition: OnnxUtil.cxx:64
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
FlavorTagInference::SaltModelVersion
SaltModelVersion
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:30
pickleTool.object
object
Definition: pickleTool.py:30
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
checker_macros.h
Define macros for attributes used to control the static checker.
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
node
Definition: node.h:21
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37