ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSONNXHandler Class Reference

Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration. More...

#include <TFCSONNXHandler.h>

Inheritance diagram for TFCSONNXHandler:
Collaboration diagram for TFCSONNXHandler:

Public Types

typedef std::map< std::string, std::map< std::string, double > > NetworkInputs
 Format for network inputs.
typedef std::map< std::string, double > NetworkOutputs
 Format for network outputs.

Public Member Functions

 TFCSONNXHandler (const std::string &inputFile)
 TFCSONNXHandler constructor.
 TFCSONNXHandler (const std::vector< char > &bytes)
 TFCSONNXHandler constructor.
 TFCSONNXHandler (const TFCSONNXHandler &copy_from)
 TFCSONNXHandler copy constructor.
NetworkOutputs compute (NetworkInputs const &inputs) const override
 Function to pass values to the network.
void writeNetToTTree (TTree &tree) override
 Save the network to a TTree.
std::vector< std::string > getOutputLayers () const override
 List the names of the outputs.
void deleteAllButNet () override
 Get rid of any memory objects that arn't needed to run the net.
template<typename Tin, typename Tout>
VNetworkBase::NetworkOutputs computeTemplate (VNetworkBase::NetworkInputs const &inputs)
 VNetworkBase ()
 VNetworkBase default constructor.
 VNetworkBase (const std::string &inputFile)
 VNetworkBase constructor.
 VNetworkBase (const VNetworkBase &copy_from)
 VNetworkBase copy constructor.
