ATLAS Offline Software
Public Types | Public Member Functions | Static Public Member Functions | Static Public Attributes | Protected Member Functions | Protected Attributes | Private Member Functions | Private Attributes | Static Private Attributes | List of all members
TFCSSimpleLWTNNHandler Class Reference

Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration. More...

#include <TFCSSimpleLWTNNHandler.h>

Inheritance diagram for TFCSSimpleLWTNNHandler:
Collaboration diagram for TFCSSimpleLWTNNHandler:

Public Types

typedef std::map< std::string, std::map< std::string, double > > NetworkInputs
 Format for network inputs. More...
 
typedef std::map< std::string, double > NetworkOutputs
 Format for network outputs. More...
 

Public Member Functions

 TFCSSimpleLWTNNHandler (const std::string &inputFile)
 TFCSSimpleLWTNNHandler constructor. More...
 
 TFCSSimpleLWTNNHandler (const TFCSSimpleLWTNNHandler &copy_from)
 TFCSSimpleLWTNNHandler copy constructor. More...
 
NetworkOutputs compute (NetworkInputs const &inputs) const override
 Function to pass values to the network. More...
 
std::vector< std::string > getOutputLayers () const override
 List the names of the outputs. More...
 
 VNetworkLWTNN (const VNetworkLWTNN &copy_from)
 VNetworkLWTNN copy constructor. More...
 
void writeNetToTTree (TTree &tree) override
 Save the network to a TTree. More...
 
virtual void writeNetToTTree (TTree &tree)=0
 Save the network to a TTree. More...
 
