ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSONNXHandler.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5// See header file for documentation.
7
8// For reading the binary onnx files
9#include <fstream>
10#include <iterator>
11#include <vector>
12
13// ONNX Runtime include(s).
14#include <onnxruntime_cxx_api.h>
15
16// For reading and writing to root
17#include "TBranch.h"
18#include "TFile.h"
19#include "TTree.h"
20
21// For throwing exceptions
22#include <stdexcept>
23
24TFCSONNXHandler::TFCSONNXHandler(const std::string &inputFile)
25 : VNetworkBase(inputFile) {
26 ATH_MSG_INFO("Setting up from inputFile.");
29 ATH_MSG_DEBUG("Setup from file complete");
30};
31
32TFCSONNXHandler::TFCSONNXHandler(const std::vector<char> &bytes)
33 : m_bytes(bytes) {
34 ATH_MSG_INFO("Given onnx session bytes as input.");
35 // The super constructor got no inputFile,
36 // so it won't call setupNet itself
38 ATH_MSG_DEBUG("Setup from session complete");
39};
40
42 : VNetworkBase(copy_from) {
43 ATH_MSG_DEBUG("TFCSONNXHandler copy construtor called");
44 m_bytes = copy_from.m_bytes;
45 // Cannot copy a session
46 // m_session = copy_from.m_session;
47 // But can read it from bytes
54};
55
58 return m_computeLambda(inputs);
59};
60
61// Writing out to ttrees
63 ATH_MSG_DEBUG("TFCSONNXHandler writing net to tree.");
64 this->writeBytesToTTree(tree, m_bytes);
65};
66
67std::vector<std::string> TFCSONNXHandler::getOutputLayers() const {
68 ATH_MSG_DEBUG("TFCSONNXHandler output layers requested.");
69 return m_outputLayers;
70};
71
73 // As we don't copy the bytes, and the inputFile
74 // is at most a name, nothing is needed here.
75 ATH_MSG_DEBUG("Deleted nothing for ONNX.");
76};
77
78void TFCSONNXHandler::print(std::ostream &strm) const {
79 if (m_inputFile.empty()) {
80 strm << "Unknown network";
81 } else {
82 strm << m_inputFile;
83 };
84 strm << "\nHas input nodes (name:dimensions);\n";
85 for (size_t inp_n = 0; inp_n < m_inputNodeNames.size(); inp_n++) {
86 strm << "\t" << m_inputNodeNames[inp_n] << ":[";
87 for (int dim : m_inputNodeDims[inp_n]) {
88 strm << " " << dim << ",";
89 };
90 strm << "]\n";
91 };
92 strm << "\nHas output nodes (name:dimensions);\n";
93 for (size_t out_n = 0; out_n < m_outputNodeNames.size(); out_n++) {
94 strm << "\t" << m_outputNodeNames[out_n] << ":[";
95 for (int dim : m_outputNodeDims[out_n]) {
96 strm << " " << dim << ",";
97 };
98 strm << "]\n";
99 };
100};
101
103 ATH_MSG_DEBUG("Setting up persisted variables for ONNX network.");
104 // depending which constructor was called,
105 // bytes may already be filled
106 if (m_bytes.empty()) {
108 };
109 ATH_MSG_DEBUG("Setup persisted variables for ONNX network.");
110};
111
113 // From
114 // https://gitlab.cern.ch/atlas/athena/-/blob/master/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
115 // m_session = AthONNX::CreateORTSession(inputFile);
116 // This segfaults.
117
118 // TODO; should I be using m_session_options? see
119 // https://github.com/microsoft/onnxruntime-inference-examples/blob/2b42b442526b9454d1e2d08caeb403e28a71da5f/c_cxx/squeezenet/main.cpp#L71
120 ATH_MSG_INFO("Setting up ONNX session.");
121 this->readSerializedSession();
122
123 // Need the type from the first node (which will be used to set
124 // just set it to undefined to avoid not initialised warnings
125 ONNXTensorElementDataType first_input_type =
126 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
127
128 // iterate over all input nodes
129 ATH_MSG_DEBUG("Getting input nodes.");
130 const int num_input_nodes = m_session->GetInputCount();
131 Ort::AllocatorWithDefaultOptions allocator;
132 for (int i = 0; i < num_input_nodes; i++) {
133
134#if ORT_API_VERSION > 11
135 Ort::AllocatedStringPtr node_names = m_session->GetInputNameAllocated(i, allocator);
136 m_storeInputNodeNames.push_back(std::move(node_names));
137 const char *input_name = m_storeInputNodeNames.back().get();
138#else
139 const char *input_name = m_session->GetInputName(i, allocator);
140#endif
141 m_inputNodeNames.push_back(input_name);
142 ATH_MSG_VERBOSE("Found input node named " << input_name);
143
144 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
145
146 // For some reason unless auto is used as the return type
147 // this causes a segfault once the loop ends....
148 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
149 if (i == 0)
150 first_input_type = tensor_info.GetElementType();
151 // Check the type has not changed
152 if (tensor_info.GetElementType() != first_input_type) {
153 ATH_MSG_ERROR("First type was " << first_input_type << ". In node " << i
154 << " found type "
155 << tensor_info.GetElementType());
156 throw std::runtime_error("Networks with varying input types not "
157 "yet impelmented in TFCSONNXHandler.");
158 };
159
160 std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
161 ATH_MSG_VERBOSE("There are " << recieved_dimension.size()
162 << " dimensions.");
163 // This vector sometimes includes a symbolic dimension
164 // which is represented by -1
165 // A symbolic dimension is usually a conversion error,
166 // from a numpy array with a shape like (None, 7),
167 // in which case it's safe to treat it as having
168 // dimension 1.
169 std::vector<int64_t> dimension_of_node;
170 for (int64_t node_dim : recieved_dimension) {
171 if (node_dim < 1) {
172 ATH_MSG_WARNING("Found symbolic dimension "
173 << node_dim << " in node named " << input_name
174 << ". Will treat this as dimension 1.");
175 dimension_of_node.push_back(1);
176 } else {
177 dimension_of_node.push_back(node_dim);
178 };
179 };
180 m_inputNodeDims.push_back(std::move(dimension_of_node));
181 };
182 ATH_MSG_DEBUG("Finished looping on inputs.");
183
184 // Outputs
185 // Store the type from the first node (which will be used to set
186 // m_computeLambda)
187 ONNXTensorElementDataType first_output_type =
188 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
189
190 // iterate over all output nodes
191 int num_output_nodes = m_session->GetOutputCount();
192 ATH_MSG_DEBUG("Getting " << num_output_nodes << " output nodes.");
193 for (int i = 0; i < num_output_nodes; i++) {
194#if ORT_API_VERSION > 11
195 Ort::AllocatedStringPtr node_names = m_session->GetOutputNameAllocated(i, allocator);
196 m_storeOutputNodeNames.push_back(std::move(node_names));
197 const char *output_name = m_storeOutputNodeNames.back().get();
198#else
199 const char *output_name = m_session->GetOutputName(i, allocator);
200#endif
201 m_outputNodeNames.push_back(output_name);
202 ATH_MSG_VERBOSE("Found output node named " << output_name);
203
204 const Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
205 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
206 if (i == 0)
207 first_output_type = tensor_info.GetElementType();
208
209 // Check the type has not changed
210 if (tensor_info.GetElementType() != first_output_type) {
211 ATH_MSG_ERROR("First type was " << first_output_type << ". In node " << i
212 << " found type "
213 << tensor_info.GetElementType());
214 throw std::runtime_error("Networks with varying output types not "
215 "yet impelmented in TFCSONNXHandler.");
216 };
217
218 const std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
219 ATH_MSG_VERBOSE("There are " << recieved_dimension.size()
220 << " dimensions.");
221 // Again, check for sybolic dimensions
222 std::vector<int64_t> dimension_of_node;
223 int node_size = 1;
224 for (int64_t node_dim : recieved_dimension) {
225 if (node_dim < 1) {
226 ATH_MSG_WARNING("Found symbolic dimension "
227 << node_dim << " in node named " << output_name
228 << ". Will treat this as dimension 1.");
229 dimension_of_node.push_back(1);
230 } else {
231 dimension_of_node.push_back(node_dim);
232 node_size *= node_dim;
233 };
234 };
235 m_outputNodeDims.push_back(std::move(dimension_of_node));
236 m_outputNodeSize.push_back(node_size);
237
238 // The outputs are treated as a flat vector
239 for (int part_n = 0; part_n < node_size; part_n++) {
240 // compose the output name
241 std::string layer_name =
242 std::string(output_name) + "_" + std::to_string(part_n);
243 ATH_MSG_VERBOSE("Found output layer named " << layer_name);
244 m_outputLayers.push_back(std::move(layer_name));
245 }
246 }
247 ATH_MSG_DEBUG("Removing prefix from stored layers.");
249 ATH_MSG_DEBUG("Finished output nodes.");
250
251 ATH_MSG_DEBUG("Setting up m_computeLambda with input type "
252 << first_input_type << " and output type "
253 << first_output_type);
254 if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
255 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
256 // gotta capture this in the lambda so it can access class methods
257 m_computeLambda = [this](NetworkInputs const &inputs) {
258 return computeTemplate<float, float>(inputs);
259 };
260 } else if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE &&
261 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
262 m_computeLambda = [this](NetworkInputs const &inputs) {
263 return computeTemplate<double, double>(inputs);
264 };
265 } else {
266 throw std::runtime_error("Haven't yet implemented that combination of "
267 "input and output types as a subclass of VState.");
268 };
269 ATH_MSG_DEBUG("Finished setting lambda function.");
270};
271
272// Needs to also work if the input file is a root file
273std::vector<char> TFCSONNXHandler::getSerializedSession(const std::string& tree_name) {
274 ATH_MSG_DEBUG("Getting serialized session for ONNX network.");
275
276 if (this->isRootFile()) {
277 ATH_MSG_INFO("Reading bytes from root file.");
278 TFile tfile(this->m_inputFile.c_str(), "READ");
279 TTree *tree = (TTree *)tfile.Get(tree_name.c_str());
280 std::vector<char> bytes = this->readBytesFromTTree(*tree);
281 ATH_MSG_DEBUG("Found bytes size " << bytes.size());
282 return bytes;
283 } else {
284 ATH_MSG_INFO("Reading bytes from text file.");
285 // see https://stackoverflow.com/a/50317432
286 std::ifstream input(this->m_inputFile, std::ios::binary);
287
288 std::vector<char> bytes((std::istreambuf_iterator<char>(input)),
289 (std::istreambuf_iterator<char>()));
290
291 input.close();
292 ATH_MSG_DEBUG("Found bytes size " << bytes.size());
293 return bytes;
294 }
295};
296
297std::vector<char> TFCSONNXHandler::readBytesFromTTree(TTree &tree) {
298 ATH_MSG_DEBUG("TFCSONNXHandler reading bytes from tree.");
299 std::vector<char> bytes;
300 char data;
301 tree.SetBranchAddress("serialized_m_session", &data);
302 for (int i = 0; tree.LoadTree(i) >= 0; i++) {
303 tree.GetEntry(i);
304 bytes.push_back(data);
305 };
306 ATH_MSG_DEBUG("TFCSONNXHandler read bytes from tree.");
307 return bytes;
308};
309
311 const std::vector<char> &bytes) {
312 ATH_MSG_DEBUG("TFCSONNXHandler writing bytes to tree.");
313 char m_session_data;
314 tree.Branch("serialized_m_session", &m_session_data,
315 "serialized_m_session/B");
316 for (Char_t here : bytes) {
317 m_session_data = here;
318 tree.Fill();
319 };
320 tree.Write();
321 ATH_MSG_DEBUG("TFCSONNXHandler written bytes to tree.");
322};
323
325 ATH_MSG_DEBUG("Transforming bytes to session.");
326 Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
327 Ort::SessionOptions opts;
328 opts.SetInterOpNumThreads(1);
329 opts.SetIntraOpNumThreads(1);
330 // Prevent ONNX from spawning additional threads
331 m_session =
332 std::make_unique<Ort::Session>(env, m_bytes.data(), m_bytes.size(), opts);
333 ATH_MSG_DEBUG("Transformed bytes to session.");
334};
335
336template <typename Tin, typename Tout>
339 // working from
340 // https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/squeezenet/main.cpp#L71
341 // and
342 // https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/MNIST/MNIST.cpp
343 ATH_MSG_DEBUG("Setting up inputs for computation on ONNX network.");
344 ATH_MSG_DEBUG("Input type " << typeid(Tin).name() << " output type "
345 << typeid(Tout).name());
346
347 // The inputs must be reformatted to the correct data structure.
348 const size_t num_input_nodes = m_inputNodeNames.size();
349 // A pointer to all the nodes we will make
350 // Gonna keep the data in each node flat, becuase that's easier
351 std::vector<std::vector<Tin>> input_values(num_input_nodes);
352 std::vector<Ort::Value> node_values;
353 // Non const values that will be needed at each step.
354 std::string node_name;
355 int n_dimensions, elements_in_node, key_number;
356 size_t first_digit;
357 // Move along the list of node names gathered in the constructor
358 // we need both the node name, and the dimension
359 // so we cannot itterate directly on the vector.
360 ATH_MSG_DEBUG("Looping over " << num_input_nodes
361 << " input nodes of ONNX network.");
362 for (size_t node_n = 0; node_n < m_inputNodeNames.size(); node_n++) {
363 ATH_MSG_DEBUG("Node n = " << node_n);
364 node_name = m_inputNodeNames[node_n];
365 ATH_MSG_DEBUG("Node name " << node_name);
366 // Get the shape of this node
367 n_dimensions = m_inputNodeDims[node_n].size();
368 ATH_MSG_DEBUG("Node dimensions " << n_dimensions);
369 elements_in_node = 1;
370 for (int dimension_len : m_inputNodeDims[node_n]) {
371 elements_in_node *= dimension_len;
372 };
373 ATH_MSG_DEBUG("Elements in node " << elements_in_node);
374 for (const auto & inp : inputs) {
375 ATH_MSG_DEBUG("Have input named " << inp.first);
376 };
377 // Get the node content and remove any common prefix from the elements
378 const std::map<std::string, double> node_inputs = inputs.at(node_name);
379 std::vector<Tin> node_elements(elements_in_node);
380
381 ATH_MSG_DEBUG("Found node named " << node_name << " with "
382 << elements_in_node << " elements.");
383 // Then the rest should be numbers from 0 up
384 for (const auto & element : node_inputs){
385 first_digit = element.first.find_first_of("0123456789");
386 // if there is no digit, it's not an element
387 if (first_digit < element.first.length()){
388 key_number = std::stoi(element.first.substr(first_digit));
389 node_elements[key_number] = element.second;
390 }
391 }
392 input_values[node_n] = std::move(node_elements);
393
394 ATH_MSG_DEBUG("Creating ort tensor n_dimensions = "
395 << n_dimensions
396 << ", elements_in_node = " << elements_in_node);
397 // Doesn't copy data internally, so vector arguments need to stay alive
398 Ort::Value node = Ort::Value::CreateTensor<Tin>(
399 m_memoryInfo, input_values[node_n].data(), elements_in_node,
400 m_inputNodeDims[node_n].data(), n_dimensions);
401 // Problems with the string steam when compiling seperatly.
402 // ATH_MSG_DEBUG("Created input node " << node << " from values " <<
403 // input_values[node_n]);
404
405 node_values.push_back(std::move(node));
406 }
407
408 ATH_MSG_DEBUG("Running computation on ONNX network.");
409 // All inputs have been correctly formatted and the net can be run.
410 auto output_tensors = m_session->Run(
411 Ort::RunOptions{nullptr}, m_inputNodeNames.data(), &node_values[0],
412 num_input_nodes, m_outputNodeNames.data(), m_outputNodeNames.size());
413
414 ATH_MSG_DEBUG("Sorting outputs from computation on ONNX network.");
415 // Finaly, the output must be rearanged in the expected format.
417 // as the output format is just a string to double map
418 // the outputs will be keyed like "<node_name>_<part_n>"
419 std::string output_name;
420 const Tout *output_node;
421 for (size_t node_n = 0; node_n < m_outputNodeNames.size(); node_n++) {
422 // get a pointer to the data
423 output_node = output_tensors[node_n].GetTensorMutableData<Tout>();
424 ATH_MSG_VERBOSE("output node " << output_node);
425 elements_in_node = m_outputNodeSize[node_n];
426 node_name = m_outputNodeNames[node_n];
427 // Does the GetTensorMutableData really always return a
428 // flat array?
429 // Likely yes, see use of memcopy on line 301 of
430 // onnxruntime/core/languge_interop_ops/pyop/pyop.cc
431 for (int part_n = 0; part_n < elements_in_node; part_n++) {
432 ATH_MSG_VERBOSE("Node part " << part_n << " contains "
433 << output_node[part_n]);
434 // compose the output name
435 output_name = node_name + "_" + std::to_string(part_n);
436 outputs[output_name] = static_cast<double>(output_node[part_n]);
437 }
438 }
439 removePrefixes(outputs);
440 ATH_MSG_DEBUG("Returning outputs from computation on ONNX network.");
441 return outputs;
442};
443
444// Possible to avoid copy?
445// https://github.com/microsoft/onnxruntime/issues/8328
446// https://github.com/microsoft/onnxruntime/pull/11789
447// https://github.com/microsoft/onnxruntime/pull/8502
448
449// Giving this its own streamer to call setupNet
450void TFCSONNXHandler::Streamer(TBuffer &buf) {
451 ATH_MSG_DEBUG("In TFCSONNXHandler streamer.");
452 if (buf.IsReading()) {
453 ATH_MSG_INFO("Reading buffer in TFCSONNXHandler ");
454 // Get the persisted variables filled in
455 TFCSONNXHandler::Class()->ReadBuffer(buf, this);
456 // Setup the net, creating the non persisted variables
457 // exactly as in the constructor
458 this->setupNet();
459#ifndef __FastCaloSimStandAlone__
460 // When running inside Athena, delete persisted information
461 // to conserve memory
462 this->deleteAllButNet();
463#endif
464 } else {
465 ATH_MSG_INFO("Writing buffer in TFCSONNXHandler ");
466 // Persist variables
467 TFCSONNXHandler::Class()->WriteBuffer(buf, this);
468 };
469 ATH_MSG_DEBUG("Finished TFCSONNXHandler streamer.");
470};
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
void readSerializedSession()
Do not persistify.
virtual void print(std::ostream &strm) const override
Write a short description of this net to the string stream.
VNetworkBase()
VNetworkBase default constructor.
std::vector< const char * > m_inputNodeNames
names that index the input nodes
void setupPersistedVariables() override
Perform actions that prep data to create the net.
NetworkOutputs computeTemplate(NetworkInputs const &input)
Do not persistify.
std::vector< char > getSerializedSession(const std::string &tree_name=m_defaultTreeName)
Return content of the proto (.onnx) file in memory.
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
std::vector< const char * > m_outputNodeNames
Do not persistify.
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
std::vector< std::vector< int64_t > > m_outputNodeDims
Do not persistify.
std::unique_ptr< Ort::Session > m_session
The network session itself.
Ort::MemoryInfo m_memoryInfo
Do not persistify.
std::vector< std::string > m_outputLayers
Do not persistify.
TFCSONNXHandler(const std::string &inputFile)
TFCSONNXHandler constructor.
void setupNet() override
Perform actions that prepare network for use.
std::vector< char > readBytesFromTTree(TTree &tree)
Retrieve the content of the proto file from a TTree.
std::vector< int64_t > m_outputNodeSize
Do not persistify.
std::function< NetworkOutputs(NetworkInputs)> m_computeLambda
computeTemplate with apropreate types selected.
void writeNetToTTree(TTree &tree) override
Save the network to a TTree.
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
void writeBytesToTTree(TTree &tree, const std::vector< char > &bytes)
Write the content of the proto file to a TTree as a branch.
std::vector< char > m_bytes
Content of the proto file.
std::vector< std::vector< int64_t > > m_inputNodeDims
Do not persistify.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
std::string m_inputFile
Path to the file describing the network, including filename.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
Definition node.h:24
TChain * tree