ATLAS Offline Software
SaltModel.cxx
Go to the documentation of this file.
1 /*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
8 
9 #include <stdexcept>
10 #include <tuple>
11 #include <set>
12 
13 namespace FlavorTagInference {
14 
15  SaltModel::SaltModel(const std::string& path_to_onnx)
16  //load the onnx model to memory using the path m_path_to_onnx
17  : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, ""))
18  {
19  // initialize session options
20  Ort::SessionOptions session_options;
21  session_options.SetIntraOpNumThreads(1);
22 
23  // Ignore all non-fatal errors. This isn't a good idea, but it's
24  // what we get for uploading semi-working graphs.
25  session_options.SetLogSeverityLevel(4);
26  session_options.SetGraphOptimizationLevel(
27  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
28  // this should reduce memory use while slowing things down slightly
29  // see
30  //
31  // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
32  //
33  // and also https://its.cern.ch/jira/browse/AFT-818
34  //
35  session_options.DisableCpuMemArena();
36 
37  // declare an allocator with default options
38  Ort::AllocatorWithDefaultOptions allocator;
39 
40  // create session and load model into memory
41  m_session = std::make_unique<Ort::Session>(
42  *m_env, path_to_onnx.c_str(), session_options);
43 
44  // get metadata from the onnx model
45  m_metadata = loadMetadata("gnn_config");
46  m_num_inputs = m_session->GetInputCount();
47  m_num_outputs = m_session->GetOutputCount();
48 
49  // get the onnx model version
50  if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
51  m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
52  if (m_onnx_model_version == SaltModelVersion::UNKNOWN){
53  throw std::runtime_error("Unknown Onnx model version!");
54  }
55  } else { // metadata version is not set, infer from the presence of "outputs" key
56  if (m_metadata.contains("outputs")){
57  m_onnx_model_version = SaltModelVersion::V0;
58  } else {
59  throw std::runtime_error("Onnx model version not found in metadata");
60  }
61  }
62 
63  // get the model name
64  m_model_name = determineModelName();
65 
66  // iterate over input nodes and get their names
67  for (size_t i = 0; i < m_num_inputs; i++) {
68  std::string input_name = m_session->GetInputNameAllocated(i, allocator).get();
69  m_input_node_names.push_back(input_name);
70  }
71 
72  // iterate over output nodes and get their configuration
73  for (size_t i = 0; i < m_num_outputs; i++) {
74  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
75  const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
76  const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
77  if (m_onnx_model_version == SaltModelVersion::V0) {
78  const SaltModelOutput saltModelOutput(name, type, m_model_name);
79  m_output_nodes.push_back(saltModelOutput);
80  } else {
81  const SaltModelOutput saltModelOutput(name, type, rank);
82  m_output_nodes.push_back(saltModelOutput);
83  }
84  }
85  }
86 
87  const nlohmann::json SaltModel::loadMetadata(const std::string& key) const {
88  Ort::AllocatorWithDefaultOptions allocator;
89  Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata();
90  std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
91  return nlohmann::json::parse(metadataString);
92  }
93 
94  const std::string SaltModel::determineModelName() const {
95  Ort::AllocatorWithDefaultOptions allocator;
96  if (m_onnx_model_version == SaltModelVersion::V0) {
97  // get the model name directly from the metadata
98  return std::string(m_metadata["outputs"].begin().key());
99  } else {
100  // get the model name from the output node names
101  // each output node name is of the form "<model_name>_<output_name>"
102  std::set<std::string> model_names;
103  for (size_t i = 0; i < m_num_outputs; i++) {
104  const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
105  size_t underscore_pos = name.find('_');
106  if (underscore_pos != std::string::npos) {
107  std::string substring = name.substr(0, underscore_pos);
108  model_names.insert(substring);
109  } else {
110  return std::string("UnknownModelName");
111  }
112  }
113  if (model_names.size() != 1) {
114  throw std::runtime_error("SaltModel: model names are not consistent between outputs");
115  }
116  return *model_names.begin();
117  }
118 
119  }
120 
121  const SaltModelGraphConfig::GraphConfig SaltModel::getGraphConfig() const {
122  return SaltModelGraphConfig::parse_json_graph(m_metadata);
123  }
124 
125  const nlohmann::json& SaltModel::getMetadata() const {
126  return m_metadata;
127  }
128 
129  const SaltModel::OutputConfig& SaltModel::getOutputConfig() const {
130  return m_output_nodes;
131  }
132 
133  SaltModelVersion SaltModel::getSaltModelVersion() const {
134  return m_onnx_model_version;
135  }
136 
137  const std::string& SaltModel::getModelName() const {
138  return m_model_name;
139  }
140 
141 
142  SaltModel::InferenceOutput SaltModel::runInference(
143  std::map<std::string, Inputs>& gnn_inputs) const {
144 
145  std::vector<float> input_tensor_values;
146 
147  // create input tensor object from data values
148  auto memory_info = Ort::MemoryInfo::CreateCpu(
149  OrtArenaAllocator, OrtMemTypeDefault
150  );
151  std::vector<Ort::Value> input_tensors;
152  for (auto& node_name : m_input_node_names) {
153  input_tensors.push_back(Ort::Value::CreateTensor<float>(
154  memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
155  gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
156  );
157  }
158 
159  // casting vector<string> to vector<const char*>. this is what ORT expects
160  std::vector<const char*> input_node_names;
161  input_node_names.reserve(m_input_node_names.size());
162  for (const auto& name : m_input_node_names) {
163  input_node_names.push_back(name.c_str());
164  }
165  std::vector<const char*> output_node_names;
166  output_node_names.reserve(m_output_nodes.size());
167  for (const auto& node : m_output_nodes) {
168  output_node_names.push_back(node.name_in_model.c_str());
169  }
170 
171  // score model & input tensor, get back output tensor
172  // Although Session::Run is non-const, the onnx authors say
173  // it is safe to call from multiple threads:
174  // https://github.com/microsoft/onnxruntime/discussions/10107
176  auto output_tensors = session.Run(Ort::RunOptions{nullptr},
177  input_node_names.data(), input_tensors.data(), input_node_names.size(),
178  output_node_names.data(), output_node_names.size()
179  );
180 
181  // Extract outputs with improved clarity and structure
182  InferenceOutput output;
183  for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
184  const auto& output_node = m_output_nodes[node_idx];
185  const auto& tensor = output_tensors[node_idx];
186  auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
187  auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
188  int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
189  if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
190  if (tensor_shape.size() == 0) {
191  output.singleFloat[output_node.name] = *tensor.GetTensorData<float>();
192  } else if (tensor_shape.size() == 1) {
193  const float* data = tensor.GetTensorData<float>();
194  output.vecFloat[output_node.name] = std::vector<float>(data, data + length);
195  } else {
196  throw std::runtime_error("Unsupported tensor shape for FLOAT type");
197  }
198  } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
199  if (tensor_shape.size() == 1) {
200  const char* data = tensor.GetTensorData<char>();
201  output.vecChar[output_node.name] = std::vector<char>(data, data + length);
202  } else {
203  throw std::runtime_error("Unsupported tensor shape for INT8 type");
204  }
205  } else {
206  throw std::runtime_error("Unsupported tensor type");
207  }
208  }
209 
210  return output;
211  }
212 
213 } // end of FlavorTagInference namespace
SaltModel.h
python.StoreID.UNKNOWN
int UNKNOWN
Definition: StoreID.py:16
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
SaltModel::m_env
std::unique_ptr< Ort::Env > m_env
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:41
json
nlohmann::json json
Definition: HistogramDef.cxx:9
FlavorTagInference::SaltModelGraphConfig::parse_json_graph
GraphConfig parse_json_graph(const nlohmann::json &metadata)
Definition: SaltModelGraphConfig.cxx:40
make_unique
std::unique_ptr< T > make_unique(Args &&... args)
Definition: SkimmingToolEXOT5.cxx:23
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1113
PlotCalibFromCool.begin
begin
Definition: PlotCalibFromCool.py:94
python.HanMetadata.getMetadata
def getMetadata(f, key)
Definition: HanMetadata.py:12
python.oracle.Session
Session
Definition: oracle.py:78
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
lumiFormat.i
int i
Definition: lumiFormat.py:85
SaltModel::m_session
std::unique_ptr< Ort::Session > m_session
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:40
SaltModel::m_input_node_names
std::vector< std::string > m_input_node_names
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:38
merge.output
output
Definition: merge.py:16
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
SaltModel::SaltModel
SaltModel(const std::string &name)
Definition: SaltModel.cxx:15
SaltModel::runInference
void runInference(const std::vector< std::vector< float >> &node_feat, std::vector< float > &effAllJet) const
Definition: OnnxUtil.cxx:64
FlavorTagInference::SaltModelVersion
SaltModelVersion
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:30
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
checker_macros.h
Define macros for attributes used to control the static checker.
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
node
Definition: node.h:21
SaltModelGraphConfig.h
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37