void writeNetToTTree (TFile &root_file, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree.
void writeNetToTTree (std::string const &root_name, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree.
bool isFile () const
 Check if the argument inputFile is the path of a file on disk.
bool msgLvl (const MSG::Level lvl) const
 Check whether the logging system is active at the provided verbosity level.
MsgStream & msg () const
 Return a stream for sending messages directly (no decoration)
MsgStream & msg (const MSG::Level lvl) const
 Return a decorated starting stream for sending messages.
MSG::Level level () const
 Retrieve output level.
virtual void setLevel (MSG::Level lvl)
 Update outputlevel.

Static Public Member Functions

static std::string representNetworkInputs (NetworkInputs const &inputs, int maxValues=3)
 String representation of network inputs.
static std::string representNetworkOutputs (NetworkOutputs const &outputs, int maxValues=3)
 String representation of network outputs.
static bool isFile (std::string const &inputFile)
 Check if a string is the path of a file on disk.
static std::string startMsg (MSG::Level lvl, const std::string &file, int line)
 Make a message to decorate the start of logging.

Static Public Attributes

static const std::string m_defaultTreeName = "onnxruntime_session"
 Default name for the TTree to save in.

Protected Member Functions

virtual void print (std::ostream &strm) const override
 Write a short description of this net to the string stream.
void setupPersistedVariables () override
 Perform actions that prep data to create the net.
void setupNet () override
 Perform actions that prepare network for use.
bool isRootFile (std::string const &filename="") const
 Check if a string is possibly a root file path.
void removePrefixes (NetworkOutputs &outputs) const
 Remove any common prefix from the outputs.
void removePrefixes (std::vector< std::string > &output_names) const
 Remove any common prefix from the outputs.

Protected Attributes

std::string m_inputFile
 Path to the file describing the network, including filename.

Private Member Functions

std::vector< char > getSerializedSession (const std::string &tree_name=m_defaultTreeName)
 Return content of the proto (.onnx) file in memory.
std::vector< char > readBytesFromTTree (TTree &tree)
 Retrieve the content of the proto file from a TTree.
void writeBytesToTTree (TTree &tree, const std::vector< char > &bytes)
 Write the content of the proto file to a TTree as a branch.
void readSerializedSession ()
 Do not persistify.
template<typename Tin, typename Tout>
NetworkOutputs computeTemplate (NetworkInputs const &input)
 Do not persistify.
 ClassDefOverride (TFCSONNXHandler, 1)
 Do not persistify.
 ClassDef (VNetworkBase, 1)

Private Attributes

std::vector< char > m_bytes
 Content of the proto file.
std::unique_ptr< Ort::Session > m_session
 The network session itself.
std::vector< const char * > m_inputNodeNames
 names that index the input nodes
std::vector< const char * > m_outputNodeNames
 Do not persistify.
std::vector< std::vector< int64_t > > m_inputNodeDims
 Do not persistify.
std::vector< std::vector< int64_t > > m_outputNodeDims
 Do not persistify.
std::vector< int64_t > m_outputNodeSize
 Do not persistify.
std::function< NetworkOutputs(NetworkInputs)> m_computeLambda
 computeTemplate with apropreate types selected.
Ort::MemoryInfo m_memoryInfo
 Do not persistify.
std::vector< std::string > m_outputLayers
 Do not persistify.
std::string m_nm
 Message source name.

Static Private Attributes

static boost::thread_specific_ptr< MsgStream > m_msg_tls ATLAS_THREAD_SAFE
 Do not persistify!

Detailed Description

Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration.

Class for a neural network read in the ONNX format. Derived from the abstract base class VNetworkBase such that it can be used interchangably with it's sibling classes, TFCSSimpleLWTNNHandler and TFCSGANLWTNNHandler.

The TFCSNetworkFactory::Create function has VNetworkBase as its return type so that it can make a run-time decision about which derived class to use, based on the data or file presented. As such it's best not to create this directly, instead allow TFCSNetworkFactory::Create to create the appropreat network object so that new network formats can be accomidated by writing new subclasses of VNetworkBase.

A handler specific for an ONNX network

Inherits from the generic interface VNetworkBase, such that it cna be used interchangably with other network formats and libraries.

Definition at line 43 of file TFCSONNXHandler.h.

Member Typedef Documentation

◆ NetworkInputs

typedef std::map<std::string, std::map<std::string, double> > VNetworkBase::NetworkInputs
inherited

Format for network inputs.

The doubles are the values to be passed into the network. Strings in the outer map identify the input node, which must corrispond to the names of the nodes as read from the description of the network found by the constructor. Strings in the inner map identify the part of the input node, for some networks these must be simple integers, in string form, as parts of nodes do not always have the ability to carry real string labels.

Definition at line 90 of file VNetworkBase.h.

◆ NetworkOutputs

typedef std::map<std::string, double> VNetworkBase::NetworkOutputs
inherited

Format for network outputs.

The doubles are the values generated by the network. Strings identify which node this value came from, and when nodes have multiple values, are suffixed with a number to indicate which part of the node they came from. So in multi-value nodes the format becomes "<node_name>_<part_n>"

Definition at line 100 of file VNetworkBase.h.

Constructor & Destructor Documentation

◆ TFCSONNXHandler() [1/3]

TFCSONNXHandler::TFCSONNXHandler ( const std::string & inputFile)
explicit

TFCSONNXHandler constructor.

Calls setupPersistedVariables and setupNet.

Parameters
inputFilefile-path on disk (with file name) of a readable onnx file containing a proto format description of the network to be constructed.

Definition at line 24 of file TFCSONNXHandler.cxx.

25 : VNetworkBase(inputFile) {
26 ATH_MSG_INFO("Setting up from inputFile.");
29 ATH_MSG_DEBUG("Setup from file complete");
30};
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
VNetworkBase()
VNetworkBase default constructor.
void setupPersistedVariables() override
Perform actions that prep data to create the net.
void setupNet() override
Perform actions that prepare network for use.

◆ TFCSONNXHandler() [2/3]

TFCSONNXHandler::TFCSONNXHandler ( const std::vector< char > & bytes)
explicit

TFCSONNXHandler constructor.

As this passes nothing to the super constructor the setupPersistedVariables will not be called.

Parameters
bytesbyte content of a .onnx file, (which are a subset if proto files). Allows TFCSONNXHandler objects to be created from data in memory, retrieved from any source. The bytes are not copied interally, and must remain in memory while the net is in use. (TODO check that assertion)

Definition at line 32 of file TFCSONNXHandler.cxx.

33 : m_bytes(bytes) {
34 ATH_MSG_INFO("Given onnx session bytes as input.");
35 // The super constructor got no inputFile,
36 // so it won't call setupNet itself
38 ATH_MSG_DEBUG("Setup from session complete");
39};
std::vector< char > m_bytes
Content of the proto file.

◆ TFCSONNXHandler() [3/3]

TFCSONNXHandler::TFCSONNXHandler ( const TFCSONNXHandler & copy_from)

TFCSONNXHandler copy constructor.

Will copy the variables taht would be generated by setupPersistedVariables and setupNet.

Parameters
copy_fromexisting network that we are copying

Definition at line 41 of file TFCSONNXHandler.cxx.

42 : VNetworkBase(copy_from) {
43 ATH_MSG_DEBUG("TFCSONNXHandler copy construtor called");
44 m_bytes = copy_from.m_bytes;
45 // Cannot copy a session
46 // m_session = copy_from.m_session;
47 // But can read it from bytes
54};
void readSerializedSession()
Do not persistify.
std::vector< const char * > m_inputNodeNames
names that index the input nodes
std::vector< const char * > m_outputNodeNames
Do not persistify.
std::vector< std::vector< int64_t > > m_outputNodeDims
Do not persistify.
std::vector< std::string > m_outputLayers
Do not persistify.
std::vector< std::vector< int64_t > > m_inputNodeDims
Do not persistify.

Member Function Documentation

◆ ClassDef()

VNetworkBase::ClassDef ( VNetworkBase ,
1  )
privateinherited

◆ ClassDefOverride()

TFCSONNXHandler::ClassDefOverride ( TFCSONNXHandler ,
1  )
private

Do not persistify.

◆ compute()

TFCSONNXHandler::NetworkOutputs TFCSONNXHandler::compute ( TFCSONNXHandler::NetworkInputs const & inputs) const
overridevirtual

Function to pass values to the network.

This function is used to actually run data through the loaded network and obtain results.

Parameters
inputsvalues to be evaluated by the network
Returns
the output of the network
See also
VNetworkBase::NetworkInputs
VNetworkBase::NetworkOutputs

Implements VNetworkBase.

Definition at line 57 of file TFCSONNXHandler.cxx.

57 {
58 return m_computeLambda(inputs);
59};
std::function< NetworkOutputs(NetworkInputs)> m_computeLambda
computeTemplate with apropreate types selected.

◆ computeTemplate() [1/2]

template<typename Tin, typename Tout>
NetworkOutputs TFCSONNXHandler::computeTemplate ( NetworkInputs const & input)
private

Do not persistify.

Computation template with adjustable types for input.

A lambda function will be used to make the correct type choice for the session/net used as a member variable during setupNet.

◆ computeTemplate() [2/2]

template<typename Tin, typename Tout>
VNetworkBase::NetworkOutputs TFCSONNXHandler::computeTemplate ( VNetworkBase::NetworkInputs const & inputs)

Definition at line 338 of file TFCSONNXHandler.cxx.

338 {
339 // working from
340 // https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/squeezenet/main.cpp#L71
341 // and
342 // https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/MNIST/MNIST.cpp
343 ATH_MSG_DEBUG("Setting up inputs for computation on ONNX network.");
344 ATH_MSG_DEBUG("Input type " << typeid(Tin).name() << " output type "
345 << typeid(Tout).name());
346
347 // The inputs must be reformatted to the correct data structure.
348 const size_t num_input_nodes = m_inputNodeNames.size();
349 // A pointer to all the nodes we will make
350 // Gonna keep the data in each node flat, becuase that's easier
351 std::vector<std::vector<Tin>> input_values(num_input_nodes);
352 std::vector<Ort::Value> node_values;
353 // Non const values that will be needed at each step.
354 std::string node_name;
355 int n_dimensions, elements_in_node, key_number;
356 size_t first_digit;
357 // Move along the list of node names gathered in the constructor
358 // we need both the node name, and the dimension
359 // so we cannot itterate directly on the vector.
360 ATH_MSG_DEBUG("Looping over " << num_input_nodes
361 << " input nodes of ONNX network.");
362 for (size_t node_n = 0; node_n < m_inputNodeNames.size(); node_n++) {
363 ATH_MSG_DEBUG("Node n = " << node_n);
364 node_name = m_inputNodeNames[node_n];
365 ATH_MSG_DEBUG("Node name " << node_name);
366 // Get the shape of this node
367 n_dimensions = m_inputNodeDims[node_n].size();
368 ATH_MSG_DEBUG("Node dimensions " << n_dimensions);
369 elements_in_node = 1;
370 for (int dimension_len : m_inputNodeDims[node_n]) {
371 elements_in_node *= dimension_len;
372 };
373 ATH_MSG_DEBUG("Elements in node " << elements_in_node);
374 for (const auto & inp : inputs) {
375 ATH_MSG_DEBUG("Have input named " << inp.first);
376 };
377 // Get the node content and remove any common prefix from the elements
378 const std::map<std::string, double> node_inputs = inputs.at(node_name);
379 std::vector<Tin> node_elements(elements_in_node);
380
381 ATH_MSG_DEBUG("Found node named " << node_name << " with "
382 << elements_in_node << " elements.");
383 // Then the rest should be numbers from 0 up
384 for (const auto & element : node_inputs){
385 first_digit = element.first.find_first_of("0123456789");
386 // if there is no digit, it's not an element
387 if (first_digit < element.first.length()){
388 key_number = std::stoi(element.first.substr(first_digit));
389 node_elements[key_number] = element.second;
390 }
391 }
392 input_values[node_n] = std::move(node_elements);
393
394 ATH_MSG_DEBUG("Creating ort tensor n_dimensions = "
395 << n_dimensions
396 << ", elements_in_node = " << elements_in_node);
397 // Doesn't copy data internally, so vector arguments need to stay alive
398 Ort::Value node = Ort::Value::CreateTensor<Tin>(
399 m_memoryInfo, input_values[node_n].data(), elements_in_node,
400 m_inputNodeDims[node_n].data(), n_dimensions);
401 // Problems with the string steam when compiling seperatly.
402 // ATH_MSG_DEBUG("Created input node " << node << " from values " <<
403 // input_values[node_n]);
404
405 node_values.push_back(std::move(node));
406 }
407
408 ATH_MSG_DEBUG("Running computation on ONNX network.");
409 // All inputs have been correctly formatted and the net can be run.
410 auto output_tensors = m_session->Run(
411 Ort::RunOptions{nullptr}, m_inputNodeNames.data(), &node_values[0],
412 num_input_nodes, m_outputNodeNames.data(), m_outputNodeNames.size());
413
414 ATH_MSG_DEBUG("Sorting outputs from computation on ONNX network.");
415 // Finaly, the output must be rearanged in the expected format.
417 // as the output format is just a string to double map
418 // the outputs will be keyed like "<node_name>_<part_n>"
419 std::string output_name;
420 const Tout *output_node;
421 for (size_t node_n = 0; node_n < m_outputNodeNames.size(); node_n++) {
422 // get a pointer to the data
423 output_node = output_tensors[node_n].GetTensorMutableData<Tout>();
424 ATH_MSG_VERBOSE("output node " << output_node);
425 elements_in_node = m_outputNodeSize[node_n];
426 node_name = m_outputNodeNames[node_n];
427 // Does the GetTensorMutableData really always return a
428 // flat array?
429 // Likely yes, see use of memcopy on line 301 of
430 // onnxruntime/core/languge_interop_ops/pyop/pyop.cc
431 for (int part_n = 0; part_n < elements_in_node; part_n++) {
432 ATH_MSG_VERBOSE("Node part " << part_n << " contains "
433 << output_node[part_n]);
434 // compose the output name
435 output_name = node_name + "_" + std::to_string(part_n);
436 outputs[output_name] = static_cast<double>(output_node[part_n]);
437 }
438 }
439 removePrefixes(outputs);
440 ATH_MSG_DEBUG("Returning outputs from computation on ONNX network.");
441 return outputs;
442};
#define ATH_MSG_VERBOSE(x)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
std::unique_ptr< Ort::Session > m_session
The network session itself.
Ort::MemoryInfo m_memoryInfo
Do not persistify.
std::vector< int64_t > m_outputNodeSize
Do not persistify.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.

◆ deleteAllButNet()

void TFCSONNXHandler::deleteAllButNet ( )
overridevirtual

Get rid of any memory objects that arn't needed to run the net.

Minimise memory usage by deleting nay inputs that are no longer required to run the compute function. Doesn't actually do anything for this network type.

Implements VNetworkBase.

Definition at line 72 of file TFCSONNXHandler.cxx.

72 {
73 // As we don't copy the bytes, and the inputFile
74 // is at most a name, nothing is needed here.
75 ATH_MSG_DEBUG("Deleted nothing for ONNX.");
76};

◆ getOutputLayers()

std::vector< std::string > TFCSONNXHandler::getOutputLayers ( ) const
overridevirtual

List the names of the outputs.

Outputs are stored in an NetworkOutputs object which is indexed by strings. This function returns the list of all strings that will index the outputs.

Implements VNetworkBase.

Definition at line 67 of file TFCSONNXHandler.cxx.

67 {
68 ATH_MSG_DEBUG("TFCSONNXHandler output layers requested.");
69 return m_outputLayers;
70};

◆ getSerializedSession()

std::vector< char > TFCSONNXHandler::getSerializedSession ( const std::string & tree_name = m_defaultTreeName)
private

Return content of the proto (.onnx) file in memory.

Get the session as a stream of bytes It's a vector<char> rather than a string becuase we need the guarantee that &bytes[0]+n == bytes[n] (string has this only after c++11). Also bytes may not be terminated by a null byte (which early strings required).

Definition at line 273 of file TFCSONNXHandler.cxx.

273 {
274 ATH_MSG_DEBUG("Getting serialized session for ONNX network.");
275
276 if (this->isRootFile()) {
277 ATH_MSG_INFO("Reading bytes from root file.");
278 TFile tfile(this->m_inputFile.c_str(), "READ");
279 TTree *tree = (TTree *)tfile.Get(tree_name.c_str());
280 std::vector<char> bytes = this->readBytesFromTTree(*tree);
281 ATH_MSG_DEBUG("Found bytes size " << bytes.size());
282 return bytes;
283 } else {
284 ATH_MSG_INFO("Reading bytes from text file.");
285 // see https://stackoverflow.com/a/50317432
286 std::ifstream input(this->m_inputFile, std::ios::binary);
287
288 std::vector<char> bytes((std::istreambuf_iterator<char>(input)),
289 (std::istreambuf_iterator<char>()));
290
291 input.close();
292 ATH_MSG_DEBUG("Found bytes size " << bytes.size());
293 return bytes;
294 }
295};
std::vector< char > readBytesFromTTree(TTree &tree)
Retrieve the content of the proto file from a TTree.
std::string m_inputFile
Path to the file describing the network, including filename.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
TChain * tree

◆ isFile() [1/2]

bool VNetworkBase::isFile ( ) const
inherited

Check if the argument inputFile is the path of a file on disk.

Determines if the string that was passed to the constructor as inputFile corrisponds to tha path of a file that can be read on the disk.

Returns
is it a readable file on disk

Definition at line 117 of file VNetworkBase.cxx.

117{ return isFile(m_inputFile); };
bool isFile() const
Check if the argument inputFile is the path of a file on disk.

◆ isFile() [2/2]

bool VNetworkBase::isFile ( std::string const & inputFile)
staticinherited

Check if a string is the path of a file on disk.

Determines if a string corrisponds to tha path of a file that can be read on the disk.

Parameters
inputFilename of the pottential file
Returns
is it a readable file on disk

Definition at line 119 of file VNetworkBase.cxx.

119 {
120 if (FILE *file = std::fopen(inputFile.c_str(), "r")) {
121 std::fclose(file);
122 return true;
123 } else {
124 return false;
125 };
126};
TFile * file

◆ isRootFile()

bool VNetworkBase::isRootFile ( std::string const & filename = "") const
protectedinherited

Check if a string is possibly a root file path.

Just checks if the string ends in .root as there are almost no reliable rules for file paths.

Parameters
inputFilename of the pottential file if blank, m_inputFile is used.
Returns
is it the path of a root file

Definition at line 101 of file VNetworkBase.cxx.

101 {
102 const std::string *to_check = &filename;
103 if (filename.length() == 0) {
104 to_check = &this->m_inputFile;
105 ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile);
106 };
107 const std::string ending = ".root";
108 const int ending_len = ending.length();
109 const int filename_len = to_check->length();
110 if (filename_len < ending_len) {
111 return false;
112 }
113 return (0 ==
114 to_check->compare(filename_len - ending_len, ending_len, ending));
115};

◆ level()

MSG::Level ISF_FCS::MLogging::level ( ) const
inlineinherited

Retrieve output level.

Definition at line 201 of file MLogging.h.

201{ return msg().level(); }
MsgStream & msg() const
Return a stream for sending messages directly (no decoration)
Definition MLogging.h:231

◆ msg() [1/2]

MsgStream & ISF_FCS::MLogging::msg ( ) const
inlineinherited

Return a stream for sending messages directly (no decoration)

Definition at line 231 of file MLogging.h.

231 {
232 MsgStream *ms = m_msg_tls.get();
233 if (!ms) {
234 ms = new MsgStream(Athena::getMessageSvc(), m_nm);
235 m_msg_tls.reset(ms);
236 }
237 return *ms;
238}
std::string m_nm
Message source name.
Definition MLogging.h:211
IMessageSvc * getMessageSvc(bool quiet=false)

◆ msg() [2/2]

MsgStream & ISF_FCS::MLogging::msg ( const MSG::Level lvl) const
inlineinherited

Return a decorated starting stream for sending messages.

Definition at line 240 of file MLogging.h.

240 {
241 return msg() << lvl;
242}

◆ msgLvl()

bool ISF_FCS::MLogging::msgLvl ( const MSG::Level lvl) const
inlineinherited

Check whether the logging system is active at the provided verbosity level.

Definition at line 222 of file MLogging.h.

222 {
223 if (msg().level() <= lvl) {
224 msg() << lvl;
225 return true;
226 } else {
227 return false;
228 }
229}
MSG::Level level() const
Retrieve output level.
Definition MLogging.h:201

◆ print()

void TFCSONNXHandler::print ( std::ostream & strm) const
overrideprotectedvirtual

Write a short description of this net to the string stream.

Specialised for ONNX to print the input and output nodes with their dimensions.

Parameters
strmoutput parameter, to which the description will be written.

Reimplemented from VNetworkBase.

Definition at line 78 of file TFCSONNXHandler.cxx.

78 {
79 if (m_inputFile.empty()) {
80 strm << "Unknown network";
81 } else {
82 strm << m_inputFile;
83 };
84 strm << "\nHas input nodes (name:dimensions);\n";
85 for (size_t inp_n = 0; inp_n < m_inputNodeNames.size(); inp_n++) {
86 strm << "\t" << m_inputNodeNames[inp_n] << ":[";
87 for (int dim : m_inputNodeDims[inp_n]) {
88 strm << " " << dim << ",";
89 };
90 strm << "]\n";
91 };
92 strm << "\nHas output nodes (name:dimensions);\n";
93 for (size_t out_n = 0; out_n < m_outputNodeNames.size(); out_n++) {
94 strm << "\t" << m_outputNodeNames[out_n] << ":[";
95 for (int dim : m_outputNodeDims[out_n]) {
96 strm << " " << dim << ",";
97 };
98 strm << "]\n";
99 };
100};

◆ readBytesFromTTree()

std::vector< char > TFCSONNXHandler::readBytesFromTTree ( TTree & tree)
private

Retrieve the content of the proto file from a TTree.

If the ONNX file has been saved as a loose variable in a TTree this method will read it back into m_bytes.

Definition at line 297 of file TFCSONNXHandler.cxx.

297 {
298 ATH_MSG_DEBUG("TFCSONNXHandler reading bytes from tree.");
299 std::vector<char> bytes;
300 char data;
301 tree.SetBranchAddress("serialized_m_session", &data);
302 for (int i = 0; tree.LoadTree(i) >= 0; i++) {
303 tree.GetEntry(i);
304 bytes.push_back(data);
305 };
306 ATH_MSG_DEBUG("TFCSONNXHandler read bytes from tree.");
307 return bytes;
308};

◆ readSerializedSession()

void TFCSONNXHandler::readSerializedSession ( )
private

Do not persistify.

Using content of the proto (.onnx) file make a session.

The m_session variable is initialised from the m_bytes variable so that the net can be run. Requires that the m_bytes variable is retained while the net is used.

Definition at line 324 of file TFCSONNXHandler.cxx.

324 {
325 ATH_MSG_DEBUG("Transforming bytes to session.");
326 Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
327 Ort::SessionOptions opts;
328 opts.SetInterOpNumThreads(1);
329 opts.SetIntraOpNumThreads(1);
330 // Prevent ONNX from spawning additional threads
331 m_session =
332 std::make_unique<Ort::Session>(env, m_bytes.data(), m_bytes.size(), opts);
333 ATH_MSG_DEBUG("Transformed bytes to session.");
334};

◆ removePrefixes() [1/2]

void VNetworkBase::removePrefixes ( VNetworkBase::NetworkOutputs & outputs) const
protectedinherited

Remove any common prefix from the outputs.

Parameters
outputsThe outputs, changed in place.

Definition at line 151 of file VNetworkBase.cxx.

151 {
152 std::vector<std::string> output_layers;
153 for (auto const &output : outputs)
154 output_layers.push_back(output.first);
155 const int length = GetPrefixLength(output_layers);
156 for (std::string layer_name : output_layers) {
157 // remove this output
158 auto nodeHandle = outputs.extract(layer_name);
159 // change the key
160 nodeHandle.key() = layer_name.substr(length);
161 // replace the output
162 outputs.insert(std::move(nodeHandle));
163 }
164};
double length(const pvec &v)
output
Definition merge.py:16

◆ removePrefixes() [2/2]

void VNetworkBase::removePrefixes ( std::vector< std::string > & output_names) const
protectedinherited

Remove any common prefix from the outputs.

Parameters
outputsThe output names, changed in place.

Definition at line 144 of file VNetworkBase.cxx.

145 {
146 const int length = GetPrefixLength(output_names);
147 for (long unsigned int i = 0; i < output_names.size(); i++)
148 output_names[i] = output_names[i].substr(length);
149};

◆ representNetworkInputs()

std::string VNetworkBase::representNetworkInputs ( VNetworkBase::NetworkInputs const & inputs,
int maxValues = 3 )
staticinherited

String representation of network inputs.

Create a string that summarises a set of network inputs. Gives basic dimensions plus a few values, up to the maxValues

Parameters
inputsvalues to be evaluated by the network
maxValuesmaximum number of values to include in the representaiton
Returns
string represetning the inputs

Definition at line 37 of file VNetworkBase.cxx.

38 {
39 std::string representation =
40 "NetworkInputs, outer size " + std::to_string(inputs.size());
41 int valuesIncluded = 0;
42 for (const auto &outer : inputs) {
43 representation += "\n key->" + outer.first + "; ";
44 for (const auto &inner : outer.second) {
45 representation += inner.first + "=" + std::to_string(inner.second) + ", ";
46 ++valuesIncluded;
47 if (valuesIncluded > maxValues)
48 break;
49 };
50 if (valuesIncluded > maxValues)
51 break;
52 };
53 representation += "\n";
54 return representation;
55};

◆ representNetworkOutputs()

std::string VNetworkBase::representNetworkOutputs ( VNetworkBase::NetworkOutputs const & outputs,
int maxValues = 3 )
staticinherited

String representation of network outputs.

Create a string that summarises a set of network outputs. Gives basic dimensions plus a few values, up to the maxValues

Parameters
outputsoutput of the network
maxValuesmaximum number of values to include in the representaiton
Returns
string represetning the outputs

Definition at line 57 of file VNetworkBase.cxx.

58 {
59 std::string representation =
60 "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n";
61 int valuesIncluded = 0;
62 for (const auto &item : outputs) {
63 representation += item.first + "=" + std::to_string(item.second) + ", ";
64 ++valuesIncluded;
65 if (valuesIncluded > maxValues)
66 break;
67 };
68 representation += "\n";
69 return representation;
70};

◆ setLevel()

void ISF_FCS::MLogging::setLevel ( MSG::Level lvl)
virtualinherited

Update outputlevel.

Definition at line 105 of file MLogging.cxx.

105 {
106 lvl = (lvl >= MSG::NUM_LEVELS) ? MSG::ALWAYS
107 : (lvl < MSG::NIL) ? MSG::NIL
108 : lvl;
109 msg().setLevel(lvl);
110}

◆ setupNet()

void TFCSONNXHandler::setupNet ( )
overrideprotectedvirtual

Perform actions that prepare network for use.

Will be called in the streamer or class constructor after the inputs have been set (either automaically by the streamer or by setupPersistedVariables in the constructor). Does not delete any resources used.

Implements VNetworkBase.

Definition at line 112 of file TFCSONNXHandler.cxx.

112 {
113 // From
114 // https://gitlab.cern.ch/atlas/athena/-/blob/master/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
115 // m_session = AthONNX::CreateORTSession(inputFile);
116 // This segfaults.
117
118 // TODO; should I be using m_session_options? see
119 // https://github.com/microsoft/onnxruntime-inference-examples/blob/2b42b442526b9454d1e2d08caeb403e28a71da5f/c_cxx/squeezenet/main.cpp#L71
120 ATH_MSG_INFO("Setting up ONNX session.");
121 this->readSerializedSession();
122
123 // Need the type from the first node (which will be used to set
124 // just set it to undefined to avoid not initialised warnings
125 ONNXTensorElementDataType first_input_type =
126 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
127
128 // iterate over all input nodes
129 ATH_MSG_DEBUG("Getting input nodes.");
130 const int num_input_nodes = m_session->GetInputCount();
131 Ort::AllocatorWithDefaultOptions allocator;
132 for (int i = 0; i < num_input_nodes; i++) {
133
134#if ORT_API_VERSION > 11
135 Ort::AllocatedStringPtr node_names = m_session->GetInputNameAllocated(i, allocator);
136 m_storeInputNodeNames.push_back(std::move(node_names));
137 const char *input_name = m_storeInputNodeNames.back().get();
138#else
139 const char *input_name = m_session->GetInputName(i, allocator);
140#endif
141 m_inputNodeNames.push_back(input_name);
142 ATH_MSG_VERBOSE("Found input node named " << input_name);
143
144 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
145
146 // For some reason unless auto is used as the return type
147 // this causes a segfault once the loop ends....
148 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
149 if (i == 0)
150 first_input_type = tensor_info.GetElementType();
151 // Check the type has not changed
152 if (tensor_info.GetElementType() != first_input_type) {
153 ATH_MSG_ERROR("First type was " << first_input_type << ". In node " << i
154 << " found type "
155 << tensor_info.GetElementType());
156 throw std::runtime_error("Networks with varying input types not "
157 "yet impelmented in TFCSONNXHandler.");
158 };
159
160 std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
161 ATH_MSG_VERBOSE("There are " << recieved_dimension.size()
162 << " dimensions.");
163 // This vector sometimes includes a symbolic dimension
164 // which is represented by -1
165 // A symbolic dimension is usually a conversion error,
166 // from a numpy array with a shape like (None, 7),
167 // in which case it's safe to treat it as having
168 // dimension 1.
169 std::vector<int64_t> dimension_of_node;
170 for (int64_t node_dim : recieved_dimension) {
171 if (node_dim < 1) {
172 ATH_MSG_WARNING("Found symbolic dimension "
173 << node_dim << " in node named " << input_name
174 << ". Will treat this as dimension 1.");
175 dimension_of_node.push_back(1);
176 } else {
177 dimension_of_node.push_back(node_dim);
178 };
179 };
180 m_inputNodeDims.push_back(std::move(dimension_of_node));
181 };
182 ATH_MSG_DEBUG("Finished looping on inputs.");
183
184 // Outputs
185 // Store the type from the first node (which will be used to set
186 // m_computeLambda)
187 ONNXTensorElementDataType first_output_type =
188 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
189
190 // iterate over all output nodes
191 int num_output_nodes = m_session->GetOutputCount();
192 ATH_MSG_DEBUG("Getting " << num_output_nodes << " output nodes.");
193 for (int i = 0; i < num_output_nodes; i++) {
194#if ORT_API_VERSION > 11
195 Ort::AllocatedStringPtr node_names = m_session->GetOutputNameAllocated(i, allocator);
196 m_storeOutputNodeNames.push_back(std::move(node_names));
197 const char *output_name = m_storeOutputNodeNames.back().get();
198#else
199 const char *output_name = m_session->GetOutputName(i, allocator);
200#endif
201 m_outputNodeNames.push_back(output_name);
202 ATH_MSG_VERBOSE("Found output node named " << output_name);
203
204 const Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
205 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
206 if (i == 0)
207 first_output_type = tensor_info.GetElementType();
208
209 // Check the type has not changed
210 if (tensor_info.GetElementType() != first_output_type) {
211 ATH_MSG_ERROR("First type was " << first_output_type << ". In node " << i
212 << " found type "
213 << tensor_info.GetElementType());
214 throw std::runtime_error("Networks with varying output types not "
215 "yet impelmented in TFCSONNXHandler.");
216 };
217
218 const std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
219 ATH_MSG_VERBOSE("There are " << recieved_dimension.size()
220 << " dimensions.");
221 // Again, check for sybolic dimensions
222 std::vector<int64_t> dimension_of_node;
223 int node_size = 1;
224 for (int64_t node_dim : recieved_dimension) {
225 if (node_dim < 1) {
226 ATH_MSG_WARNING("Found symbolic dimension "
227 << node_dim << " in node named " << output_name
228 << ". Will treat this as dimension 1.");
229 dimension_of_node.push_back(1);
230 } else {
231 dimension_of_node.push_back(node_dim);
232 node_size *= node_dim;
233 };
234 };
235 m_outputNodeDims.push_back(std::move(dimension_of_node));
236 m_outputNodeSize.push_back(node_size);
237
238 // The outputs are treated as a flat vector
239 for (int part_n = 0; part_n < node_size; part_n++) {
240 // compose the output name
241 std::string layer_name =
242 std::string(output_name) + "_" + std::to_string(part_n);
243 ATH_MSG_VERBOSE("Found output layer named " << layer_name);
244 m_outputLayers.push_back(std::move(layer_name));
245 }
246 }
247 ATH_MSG_DEBUG("Removing prefix from stored layers.");
249 ATH_MSG_DEBUG("Finished output nodes.");
250
251 ATH_MSG_DEBUG("Setting up m_computeLambda with input type "
252 << first_input_type << " and output type "
253 << first_output_type);
254 if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
255 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
256 // gotta capture this in the lambda so it can access class methods
257 m_computeLambda = [this](NetworkInputs const &inputs) {
258 return computeTemplate<float, float>(inputs);
259 };
260 } else if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE &&
261 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
262 m_computeLambda = [this](NetworkInputs const &inputs) {
263 return computeTemplate<double, double>(inputs);
264 };
265 } else {
266 throw std::runtime_error("Haven't yet implemented that combination of "
267 "input and output types as a subclass of VState.");
268 };
269 ATH_MSG_DEBUG("Finished setting lambda function.");
270};
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
NetworkOutputs computeTemplate(NetworkInputs const &input)
Do not persistify.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.

