ATLAS Offline Software
SaltModelTriton.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
9 #include <onnxruntime_cxx_api.h>
10 
11 #include <stdexcept>
12 #include <tuple>
13 #include <set>
14 #include <memory>
15 
16 // DType traits for Triton
17 template <typename T> struct TritonDType;
18 template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
19 template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
20 
21 template <typename T>
22 bool prepareInput(const std::string& name,
23  const std::vector<int64_t>& shape,
24  const std::vector<T>& data,
25  std::vector<std::shared_ptr<tc::InferInput>>& inputs)
26 {
27  const char* dtype = TritonDType<T>::value;
28  tc::InferInput* rawInputPtr = nullptr;
29 
30  // create the InferInput object with the predefined name, shape, and data type.
31  tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
32  if(!err.IsOk()) {
33  std::cerr << "Unable to create input: " + name << std::endl;
34  return false;
35  }
36 
37  // Append tensor values for this input from a byte array.
38  // Note: The vector is not copied and so it must not be modified or destroyed
39  // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
40  // Multiple calls can be made to this API to keep adding tensor data for this input.
41  // The data will be delivered in the order it was added.
42  std::shared_ptr<tc::InferInput> input(rawInputPtr);
43  err = input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
44  data.size() * sizeof(T));
45  if(!err.IsOk()) {
46  std::cerr << "Unable to set input data for: " + name << std::endl;
47  return false;
48  }
49 
50  inputs.push_back(std::move(input));
51  return true;
52 }
53 
54 template <typename T>
55 bool extractOutput(const std::string& name,
56  const std::shared_ptr<tc::InferResult>& result,
57  std::vector<T>& outputVec)
58 {
59  const uint8_t* rawData = nullptr;
60  size_t size = 0;
61 
62  // Get access to the buffer holding raw results of specified output returned by the server.
63  // Note: the buffer is owned by InferResult instance.
64  // Users can copy out the data if required to extend the lifetime.
65  tc::Error err = result->RawData(name, &rawData, &size);
66  if(!err.IsOk()) {
67  std::cerr << "Unable to get raw output for: " + name << std::endl;
68  return false;
69  }
70 
71  outputVec.resize(size / sizeof(T));
72  std::memcpy(outputVec.data(), rawData, size);
73  return true;
74 }
75 
76 
77 namespace FlavorTagInference {
78 
79  SaltModelTriton::SaltModelTriton(const std::string& path_to_onnx
80  , const std::string& model_name
81  , float client_timeout
82  , int port
83  , const std::string& url
84  , bool useSSL)
85  : m_model_name(model_name)
86  , m_clientTimeout(client_timeout)
87  , m_port(port)
88  , m_url(url)
89  , m_useSSL(useSSL)
90  //load the onnx model to memory using the path m_path_to_onnx
91  {
92  std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, "");
93 
94  // initialize session options
95  Ort::SessionOptions session_options;
96  session_options.SetIntraOpNumThreads(1);
97 
98  // Ignore all non-fatal errors. This isn't a good idea, but it's
99  // what we get for uploading semi-working graphs.
100  session_options.SetLogSeverityLevel(4);
101  session_options.SetGraphOptimizationLevel(
102  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
103  // this should reduce memory use while slowing things down slightly
104  // see
105  //
106  // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
107  //
108  // and also https://its.cern.ch/jira/browse/AFT-818
109  //
110  session_options.DisableCpuMemArena();
111 
112  // declare an allocator with default options
113  Ort::AllocatorWithDefaultOptions allocator;
114 
115  // create session and load model into memory
116  std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
117  *env, path_to_onnx.c_str(), session_options);
118 
119  // get metadata from the onnx model
120  m_metadata = loadMetadata("gnn_config", session.get());
121  m_num_outputs = session->GetOutputCount();
122 
123  // get the onnx model version
124  if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
125  m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
127  throw std::runtime_error("Unknown Onnx model version!");
128  }
129  } else { // metadata version is not set, infer from the presence of "outputs" key
130  if (m_metadata.contains("outputs")){
132  } else {
133  throw std::runtime_error("Onnx model version not found in metadata");
134  }
135  }
136 
137  // get the model name
138  m_model_type = determineModelType(session.get());
139 
140  // iterate over output nodes and get their configuration
141  for (size_t i = 0; i < m_num_outputs; i++) {
142  const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
143  const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
144  const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
146  const SaltModelOutput saltModelOutput(name, type, m_model_type);
147  m_output_nodes.push_back(std::move(saltModelOutput));
148  } else {
149  const SaltModelOutput saltModelOutput(name, type, rank);
150  m_output_nodes.push_back(std::move(saltModelOutput));
151  }
152  }
153 
154  m_options = std::make_unique<tc::InferOptions>(m_model_name);
155  m_options->model_version_ = ""; // Fixme: find a proper solution for this
156  m_options->client_timeout_ = m_clientTimeout;
157  }
158 
159  const nlohmann::json SaltModelTriton::loadMetadata(const std::string& key, const Ort::Session* session) const {
160  Ort::AllocatorWithDefaultOptions allocator;
161  Ort::ModelMetadata modelMetadata = session->GetModelMetadata();
162  std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
163  return nlohmann::json::parse(metadataString);
164  }
165 
166  const std::string SaltModelTriton::determineModelType(const Ort::Session* session) const {
167  Ort::AllocatorWithDefaultOptions allocator;
169  // get the model name directly from the metadata
170  return std::string(m_metadata["outputs"].begin().key());
171  }
172  else {
173  // get the model name from the output node names
174  // each output node name is of the form "<model_type>_<output_name>"
175  std::set<std::string> model_types;
176  for (size_t i = 0; i < m_num_outputs; i++) {
177  const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
178  size_t underscore_pos = name.find('_');
179  if (underscore_pos != std::string::npos) {
180  std::string substring = name.substr(0, underscore_pos);
181  model_types.insert(std::move(substring));
182  }
183  else {
184  return std::string("UnknownModelName");
185  }
186  }
187  if (model_types.size() != 1) {
188  throw std::runtime_error("SaltModelTriton: model names are not consistent between outputs");
189  }
190  return *model_types.begin();
191  }
192  }
193 
196  }
197 
199  return m_output_nodes;
200  }
201 
203  return m_onnx_model_version;
204  }
205 
206  const std::string& SaltModelTriton::getModelName() const {
207  return m_model_type;
208  }
209 
210 
212  std::map<std::string, Inputs>& gnn_inputs) const {
213 
214  // Create tensor for the input data
215  std::vector<std::shared_ptr<tc::InferInput> > inputs_;
216  inputs_.reserve(gnn_inputs.size());
217 
218  for (auto& [inputName, inputInfo]: gnn_inputs) {
219  const std::vector<float>& inputData = inputInfo.first; // ? good name ?
220  const std::vector<int64_t>& inputShape = inputInfo.second;
221  if(!prepareInput<float>(inputName, inputShape, inputData, inputs_)) {
222  throw std::runtime_error("Failed to prepare input for inference"); // ? more informative error message ?
223  }
224  }
225 
226  // construct raw points for inference
227  std::vector<tc::InferInput*> rawInputs;
228  for(auto& input : inputs_) {
229  rawInputs.push_back(input.get());
230  }
231 
232  // perform the inference
233  tc::InferResult* rawResultPtr = nullptr;
234  tc::Headers http_headers;
235  grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
236 
237  auto client = getClient();
238  if(client) {
239  tc::Error err = client->Infer(&rawResultPtr
240  , *m_options
241  , rawInputs
242  , {}
243  , http_headers
244  , compression_algorithm);
245  if(!err.IsOk()) {
246  throw std::runtime_error("unable to run model "+ m_model_name + " error: " + err.Message());
247  }
248  }
249  else {
250  throw std::runtime_error("Failed to create Triton gRPC client");
251  }
252 
253  // Get the result of the inference
254  InferenceOutput output;
255  std::shared_ptr<tc::InferResult> results(rawResultPtr);
256  for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
257  const auto& output_node = m_output_nodes[node_idx];
258  switch(output_node.type) {
260  {
261  std::vector<float> outputVecFloat;
262  extractOutput<float>(output_node.name, results, outputVecFloat);
263  output.vecFloat[output_node.name] = std::move(outputVecFloat);
264  }
265  break;
267  {
268  std::vector<float> outputFloat;
269  extractOutput<float>(output_node.name, results, outputFloat);
270  if(outputFloat.size()==1) {
271  output.singleFloat[output_node.name] = outputFloat[0];
272  }
273  else {
274  throw std::runtime_error("Vector of floats returned instead of a single float for " + output_node.name);
275  }
276  }
277  break;
279  {
280  std::vector<int8_t> outputVecInt;
281  extractOutput<int8_t>(output_node.name, results, outputVecInt);
282  // convert int8_t vector to char vector
283  std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
284  output.vecChar[output_node.name] = std::move(outputVecChar);
285  }
286  break;
288  [[fallthrough]];
289  default:
290  throw std::runtime_error("Unknown output type for the node " + output_node.name);
291  }
292  }
293 
294  return output;
295  }
296 
297  tc::InferenceServerGrpcClient* SaltModelTriton::getClient() const
298  {
299  thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
300  if(!threadClient) {
301  std::string url = m_url + ":" + std::to_string(m_port);
302  tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, false, m_useSSL);
303  if (!err.IsOk()) {
304  std::cerr << "SaltModelTriton ERROR: Failed to create Triton gRPC client for model: " << m_model_name
305  << " at URL: " << url << std::endl;
306  std::cerr << err.Message() << std::endl;
307  return nullptr;
308  }
309  }
310  return threadClient.get();
311  }
312 } // end of FlavorTagInference namespace
FlavorTagInference::SaltModelOutput::OutputType::VECCHAR
@ VECCHAR
TritonDType
Definition: SaltModelTriton.cxx:17
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
FlavorTagInference::SaltModelTriton::m_clientTimeout
float m_clientTimeout
Definition: SaltModelTriton.h:60
FlavorTagInference::SaltModelTriton::getOutputConfig
virtual const OutputConfig & getOutputConfig() const override
Definition: SaltModelTriton.cxx:198
get_generator_info.result
result
Definition: get_generator_info.py:21
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
xAOD::uint8_t
uint8_t
Definition: Muon_v1.cxx:553
SaltModelTriton.h
json
nlohmann::json json
Definition: HistogramDef.cxx:9
FlavorTagInference::SaltModelGraphConfig::parse_json_graph
GraphConfig parse_json_graph(const nlohmann::json &metadata)
Definition: SaltModelGraphConfig.cxx:40
FlavorTagInference::SaltModelTriton::getSaltModelVersion
virtual SaltModelVersion getSaltModelVersion() const override
Definition: SaltModelTriton.cxx:202
FlavorTagInference::SaltModelVersion::V0
@ V0
PlotCalibFromCool.dtype
dtype
Definition: PlotCalibFromCool.py:495
FlavorTagInference::SaltModelTriton::determineModelType
const std::string determineModelType(const Ort::Session *session) const
Definition: SaltModelTriton.cxx:166
parse
std::map< std::string, std::string > parse(const std::string &list)
Definition: egammaLayerRecalibTool.cxx:1113
ReadBchFromCrest.begin
begin
Definition: ReadBchFromCrest.py:80
FlavorTagInference::SaltModelTriton::m_output_nodes
OutputConfig m_output_nodes
Definition: SaltModelTriton.h:55
athena.value
value
Definition: athena.py:124
extractOutput
bool extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)
Definition: SaltModelTriton.cxx:55
MuonR4::to_string
std::string to_string(const SectorProjector proj)
Definition: MsTrackSeeder.cxx:66
FlavorTagInference::SaltModelOutput
Definition: SaltModelOutput.h:16
python.oracle.Session
Session
Definition: oracle.py:76
physics_parameters.url
string url
Definition: physics_parameters.py:27
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
FlavorTagInference::SaltModelTriton::m_num_outputs
size_t m_num_outputs
Definition: SaltModelTriton.h:52
prepareInput
bool prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput >> &inputs)
Definition: SaltModelTriton.cxx:22
rerun_display.client
client
Definition: rerun_display.py:31
FlavorTagInference::SaltModelTriton::m_onnx_model_version
SaltModelVersion m_onnx_model_version
Definition: SaltModelTriton.h:57
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
FlavorTagInference::SaltModelTriton::getModelName
virtual const std::string & getModelName() const override
Definition: SaltModelTriton.cxx:206
FlavorTagInference::SaltModelOutput::OutputType::UNKNOWN
@ UNKNOWN
dqt_zlumi_pandas.err
err
Definition: dqt_zlumi_pandas.py:183
lumiFormat.i
int i
Definition: lumiFormat.py:85
SaltModelOutput.h
FlavorTagInference::SaltModelTriton::getGraphConfig
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
Definition: SaltModelTriton.cxx:194
FlavorTagInference::SaltModelTriton::loadMetadata
const nlohmann::json loadMetadata(const std::string &key, const Ort::Session *session) const
Definition: SaltModelTriton.cxx:159
FlavorTagInference::SaltModelTriton::m_useSSL
bool m_useSSL
Definition: SaltModelTriton.h:63
FlavorTagInference::SaltModelTriton::m_url
std::string m_url
Definition: SaltModelTriton.h:62
add-xsec-uncert-quadrature-N.results
dictionary results
Definition: add-xsec-uncert-quadrature-N.py:39
FlavorTagInference::SaltModelTriton::m_metadata
nlohmann::json m_metadata
Definition: SaltModelTriton.h:50
FlavorTagInference::SaltModelVersion::UNKNOWN
@ UNKNOWN
FlavorTagInference::SaltModelTriton::runInference
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
Definition: SaltModelTriton.cxx:211
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
FlavorTagInference::SaltModelTriton::m_model_name
std::string m_model_name
Definition: SaltModelTriton.h:53
FlavorTagInference::InferenceOutput
Definition: ISaltModel.h:38
FlavorTagInference::OutputConfig
std::vector< SaltModelOutput > OutputConfig
Definition: ISaltModel.h:36
FlavorTagInference::SaltModelVersion
SaltModelVersion
Definition: ISaltModel.h:25
FlavorTagInference::SaltModelOutput::OutputType::FLOAT
@ FLOAT
trigbs_mixBSevents.input
input
Definition: trigbs_mixBSevents.py:56
FlavorTagInference::SaltModelTriton::m_model_type
std::string m_model_type
Definition: SaltModelTriton.h:54
FlavorTagInference::SaltModelTriton::m_options
std::unique_ptr< tc::InferOptions > m_options
Definition: SaltModelTriton.h:59
FlavorTagInference::SaltModelTriton::getClient
tc::InferenceServerGrpcClient * getClient() const
Definition: SaltModelTriton.cxx:297
L1Topo::Error
Error
The different types of error that can be flagged in the L1TopoRDO.
Definition: Error.h:16
FlavorTagInference::SaltModelTriton::SaltModelTriton
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL)
Definition: SaltModelTriton.cxx:79
FlavorTagInference::SaltModelGraphConfig::GraphConfig
Definition: SaltModelGraphConfig.h:36
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32
FlavorTagInference::SaltModelOutput::OutputType::VECFLOAT
@ VECFLOAT
checker_macros.h
Define macros for attributes used to control the static checker.
TSU::T
unsigned long long T
Definition: L1TopoDataTypes.h:35
SaltModelGraphConfig.h
FlavorTagInference::SaltModelTriton::m_port
int m_port
Definition: SaltModelTriton.h:61
python.LArMinBiasAlgConfig.float
float
Definition: LArMinBiasAlgConfig.py:65
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37