void writeNetToTTree (TFile &root_file, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree. More...
 
void writeNetToTTree (std::string const &root_name, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree. More...
 
void writeNetToTTree (TFile &root_file, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree. More...
 
void writeNetToTTree (std::string const &root_name, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree. More...
 
void deleteAllButNet () override
 Get rid of any memory objects that arn't needed to run the net. More...
 
 VNetworkBase ()
 VNetworkBase default constructor. More...
 
 VNetworkBase (const std::string &inputFile)
 VNetworkBase constructor. More...
 
 VNetworkBase (const VNetworkBase &copy_from)
 VNetworkBase copy constructor. More...
 
bool isFile () const
 Check if the argument inputFile is the path of a file on disk. More...
 
bool msgLvl (const MSG::Level lvl) const
 Check whether the logging system is active at the provided verbosity level. More...
 
MsgStream & msg () const
 Return a stream for sending messages directly (no decoration) More...
 
MsgStream & msg (const MSG::Level lvl) const
 Return a decorated starting stream for sending messages. More...
 
MSG::Level level () const
 Retrieve output level. More...
 
virtual void setLevel (MSG::Level lvl)
 Update outputlevel. More...
 

Static Public Member Functions

static std::string representNetworkInputs (NetworkInputs const &inputs, int maxValues=3)
 String representation of network inputs. More...
 
static std::string representNetworkOutputs (NetworkOutputs const &outputs, int maxValues=3)
 String representation of network outputs. More...
 
static bool isFile (std::string const &inputFile)
 Check if a string is the path of a file on disk. More...
 
static std::string startMsg (MSG::Level lvl, const std::string &file, int line)
 Make a message to decorate the start of logging. More...
 

Static Public Attributes

static const std::string m_defaultTreeName = "onnxruntime_session"
 Default name for the TTree to save in. More...
 

Protected Member Functions

void setupNet () override
 Perform actions that prepare network for use. More...
 
virtual void print (std::ostream &strm) const override
 Write a short description of this net to the string stream. More...
 
void setupPersistedVariables () override
 Perform actions that prep data to create the net. More...
 
bool isRootFile (std::string const &filename="") const
 Check if a string is possibly a root file path. More...
 
void removePrefixes (NetworkOutputs &outputs) const
 Remove any common prefix from the outputs. More...
 
void removePrefixes (std::vector< std::string > &output_names) const
 Remove any common prefix from the outputs. More...
 

Protected Attributes

std::string m_json
 String containing json input file. More...
 
std::string m_inputFile
 Path to the file describing the network, including filename. More...
 

Private Member Functions

 ClassDefOverride (TFCSSimpleLWTNNHandler, 1)
 Do not persistify. More...
 
void fillJson (std::string const &tree_name=m_defaultTreeName)
 Fill out m_json from a file provided to the constructor. More...
 
std::string readStringFromTTree (TTree &tree)
 Get json string from TTree. More...
 
void writeStringToTTree (TTree &tree, std::string json_string)
 Get json string from TTree. More...
 
 ClassDef (VNetworkBase, 1)
 

Private Attributes

std::unique_ptr< lwt::LightweightNeuralNetwork > m_lwtnn_neural
 The network that we are wrapping here. More...
 
std::vector< std::string > m_outputLayers
 Do not persistify. More...
 
std::string m_printable_name
 Stores a printable identifyer for the net. More...
 
std::string m_nm
 Message source name. More...
 

Static Private Attributes

static boost::thread_specific_ptr< MsgStream > m_msg_tls ATLAS_THREAD_SAFE
 Do not persistify! More...
 

Detailed Description

Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.

Class for a neural network read in the LWTNN format. Derived from the abstract base class VNetworkBase such that it can be used interchangably with it's sibling classes, TFCSSimpleLWTNNHandler, TFCSGANLWTNNHandler, TFCSONNXHandler.

Frustratingly, LightweightNeuralNetwork and LightweightGraph from lwtnn do not have a common ancestor, they could be connected with the bridge pattern, but that is more complex that currently required.

The LoadNetwork function has VNetworkBase as it's return type so that it can make a run-time decision about which derived class to use, based on the file name presented.

Definition at line 34 of file TFCSSimpleLWTNNHandler.h.

Member Typedef Documentation

◆ NetworkInputs

typedef std::map<std::string, std::map<std::string, double> > VNetworkBase::NetworkInputs
inherited

Format for network inputs.

The doubles are the values to be passed into the network. Strings in the outer map identify the input node, which must corrispond to the names of the nodes as read from the description of the network found by the constructor. Strings in the inner map identify the part of the input node, for some networks these must be simple integers, in string form, as parts of nodes do not always have the ability to carry real string labels.

Definition at line 90 of file VNetworkBase.h.

◆ NetworkOutputs

typedef std::map<std::string, double> VNetworkBase::NetworkOutputs
inherited

Format for network outputs.

The doubles are the values generated by the network. Strings identify which node this value came from, and when nodes have multiple values, are suffixed with a number to indicate which part of the node they came from. So in multi-value nodes the format becomes "<node_name>_<part_n>"

Definition at line 100 of file VNetworkBase.h.

Constructor & Destructor Documentation

◆ TFCSSimpleLWTNNHandler() [1/2]

TFCSSimpleLWTNNHandler::TFCSSimpleLWTNNHandler ( const std::string &  inputFile)
explicit

TFCSSimpleLWTNNHandler constructor.

Calls setupPersistedVariables and setupNet.

Parameters
inputFilefile-path on disk (with file name) of a readable lwtnn file containing a json format description of the network to be constructed, or the json itself as a string.

Definition at line 15 of file TFCSSimpleLWTNNHandler.cxx.

17  ATH_MSG_DEBUG("Setting up from inputFile.");
20 };

◆ TFCSSimpleLWTNNHandler() [2/2]

TFCSSimpleLWTNNHandler::TFCSSimpleLWTNNHandler ( const TFCSSimpleLWTNNHandler copy_from)

TFCSSimpleLWTNNHandler copy constructor.

Will copy the variables that would be generated by setupPersistedVariables and setupNet.

Parameters
copy_fromexisting network that we are copying

Definition at line 22 of file TFCSSimpleLWTNNHandler.cxx.

24  : VNetworkLWTNN(copy_from) {
25  // Cannot take copy of lwt::LightweightNeuralNetwork
26  // (copy constructor disabled)
27  ATH_MSG_DEBUG("Making new m_lwtnn_neural for copy of network.");
28  std::stringstream json_stream(m_json);
29  const lwt::JSONConfig config = lwt::parse_json(json_stream);
30  m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>(
31  config.inputs, config.layers, config.outputs);
32  m_outputLayers = copy_from.m_outputLayers;
33 };

Member Function Documentation

◆ ClassDef()

VNetworkBase::ClassDef ( VNetworkBase  ,
 
)
privateinherited

◆ ClassDefOverride()

TFCSSimpleLWTNNHandler::ClassDefOverride ( TFCSSimpleLWTNNHandler  ,
 
)
private

Do not persistify.

◆ compute()

TFCSSimpleLWTNNHandler::NetworkOutputs TFCSSimpleLWTNNHandler::compute ( TFCSSimpleLWTNNHandler::NetworkInputs const inputs) const
overridevirtual

Function to pass values to the network.

This function, hides variations in the formated needed by different network libraries, providing a uniform input and output type.

Parameters
inputsvalues to be evaluated by the network
Returns
the output of the network
See also
VNetworkBase::NetworkInputs
VNetworkBase::NetworkOutputs

Implements VNetworkBase.

Definition at line 59 of file TFCSSimpleLWTNNHandler.cxx.

60  {
61  ATH_MSG_DEBUG("Running computation on LWTNN neural network");
63  // Flatten the map depth
64  if (inputs.size() != 1) {
65  ATH_MSG_ERROR("The inputs have multiple elements."
66  << " An LWTNN neural network can only handle one node.");
67  };
68  std::map<std::string, double> flat_inputs;
69  for (auto node : inputs) {
70  flat_inputs = node.second;
71  }
72  // Now we have flattened, we can compute.
73  NetworkOutputs outputs = m_lwtnn_neural->compute(flat_inputs);
76  ATH_MSG_DEBUG("Computation on LWTNN neural network done, returning");
77  return outputs;
78 };

◆ deleteAllButNet()

void VNetworkLWTNN::deleteAllButNet ( )
overridevirtualinherited

Get rid of any memory objects that arn't needed to run the net.

Minimise memory usage by deleting nay inputs that are no longer required to run the compute function. Will prevent the net from being saved, if you need to call writeNetToTTree that must happen before this is called.

Implements VNetworkBase.

Definition at line 87 of file VNetworkLWTNN.cxx.

87  {
88  ATH_MSG_DEBUG("Replacing m_inputFile with unknown");
89  m_inputFile.assign("unknown");
90  m_inputFile.shrink_to_fit();
91  ATH_MSG_DEBUG("Emptying the m_json string");
92  m_json.clear();
93  m_json.shrink_to_fit();
94  ATH_MSG_VERBOSE("m_json now has capacity "
95  << m_json.capacity() << ". m_inputFile now has capacity "
96  << m_inputFile.capacity()
97  << ". m_printable_name now has capacity "
98  << m_printable_name.capacity());
99 };

◆ fillJson()

void VNetworkLWTNN::fillJson ( std::string const tree_name = m_defaultTreeName)
privateinherited

Fill out m_json from a file provided to the constructor.

Provided the string provided as inputFile to the constructor is a known file type (root or json) this function retreives the json string itself and puts it into m_json.

Parameters
tree_nameTTree name to check in when reading root files.

Definition at line 52 of file VNetworkLWTNN.cxx.

52  {
53  ATH_MSG_VERBOSE("Trying to fill the m_json variable");
54  if (this->isRootFile()) {
55  ATH_MSG_VERBOSE("Treating input file as a root file");
56  TFile tfile(this->m_inputFile.c_str(), "READ");
57  TTree *tree = (TTree *)tfile.Get(tree_name.c_str());
58  std::string found = this->readStringFromTTree(*tree);
59  ATH_MSG_DEBUG("Read json from root file, length " << found.length());
60  m_json = found;
61  } else {
62  ATH_MSG_VERBOSE("Treating input file as a text json file");
63  // The input file is read into a stringstream
64  std::ifstream input(m_inputFile);
65  std::ostringstream sstr;
66  sstr << input.rdbuf();
67  m_json = sstr.str();
68  input.close();
69  ATH_MSG_DEBUG("Read json from text file");
70  }
71 }

◆ getOutputLayers()

std::vector< std::string > TFCSSimpleLWTNNHandler::getOutputLayers ( ) const
overridevirtual

List the names of the outputs.

Outputs are stored in an NetworkOutputs object which is indexed by strings. This function returns the list of all strings that will index the outputs.

Implements VNetworkBase.

Definition at line 53 of file TFCSSimpleLWTNNHandler.cxx.

53  {
54  return m_outputLayers;
55 };

◆ isFile() [1/2]

bool VNetworkBase::isFile ( ) const
inherited

Check if the argument inputFile is the path of a file on disk.

Determines if the string that was passed to the constructor as inputFile corrisponds to tha path of a file that can be read on the disk.

Returns
is it a readable file on disk

Definition at line 117 of file VNetworkBase.cxx.

117 { return isFile(m_inputFile); };

◆ isFile() [2/2]

bool VNetworkBase::isFile ( std::string const inputFile)
staticinherited

Check if a string is the path of a file on disk.

Determines if a string corrisponds to tha path of a file that can be read on the disk.

Parameters
inputFilename of the pottential file
Returns
is it a readable file on disk

Definition at line 119 of file VNetworkBase.cxx.

119  {
120  if (FILE *file = std::fopen(inputFile.c_str(), "r")) {
121  std::fclose(file);
122  return true;
123  } else {
124  return false;
125  };
126 };

◆ isRootFile()

bool VNetworkBase::isRootFile ( std::string const filename = "") const
protectedinherited

Check if a string is possibly a root file path.

Just checks if the string ends in .root as there are almost no reliable rules for file paths.

Parameters
inputFilename of the pottential file if blank, m_inputFile is used.
Returns
is it the path of a root file

Definition at line 101 of file VNetworkBase.cxx.

101  {
102  const std::string *to_check = &filename;
103  if (filename.length() == 0) {
104  to_check = &this->m_inputFile;
105  ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile);
106  };
107  const std::string ending = ".root";
108  const int ending_len = ending.length();
109  const int filename_len = to_check->length();
110  if (filename_len < ending_len) {
111  return false;
112  }
113  return (0 ==
114  to_check->compare(filename_len - ending_len, ending_len, ending));
115 };

◆ level()

MSG::Level ISF_FCS::MLogging::level ( ) const
inlineinherited

Retrieve output level.

Definition at line 201 of file MLogging.h.

201 { return msg().level(); }

◆ msg() [1/2]

MsgStream & ISF_FCS::MLogging::msg ( ) const
inlineinherited

Return a stream for sending messages directly (no decoration)

Definition at line 231 of file MLogging.h.

231  {
232  MsgStream *ms = m_msg_tls.get();
233  if (!ms) {
234  ms = new MsgStream(Athena::getMessageSvc(), m_nm);
235  m_msg_tls.reset(ms);
236  }
237  return *ms;
238 }

◆ msg() [2/2]

MsgStream & ISF_FCS::MLogging::msg ( const MSG::Level  lvl) const
inlineinherited

Return a decorated starting stream for sending messages.

Definition at line 240 of file MLogging.h.

240  {
241  return msg() << lvl;
242 }

◆ msgLvl()

bool ISF_FCS::MLogging::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Check whether the logging system is active at the provided verbosity level.

Definition at line 222 of file MLogging.h.

222  {
223  if (msg().level() <= lvl) {
224  msg() << lvl;
225  return true;
226  } else {
227  return false;
228  }
229 }

◆ print()

void VNetworkLWTNN::print ( std::ostream &  strm) const
overrideprotectedvirtualinherited

Write a short description of this net to the string stream.

Outputs a printable name, which maybe a file name, or a note specifying that the file has been provided from memory.

Parameters
strmoutput parameter, to which the description will be written.

Reimplemented from VNetworkBase.

Definition at line 44 of file VNetworkLWTNN.cxx.

44  {
45  strm << m_printable_name;
46 };

◆ readStringFromTTree()

std::string VNetworkLWTNN::readStringFromTTree ( TTree &  tree)
privateinherited

Get json string from TTree.

Given a TTree object, retrive the json string from the standard branch. This is used to retrive a network previously saved using writeNetToTTree.

Parameters
treeTTree with the json saved inside.

Definition at line 73 of file VNetworkLWTNN.cxx.

73  {
74  std::string found = std::string();
75  std::string *to_found = &found;
76  tree.SetBranchAddress("lwtnn_json", &to_found);
77  tree.GetEntry(0);
78  return found;
79 };

◆ removePrefixes() [1/2]

void VNetworkBase::removePrefixes ( VNetworkBase::NetworkOutputs outputs) const
protectedinherited

Remove any common prefix from the outputs.

Parameters
outputsThe outputs, changed in place.

Definition at line 151 of file VNetworkBase.cxx.

151  {
152  std::vector<std::string> output_layers;
153  for (auto const &output : outputs)
154  output_layers.push_back(output.first);
155  const int length = GetPrefixLength(output_layers);
156  for (std::string layer_name : output_layers) {
157  // remove this output
158  auto nodeHandle = outputs.extract(layer_name);
159  // change the key
160  nodeHandle.key() = layer_name.substr(length);
161  // replace the output
162  outputs.insert(std::move(nodeHandle));
163  }
164 };

◆ removePrefixes() [2/2]

void VNetworkBase::removePrefixes ( std::vector< std::string > &  output_names) const
protectedinherited

Remove any common prefix from the outputs.

Parameters
outputsThe output names, changed in place.

Definition at line 144 of file VNetworkBase.cxx.

145  {
146  const int length = GetPrefixLength(output_names);
147  for (long unsigned int i = 0; i < output_names.size(); i++)
148  output_names[i] = output_names[i].substr(length);
149 };

◆ representNetworkInputs()

std::string VNetworkBase::representNetworkInputs ( VNetworkBase::NetworkInputs const inputs,
int  maxValues = 3 
)
staticinherited

String representation of network inputs.

Create a string that summarises a set of network inputs. Gives basic dimensions plus a few values, up to the maxValues

Parameters
inputsvalues to be evaluated by the network
maxValuesmaximum number of values to include in the representaiton
Returns
string represetning the inputs

Definition at line 37 of file VNetworkBase.cxx.

38  {
39  std::string representation =
40  "NetworkInputs, outer size " + std::to_string(inputs.size());
41  int valuesIncluded = 0;
42  for (const auto &outer : inputs) {
43  representation += "\n key->" + outer.first + "; ";
44  for (const auto &inner : outer.second) {
45  representation += inner.first + "=" + std::to_string(inner.second) + ", ";
46  ++valuesIncluded;
47  if (valuesIncluded > maxValues)
48  break;
49  };
50  if (valuesIncluded > maxValues)
51  break;
52  };
53  representation += "\n";
54  return representation;
55 };

◆ representNetworkOutputs()

std::string VNetworkBase::representNetworkOutputs ( VNetworkBase::NetworkOutputs const outputs,
int  maxValues = 3 
)
staticinherited

String representation of network outputs.

Create a string that summarises a set of network outputs. Gives basic dimensions plus a few values, up to the maxValues

Parameters
outputsoutput of the network
maxValuesmaximum number of values to include in the representaiton
Returns
string represetning the outputs

Definition at line 57 of file VNetworkBase.cxx.

58  {
59  std::string representation =
60  "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n";
61  int valuesIncluded = 0;
62  for (const auto &item : outputs) {
63  representation += item.first + "=" + std::to_string(item.second) + ", ";
64  ++valuesIncluded;
65  if (valuesIncluded > maxValues)
66  break;
67  };
68  representation += "\n";
69  return representation;
70 };

◆ setLevel()

void ISF_FCS::MLogging::setLevel ( MSG::Level  lvl)
virtualinherited

Update outputlevel.

Definition at line 105 of file MLogging.cxx.

105  {
106  lvl = (lvl >= MSG::NUM_LEVELS) ? MSG::ALWAYS
107  : (lvl < MSG::NIL) ? MSG::NIL
108  : lvl;
109  msg().setLevel(lvl);
110 }

◆ setupNet()

void TFCSSimpleLWTNNHandler::setupNet ( )
overrideprotectedvirtual

Perform actions that prepare network for use.

Will be called in the streamer or class constructor after the inputs have been set (either automaically by the streamer or by setupPersistedVariables in the constructor). Does not delete any resources used.

Implements VNetworkBase.

Definition at line 35 of file TFCSSimpleLWTNNHandler.cxx.

35  {
36  // build the graph
37  ATH_MSG_DEBUG("Reading the m_json string stream into a neural network");
38  std::stringstream json_stream(m_json);
39  const lwt::JSONConfig config = lwt::parse_json(json_stream);
40  m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>(
41  config.inputs, config.layers, config.outputs);
42  // Get the output layers
43  ATH_MSG_DEBUG("Getting output layers for neural network");
44  for (const std::string& name : config.outputs) {
45  ATH_MSG_VERBOSE("Found output layer called " << name);
46  m_outputLayers.push_back(name);
47  };
48  ATH_MSG_DEBUG("Removing prefix from stored layers.");
50  ATH_MSG_DEBUG("Finished output nodes.");
51 }

◆ setupPersistedVariables()

void VNetworkLWTNN::setupPersistedVariables ( )
overrideprotectedvirtualinherited

Perform actions that prep data to create the net.

Will be called in the base class constructor before calling setupNet, but not in the streamer. It sets any variables that the sreamer would persist when saving or loading to file.

Implements VNetworkBase.

Definition at line 30 of file VNetworkLWTNN.cxx.

30  {
31  if (this->isFile(m_inputFile)) {
32  ATH_MSG_DEBUG("Making an LWTNN network using a file on disk, "
33  << m_inputFile);
35  fillJson();
36  } else {
37  ATH_MSG_DEBUG("Making an LWTNN network using a json in memory, length "
38  << m_inputFile.length());
39  m_printable_name = "JSON from memory";
41  };
42 };

◆ startMsg()

std::string ISF_FCS::MLogging::startMsg ( MSG::Level  lvl,
const std::string &  file,
int  line 
)
staticinherited

Make a message to decorate the start of logging.

Print a message for the start of logging.

Definition at line 116 of file MLogging.cxx.

116  {
117  int col1_len = 20;
118  int col2_len = 5;
119  int col3_len = 10;
120  auto last_slash = file.find_last_of('/');
121  int path_len = last_slash == std::string::npos ? 0 : last_slash;
122  int trim_point = path_len;
123  int total_len = file.length();
124  if (total_len - path_len > col1_len)
125  trim_point = total_len - col1_len;
126  std::string trimmed_name = file.substr(trim_point);
127  const char *LevelNames[MSG::NUM_LEVELS] = {
128  "NIL", "VERBOSE", "DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "ALWAYS"};
129  std::string level = LevelNames[lvl];
130  std::string level_string = std::string("(") + level + ") ";
131  std::stringstream output;
132  output << std::setw(col1_len) << std::right << trimmed_name << ":"
133  << std::setw(col2_len) << std::left << line << std::setw(col3_len)
134  << std::right << level_string;
135  return output.str();
136 }

◆ VNetworkBase() [1/3]

VNetworkBase::VNetworkBase
inherited

VNetworkBase default constructor.

For use in streamers.

Definition at line 45 of file VNetworkBase.cxx.

16 : m_inputFile("unknown"){};

◆ VNetworkBase() [2/3]

VNetworkBase::VNetworkBase
explicitinherited

VNetworkBase constructor.

Only saves inputFile to m_inputFile; Inherting classes should call setupPersistedVariables and setupNet in constructor;

Parameters
inputFilefile-path on disk (with file name) of a readable file containing a description of the network to be constructed or the content of the file.

Definition at line 59 of file VNetworkBase.cxx.

21  ATH_MSG_DEBUG("Constructor called with inputFile");
22 };

◆ VNetworkBase() [3/3]

VNetworkBase::VNetworkBase
inherited

VNetworkBase copy constructor.

Does not call setupPersistedVariables or setupNet but will pass on m_inputFile. Inherting classes should do whatever they need to move the variables created in the setup functions.

Parameters
copy_fromexisting network that we are copying

Definition at line 71 of file VNetworkBase.cxx.

26  : MLogging(),
27  m_inputFile (copy_from.m_inputFile)
28 {
29 };

◆ VNetworkLWTNN()

VNetworkLWTNN::VNetworkLWTNN

VNetworkLWTNN copy constructor.

Will copy the variables that would be generated by setupPersistedVariables and setupNet. Will fail if deleteAllButNet has already been called.

Parameters
copy_fromexisting network that we are copying

Definition at line 45 of file VNetworkLWTNN.cxx.

15  : VNetworkBase(copy_from),
16  m_json (copy_from.m_json),
17  m_printable_name (copy_from.m_printable_name)
18 {
19  if (m_json.length() == 0) {
20  throw std::invalid_argument(
21  "Trying to copy a VNetworkLWTNN with length 0 m_json, probably "
22  "deleteAllButNet was called on the object being coppied from.");
23  };
24 };

◆ writeNetToTTree() [1/6]

void VNetworkBase::writeNetToTTree
inherited

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_nameThe path of the file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 196 of file VNetworkBase.cxx.

94  {
95  ATH_MSG_DEBUG("Making or updating file name " << root_name);
96  TFile root_file(root_name.c_str(), "UPDATE");
97  this->writeNetToTTree(root_file, tree_name);
98  root_file.Close();
99 };

◆ writeNetToTTree() [2/6]

void VNetworkBase::writeNetToTTree ( std::string const root_name,
std::string const tree_name = m_defaultTreeName 
)
inherited

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_nameThe path of the file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 93 of file VNetworkBase.cxx.

94  {
95  ATH_MSG_DEBUG("Making or updating file name " << root_name);
96  TFile root_file(root_name.c_str(), "UPDATE");
97  this->writeNetToTTree(root_file, tree_name);
98  root_file.Close();
99 };

◆ writeNetToTTree() [3/6]

void VNetworkBase::writeNetToTTree
inherited

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_fileThe file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 184 of file VNetworkBase.cxx.

84  {
85  ATH_MSG_DEBUG("Making tree name " << tree_name);
86  root_file.cd();
87  const std::string title = "onnxruntime saved network";
88  TTree tree(tree_name.c_str(), title.c_str());
89  this->writeNetToTTree(tree);
90  root_file.Write();
91 };

◆ writeNetToTTree() [4/6]

void VNetworkBase::writeNetToTTree ( TFile &  root_file,
std::string const tree_name = m_defaultTreeName 
)
inherited

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_fileThe file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 83 of file VNetworkBase.cxx.

84  {
85  ATH_MSG_DEBUG("Making tree name " << tree_name);
86  root_file.cd();
87  const std::string title = "onnxruntime saved network";
88  TTree tree(tree_name.c_str(), title.c_str());
89  this->writeNetToTTree(tree);
90  root_file.Write();
91 };

◆ writeNetToTTree() [5/6]

void VNetworkLWTNN::writeNetToTTree ( TTree &  tree)
overridevirtualinherited

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
treeThe tree to save inside.

Implements VNetworkBase.

Definition at line 48 of file VNetworkLWTNN.cxx.

48  {
50 };

◆ writeNetToTTree() [6/6]

virtual void VNetworkBase::writeNetToTTree
inherited

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
treeThe tree to save inside.

◆ writeStringToTTree()

void VNetworkLWTNN::writeStringToTTree ( TTree &  tree,
std::string  json_string 
)
privateinherited

Get json string from TTree.

Given a TTree object, retrive the json string from the standard branch. This is used to retrive a network previously saved using writeNetToTTree.

Parameters
treeTTree with the json saved inside.

Definition at line 81 of file VNetworkLWTNN.cxx.

81  {
82  tree.Branch("lwtnn_json", &json_string);
83  tree.Fill();
84  tree.Write();
85 };

Member Data Documentation

◆ ATLAS_THREAD_SAFE

boost::thread_specific_ptr<MsgStream> m_msg_tls ISF_FCS::MLogging::ATLAS_THREAD_SAFE
inlinestaticprivateinherited

Do not persistify!

MsgStream instance (a std::cout like with print-out levels)

Definition at line 215 of file MLogging.h.

◆ m_defaultTreeName

const std::string VNetworkBase::m_defaultTreeName = "onnxruntime_session"
inlinestaticinherited

Default name for the TTree to save in.

Definition at line 173 of file VNetworkBase.h.

◆ m_inputFile

std::string VNetworkBase::m_inputFile
protectedinherited

Path to the file describing the network, including filename.

Definition at line 245 of file VNetworkBase.h.

◆ m_json

std::string VNetworkLWTNN::m_json
protectedinherited

String containing json input file.

Is needed to save the network with writeNetToTTree but not needed to run the network with compute. Is eraised by deleteAllButNet Should be persisted.

Definition at line 84 of file VNetworkLWTNN.h.

◆ m_lwtnn_neural

std::unique_ptr<lwt::LightweightNeuralNetwork> TFCSSimpleLWTNNHandler::m_lwtnn_neural
private

The network that we are wrapping here.

Definition at line 103 of file TFCSSimpleLWTNNHandler.h.

◆ m_nm

std::string ISF_FCS::MLogging::m_nm
privateinherited

Message source name.

Definition at line 211 of file MLogging.h.

◆ m_outputLayers

std::vector<std::string> TFCSSimpleLWTNNHandler::m_outputLayers
private

Do not persistify.

List of names that index the output layer.

Definition at line 107 of file TFCSSimpleLWTNNHandler.h.

◆ m_printable_name

std::string VNetworkLWTNN::m_printable_name
privateinherited

Stores a printable identifyer for the net.

Not unique.

Definition at line 144 of file VNetworkLWTNN.h.


The documentation for this class was generated from the following files:
VNetworkBase::NetworkOutputs
std::map< std::string, double > NetworkOutputs
Format for network outputs.
Definition: VNetworkBase.h:100
TFCSSimpleLWTNNHandler::setupNet
void setupNet() override
Perform actions that prepare network for use.
Definition: TFCSSimpleLWTNNHandler.cxx:35
checkFileSG.line
line
Definition: checkFileSG.py:75
TFCSSimpleLWTNNHandler::m_lwtnn_neural
std::unique_ptr< lwt::LightweightNeuralNetwork > m_lwtnn_neural
The network that we are wrapping here.
Definition: TFCSSimpleLWTNNHandler.h:103
VNetworkBase::VNetworkBase
VNetworkBase()
VNetworkBase default constructor.
Definition: VNetworkBase.cxx:16
VNetworkLWTNN::fillJson
void fillJson(std::string const &tree_name=m_defaultTreeName)
Fill out m_json from a file provided to the constructor.
Definition: VNetworkLWTNN.cxx:52
VNetworkBase::representNetworkOutputs
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
Definition: VNetworkBase.cxx:57
ISF_FCS::MLogging::level
MSG::Level level() const
Retrieve output level.
Definition: MLogging.h:201
tree
TChain * tree
Definition: tile_monitor.h:30
VNetworkLWTNN::m_json
std::string m_json
String containing json input file.
Definition: VNetworkLWTNN.h:84
VNetworkBase::isFile
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
Definition: VNetworkBase.cxx:117
TFCSSimpleLWTNNHandler::m_outputLayers
std::vector< std::string > m_outputLayers
Do not persistify.
Definition: TFCSSimpleLWTNNHandler.h:107
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
python.SystemOfUnits.ms
int ms
Definition: SystemOfUnits.py:132
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
lwtDev::parse_json
JSONConfig parse_json(std::istream &json)
Definition: parse_json.cxx:42
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
ISF_FCS::MLogging::MLogging
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
Definition: MLogging.cxx:91
VNetworkLWTNN::writeStringToTTree
void writeStringToTTree(TTree &tree, std::string json_string)
Get json string from TTree.
Definition: VNetworkLWTNN.cxx:81
ISF_FCS::MLogging::msg
MsgStream & msg() const
Return a stream for sending messages directly (no decoration)
Definition: MLogging.h:231
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
TrigConf::MSGTC::ALWAYS
@ ALWAYS
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:29
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
VNetworkLWTNN::readStringFromTTree
std::string readStringFromTTree(TTree &tree)
Get json string from TTree.
Definition: VNetworkLWTNN.cxx:73
covarianceTool.title
title
Definition: covarianceTool.py:542
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
ISF_FCS::MLogging::m_nm
std::string m_nm
Message source name.
Definition: MLogging.h:211
file
TFile * file
Definition: tile_monitor.h:29
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
checkTriggerxAOD.tfile
tfile
Definition: checkTriggerxAOD.py:277
VNetworkBase::removePrefixes
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
Definition: VNetworkBase.cxx:151
merge.output
output
Definition: merge.py:17
TrigConf::MSGTC::NUM_LEVELS
@ NUM_LEVELS
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:30
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
VNetworkLWTNN::m_printable_name
std::string m_printable_name
Stores a printable identifyer for the net.
Definition: VNetworkLWTNN.h:144
VNetworkBase::representNetworkInputs
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
Definition: VNetworkBase.cxx:37
item
Definition: ItemListSvc.h:43
VNetworkLWTNN::writeNetToTTree
void writeNetToTTree(TTree &tree) override
Save the network to a TTree.
Definition: VNetworkLWTNN.cxx:48
CondAlgsOpts.found
int found
Definition: CondAlgsOpts.py:101
VNetworkLWTNN::VNetworkLWTNN
VNetworkLWTNN(const VNetworkLWTNN &copy_from)
VNetworkLWTNN copy constructor.
Definition: VNetworkLWTNN.cxx:14
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
VNetworkBase::isRootFile
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
Definition: VNetworkBase.cxx:101
VNetworkBase::writeNetToTTree
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
VNetworkLWTNN::setupPersistedVariables
void setupPersistedVariables() override
Perform actions that prep data to create the net.
Definition: VNetworkLWTNN.cxx:30
TrigConf::MSGTC::NIL
@ NIL
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:22
VNetworkBase::m_inputFile
std::string m_inputFile
Path to the file describing the network, including filename.
Definition: VNetworkBase.h:245
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
node
Definition: memory_hooks-stdcmalloc.h:74