◆ setupPersistedVariables()

void TFCSONNXHandler::setupPersistedVariables ( )
overrideprotectedvirtual

Perform actions that prep data to create the net.

Will be called in the class constructor before calling setupNet, but not in the streamer. It sets any variables that the sreamer would persist when saving or loading to file.

Implements VNetworkBase.

Definition at line 102 of file TFCSONNXHandler.cxx.

102 {
103 ATH_MSG_DEBUG("Setting up persisted variables for ONNX network.");
104 // depending which constructor was called,
105 // bytes may already be filled
106 if (m_bytes.empty()) {
108 };
109 ATH_MSG_DEBUG("Setup persisted variables for ONNX network.");
110};
std::vector< char > getSerializedSession(const std::string &tree_name=m_defaultTreeName)
Return content of the proto (.onnx) file in memory.

◆ startMsg()

std::string ISF_FCS::MLogging::startMsg ( MSG::Level lvl,
const std::string & file,
int line )
staticinherited

Make a message to decorate the start of logging.

Print a message for the start of logging.

Definition at line 116 of file MLogging.cxx.

116 {
117 int col1_len = 20;
118 int col2_len = 5;
119 int col3_len = 10;
120 auto last_slash = file.find_last_of('/');
121 int path_len = last_slash == std::string::npos ? 0 : last_slash;
122 int trim_point = path_len;
123 int total_len = file.length();
124 if (total_len - path_len > col1_len)
125 trim_point = total_len - col1_len;
126 std::string trimmed_name = file.substr(trim_point);
127 const char *LevelNames[MSG::NUM_LEVELS] = {
128 "NIL", "VERBOSE", "DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "ALWAYS"};
129 std::string level = LevelNames[lvl];
130 std::string level_string = std::string("(") + level + ") ";
131 std::stringstream output;
132 output << std::setw(col1_len) << std::right << trimmed_name << ":"
133 << std::setw(col2_len) << std::left << line << std::setw(col3_len)
134 << std::right << level_string;
135 return output.str();
136}

◆ VNetworkBase() [1/3]

VNetworkBase::VNetworkBase ( )

VNetworkBase default constructor.

For use in streamers.

Definition at line 45 of file VNetworkBase.cxx.

16: m_inputFile("unknown"){};

◆ VNetworkBase() [2/3]

VNetworkBase::VNetworkBase ( const std::string & inputFile)
explicit

VNetworkBase constructor.

Only saves inputFile to m_inputFile; Inherting classes should call setupPersistedVariables and setupNet in constructor;

Parameters
inputFilefile-path on disk (with file name) of a readable file containing a description of the network to be constructed or the content of the file.

Definition at line 59 of file VNetworkBase.cxx.

20 : m_inputFile(inputFile) {
21 ATH_MSG_DEBUG("Constructor called with inputFile");
22};

◆ VNetworkBase() [3/3]

VNetworkBase::VNetworkBase ( const VNetworkBase & copy_from)

VNetworkBase copy constructor.

Does not call setupPersistedVariables or setupNet but will pass on m_inputFile. Inherting classes should do whatever they need to move the variables created in the setup functions.

Parameters
copy_fromexisting network that we are copying

Definition at line 71 of file VNetworkBase.cxx.

26 : MLogging(),
27 m_inputFile (copy_from.m_inputFile)
28{
29};
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
Definition MLogging.cxx:91

◆ writeBytesToTTree()

void TFCSONNXHandler::writeBytesToTTree ( TTree & tree,
const std::vector< char > & bytes )
private

Write the content of the proto file to a TTree as a branch.

The ONNX proto file is saved as a simple branch (no streamers involved).

Definition at line 310 of file TFCSONNXHandler.cxx.

311 {
312 ATH_MSG_DEBUG("TFCSONNXHandler writing bytes to tree.");
313 char m_session_data;
314 tree.Branch("serialized_m_session", &m_session_data,
315 "serialized_m_session/B");
316 for (Char_t here : bytes) {
317 m_session_data = here;
318 tree.Fill();
319 };
320 tree.Write();
321 ATH_MSG_DEBUG("TFCSONNXHandler written bytes to tree.");
322};

◆ writeNetToTTree() [1/3]

void VNetworkBase::writeNetToTTree ( std::string const & root_name,
std::string const & tree_name = m_defaultTreeName )

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_nameThe path of the file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 196 of file VNetworkBase.cxx.

94 {
95 ATH_MSG_DEBUG("Making or updating file name " << root_name);
96 TFile root_file(root_name.c_str(), "UPDATE");
97 this->writeNetToTTree(root_file, tree_name);
98 root_file.Close();
99};
void writeNetToTTree(TTree &tree) override
Save the network to a TTree.

◆ writeNetToTTree() [2/3]

void VNetworkBase::writeNetToTTree ( TFile & root_file,
std::string const & tree_name = m_defaultTreeName )

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_fileThe file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 184 of file VNetworkBase.cxx.

84 {
85 ATH_MSG_DEBUG("Making tree name " << tree_name);
86 root_file.cd();
87 const std::string title = "onnxruntime saved network";
88 TTree tree(tree_name.c_str(), title.c_str());
89 this->writeNetToTTree(tree);
90 root_file.Write();
91};

◆ writeNetToTTree() [3/3]

void TFCSONNXHandler::writeNetToTTree ( TTree & tree)
overridevirtual

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified. Will still work even if deleteAllButNet has already been called.

Parameters
treeThe tree to save inside.

Implements VNetworkBase.

Definition at line 62 of file TFCSONNXHandler.cxx.

62 {
63 ATH_MSG_DEBUG("TFCSONNXHandler writing net to tree.");
65};
void writeBytesToTTree(TTree &tree, const std::vector< char > &bytes)
Write the content of the proto file to a TTree as a branch.

Member Data Documentation

◆ ATLAS_THREAD_SAFE

boost::thread_specific_ptr<MsgStream> m_msg_tls ISF_FCS::MLogging::ATLAS_THREAD_SAFE
inlinestaticprivateinherited

Do not persistify!

MsgStream instance (a std::cout like with print-out levels)

Definition at line 215 of file MLogging.h.

◆ m_bytes

std::vector<char> TFCSONNXHandler::m_bytes
private

Content of the proto file.

Definition at line 170 of file TFCSONNXHandler.h.

◆ m_computeLambda

std::function<NetworkOutputs(NetworkInputs)> TFCSONNXHandler::m_computeLambda
private

computeTemplate with apropreate types selected.

Definition at line 307 of file TFCSONNXHandler.h.

◆ m_defaultTreeName

const std::string VNetworkBase::m_defaultTreeName = "onnxruntime_session"
inlinestaticinherited

Default name for the TTree to save in.

Definition at line 173 of file VNetworkBase.h.

◆ m_inputFile

std::string VNetworkBase::m_inputFile
protectedinherited

Path to the file describing the network, including filename.

Definition at line 245 of file VNetworkBase.h.

◆ m_inputNodeDims

std::vector<std::vector<int64_t> > TFCSONNXHandler::m_inputNodeDims
private

Do not persistify.

dimension lengths in each named input node

Describes the shape of the input nodes.

See also
TFCSONNXHandler::m_inputNodeNames

Definition at line 276 of file TFCSONNXHandler.h.

◆ m_inputNodeNames

std::vector<const char *> TFCSONNXHandler::m_inputNodeNames
private

names that index the input nodes

An ONNX network is capable of having two layers of labels, input node names, then labels within each node, but it's twin, LWTNN is not. LWTNN supports one list of nodes indexed by strings, and each input node may have more than one value, indexed by positive integers (list like), so this interfae only supports that more limited format.

Definition at line 234 of file TFCSONNXHandler.h.

◆ m_memoryInfo

Ort::MemoryInfo TFCSONNXHandler::m_memoryInfo
private
Initial value:
= Ort::MemoryInfo::CreateCpu(
OrtArenaAllocator, OrtMemTypeDefault)

Do not persistify.

Specifies memory behavior for vectors in ONNX.

Definition at line 312 of file TFCSONNXHandler.h.

◆ m_nm

std::string ISF_FCS::MLogging::m_nm
privateinherited

Message source name.

Definition at line 211 of file MLogging.h.

◆ m_outputLayers

std::vector<std::string> TFCSONNXHandler::m_outputLayers
private

Do not persistify.

Externally visible names that index the output.

Definition at line 318 of file TFCSONNXHandler.h.

◆ m_outputNodeDims

std::vector<std::vector<int64_t> > TFCSONNXHandler::m_outputNodeDims
private

Do not persistify.

dimension lengths in each named output node

As the final output must be flat in each output node, this is for internal manipulations only.

See also
TFCSONNXHandler::m_inputNodeDims

Definition at line 284 of file TFCSONNXHandler.h.

◆ m_outputNodeNames

std::vector<const char *> TFCSONNXHandler::m_outputNodeNames
private

Do not persistify.

the names that index the output nodes

An ONNX network is capable of having two layers of labels, input node names, then labels within each node, but it's twin, LWTNN is not. LWTNN supports one list of nodes indexed by strings, and each input node may have more than one value, indexed by positive integers (list like), so this interfae only supports that more limited format.

Definition at line 258 of file TFCSONNXHandler.h.

◆ m_outputNodeSize

std::vector<int64_t> TFCSONNXHandler::m_outputNodeSize
private

Do not persistify.

total elements in each named output node

For internal use only, gives the total number of elements in the output nodes.

See also
TFCSONNXHandler::m_inputNodeDims

Definition at line 292 of file TFCSONNXHandler.h.

◆ m_session

std::unique_ptr<Ort::Session> TFCSONNXHandler::m_session
private

The network session itself.

This is the object created by onnxruntime_cxx_api which contains information about the network and can run inputs through it.

Held as a unique pointer to prevent the need for manual memory management

Definition at line 210 of file TFCSONNXHandler.h.


The documentation for this class was generated from the following files: