ATLAS Offline Software
SaltModel.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
8 
9 #include <stdexcept>
10 #include <tuple>
11 #include <set>
12 
13 namespace FlavorTagInference {
14 
15  SaltModel::SaltModel(const std::string& path_to_onnx)
16  //load the onnx model to memory using the path m_path_to_onnx
17  : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, ""))
18  {
19  // initialize session options
20  Ort::SessionOptions session_options;
21  session_options.SetIntraOpNumThreads(1);
22 
23  // Ignore all non-fatal errors. This isn't a good idea, but it's
24  // what we get for uploading semi-working graphs.
25  session_options.SetLogSeverityLevel(4);
26  session_options.SetGraphOptimizationLevel(
27  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
28  // this should reduce memory use while slowing things down slightly
29  // see
30  //
31  // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
32  //
33  // and also https://its.cern.ch/jira/browse/AFT-818
34  //
35  session_options.DisableCpuMemArena();
36 
37  // declare an allocator with default options
38  Ort::AllocatorWithDefaultOptions allocator;
39 
40  // create session and load model into memory
41  m_session = std::make_unique<Ort::Session>(
42  *m_env, path_to_onnx.c_str(), session_options);
43 
44  // get metadata from the onnx model
45  m_metadata = loadMetadata("gnn_config");
46  m_num_inputs = m_session->GetInputCount();
47  m_num_outputs = m_session->GetOutputCount();
48 
49  // get the onnx model version
50  if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
51  m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
53  throw std::runtime_error("Unknown Onnx model version!");
54  }
55  } else { // metadata version is not set, infer from the presence of "outputs" key
56  if (m_metadata.contains("outputs")){
58  } else {
59  throw std::runtime_error("Onnx model version not found in metadata");
60  }
61  }
62 
63  // get the model name
65 
66  // iterate over input nodes and get their names
67  for (size_t i = 0; i < m_num_inputs; i++) {
68  m_input_node_names.push_back(m_session->GetInputNameAllocated(i, allocator).get());
69  }
70 
71  // iterate over output nodes and get their configuration
72  for (size_t i = 0; i < m_num_outputs; i++) {
73  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
74  const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
75  const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
77  m_output_nodes.emplace_back(name, type, m_model_name);
78  } else {
79  m_output_nodes.emplace_back(name, type, rank);
80  }
81  }
82  }
83 
84  const nlohmann::json SaltModel::loadMetadata(const std::string& key) const {
85  Ort::AllocatorWithDefaultOptions allocator;
86  Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata();
87  std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
88  return nlohmann::json::parse(metadataString);
89  }
90 
91  const std::string SaltModel::determineModelName() const {
92  Ort::AllocatorWithDefaultOptions allocator;
94  // get the model name directly from the metadata
95  return std::string(m_metadata["outputs"].begin().key());
96  } else {
97  // get the model name from the output node names
98  // each output node name is of the form "<model_name>_<output_name>"
99  std::set<std::string> model_names;
100  for (size_t i = 0; i < m_num_outputs; i++) {
101  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
102  size_t underscore_pos = name.find('_');
103  if (underscore_pos != std::string::npos) {
104  model_names.insert(name.substr(0, underscore_pos));
105  } else {
106  return std::string("UnknownModelName");
107  }
108  }
109  if (model_names.size() != 1) {
110  throw std::runtime_error("SaltModel: model names are not consistent between outputs");
111  }
112  return *model_names.begin();
113  }
114 
115  }
116 
119  }
120 
122  return m_output_nodes;
123  }
124 
126  return m_onnx_model_version;
127  }
128 
129  const std::string& SaltModel::getModelName() const {
130  return m_model_name;
131  }
132 
133 
135  std::map<std::string, Inputs>& gnn_inputs) const {
136 
137  std::vector<float> input_tensor_values;
138 
139  // create input tensor object from data values
140  auto memory_info = Ort::MemoryInfo::CreateCpu(
141  OrtArenaAllocator, OrtMemTypeDefault
142  );
143  std::vector<Ort::Value> input_tensors;
144  for (auto& node_name : m_input_node_names) {
145  input_tensors.push_back(Ort::Value::CreateTensor<float>(
146  memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
147  gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
148  );
149  }
150 
151  // casting vector<string> to vector<const char*>. this is what ORT expects
152  std::vector<const char*> input_node_names;
153  input_node_names.reserve(m_input_node_names.size());
154  for (const auto& name : m_input_node_names) {
155  input_node_names.push_back(name.c_str());
156  }
157  std::vector<const char*> output_node_names;
158  output_node_names.reserve(m_output_nodes.size());
159  for (const auto& node : m_output_nodes) {
160  output_node_names.push_back(node.name_in_model.c_str());
161  }
162 
163  // score model & input tensor, get back output tensor
164  // Although Session::Run is non-const, the onnx authors say
165  // it is safe to call from multiple threads:
166  // https://github.com/microsoft/onnxruntime/discussions/10107
168  auto output_tensors = session.Run(Ort::RunOptions{nullptr},
169  input_node_names.data(), input_tensors.data(), input_node_names.size(),
170  output_node_names.data(), output_node_names.size()
171  );
172 
173  // Extract outputs with improved clarity and structure
174  InferenceOutput output;
175  for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
176  const auto& output_node = m_output_nodes[node_idx];
177  const auto& tensor = output_tensors[node_idx];
178  auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
179  auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
180  int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
181  if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
182  if (tensor_shape.size() == 0) {
183  output.singleFloat[output_node.name] = *tensor.GetTensorData<float>();
184  } else if (tensor_shape.size() == 1) {
185  const float* data = tensor.GetTensorData<float>();
186  output.vecFloat[output_node.name] = std::vector<float>(data, data + length);
187  } else {
188  throw std::runtime_error("Unsupported tensor shape for FLOAT type");
189  }
190  } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
191  if (tensor_shape.size() == 1) {
192  const char* data = tensor.GetTensorData<char>();
193  output.vecChar[output_node.name] = std::vector<char>(data, data + length);
194  } else {
195  throw std::runtime_error("Unsupported tensor shape for INT8 type");
196  }
197  } else {
198  throw std::runtime_error("Unsupported tensor type");
199  }
200  }
201 
202  return output;
203  }
204 
205 } // end of FlavorTagInference namespace
SaltModel.h
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
Ort
Definition: SaltModelTriton.h:20
FlavorTagInference::SaltModel::m_num_outputs
size_t m_num_outputs
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:45
json
nlohmann::json json
Definition: HistogramDef.cxx:9
FlavorTagInference::SaltModelGraphConfig::parse_json_graph
GraphConfig parse_json_graph(const nlohmann::json &metadata)
Definition: SaltModelGraphConfig.cxx:40
FlavorTagInference::SaltModelVersion::V0
@ V0
FlavorTagInference::SaltModel::m_output_nodes
OutputConfig m_output_nodes
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:48
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1113
ReadBchFromCrest.begin
begin
Definition: ReadBchFromCrest.py:80
FlavorTagInference::SaltModel::m_metadata
nlohmann::json m_metadata
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:39
FlavorTagInference::SaltModel::m_session
std::unique_ptr< Ort::Session > m_session
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:41
FlavorTagInference::SaltModel::getGraphConfig
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
Definition: SaltModel.cxx:117
python.oracle.Session
Session
Definition: oracle.py:76
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
FlavorTagInference::SaltModel::getSaltModelVersion
virtual SaltModelVersion getSaltModelVersion() const override
Definition: SaltModel.cxx:125
lumiFormat.i
int i
Definition: lumiFormat.py:85
FlavorTagInference::SaltModel::determineModelName
const std::string determineModelName() const
Definition: SaltModel.cxx:91
FlavorTagInference::SaltModelVersion::UNKNOWN
@ UNKNOWN
FlavorTagInference::SaltModel::runInference
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
Definition: SaltModel.cxx:134
FlavorTagInference::SaltModel::m_onnx_model_version
SaltModelVersion m_onnx_model_version
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:50
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
FlavorTagInference::InferenceOutput
Definition: ISaltModel.h:38
FlavorTagInference::SaltModel::SaltModel
SaltModel(const std::string &path_to_onnx)
Definition: SaltModel.cxx:15
FlavorTagInference::OutputConfig
std::vector< SaltModelOutput > OutputConfig
Definition: ISaltModel.h:36
FlavorTagInference::SaltModel::getModelName
virtual const std::string & getModelName() const override
Definition: SaltModel.cxx:129
FlavorTagInference::SaltModelVersion
SaltModelVersion
Definition: ISaltModel.h:25
FlavorTagInference::SaltModel::m_input_node_names
std::vector< std::string > m_input_node_names
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:47
FlavorTagInference::SaltModel::m_num_inputs
size_t m_num_inputs
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:44
FlavorTagInference::SaltModel::m_env
std::unique_ptr< Ort::Env > m_env
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:42
FlavorTagInference::SaltModel::loadMetadata
const nlohmann::json loadMetadata(const std::string &key) const
Definition: SaltModel.cxx:84
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
FlavorTagInference::SaltModel::getOutputConfig
virtual const OutputConfig & getOutputConfig() const override
Definition: SaltModel.cxx:121
FlavorTagInference::SaltModelGraphConfig::GraphConfig
Definition: SaltModelGraphConfig.h:36
FlavorTagInference::SaltModel::m_model_name
std::string m_model_name
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:46
checker_macros.h
Define macros for attributes used to control the static checker.
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
node
Definition: node.h:24
SaltModelGraphConfig.h
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37