ATLAS Offline Software
Public Types | Public Member Functions | Static Public Member Functions | Static Public Attributes | Protected Member Functions | Protected Attributes | Private Member Functions | Private Attributes | Static Private Attributes | Friends | List of all members
VNetworkBase Class Referenceabstract

Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration. More...

#include <VNetworkBase.h>

Inheritance diagram for VNetworkBase:
Collaboration diagram for VNetworkBase:

Public Types

typedef std::map< std::string, std::map< std::string, double > > NetworkInputs
 Format for network inputs. More...
 
typedef std::map< std::string, double > NetworkOutputs
 Format for network outputs. More...
 

Public Member Functions

 VNetworkBase ()
 VNetworkBase default constructor. More...
 
 VNetworkBase (const std::string &inputFile)
 VNetworkBase constructor. More...
 
 VNetworkBase (const VNetworkBase &copy_from)
 VNetworkBase copy constructor. More...
 
virtual ~VNetworkBase ()
 
virtual NetworkOutputs compute (NetworkInputs const &inputs) const =0
 Function to pass values to the network. More...
 
virtual void writeNetToTTree (TTree &tree)=0
 Save the network to a TTree. More...
 
void writeNetToTTree (TFile &root_file, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree. More...
 
void writeNetToTTree (std::string const &root_name, std::string const &tree_name=m_defaultTreeName)
 Save the network to a TTree. More...
 
virtual std::vector< std::string > getOutputLayers () const =0
 List the names of the outputs. More...
 
bool isFile () const
 Check if the argument inputFile is the path of a file on disk. More...
 
virtual void deleteAllButNet ()=0
 Get rid of any memory objects that arn't needed to run the net. More...
 
bool msgLvl (const MSG::Level lvl) const
 Check whether the logging system is active at the provided verbosity level. More...
 
MsgStream & msg () const
 Return a stream for sending messages directly (no decoration) More...
 
MsgStream & msg (const MSG::Level lvl) const
 Return a decorated starting stream for sending messages. More...
 
MSG::Level level () const
 Retrieve output level. More...
 
virtual void setLevel (MSG::Level lvl)
 Update outputlevel. More...
 

Static Public Member Functions

static std::string representNetworkInputs (NetworkInputs const &inputs, int maxValues=3)
 String representation of network inputs. More...
 
static std::string representNetworkOutputs (NetworkOutputs const &outputs, int maxValues=3)
 String representation of network outputs. More...
 
static bool isFile (std::string const &inputFile)
 Check if a string is the path of a file on disk. More...
 
static std::string startMsg (MSG::Level lvl, const std::string &file, int line)
 Make a message to decorate the start of logging. More...
 

Static Public Attributes

static const std::string m_defaultTreeName = "onnxruntime_session"
 Default name for the TTree to save in. More...
 

Protected Member Functions

virtual void setupPersistedVariables ()=0
 Perform actions that prep data to create the net. More...
 
virtual void setupNet ()=0
 Perform actions that prepare network for use. More...
 
virtual void print (std::ostream &strm) const
 Write a short description of this net to the string stream. More...
 
bool isRootFile (std::string const &filename="") const
 Check if a string is possibly a root file path. More...
 
void removePrefixes (NetworkOutputs &outputs) const
 Remove any common prefix from the outputs. More...
 
void removePrefixes (std::vector< std::string > &output_names) const
 Remove any common prefix from the outputs. More...
 

Protected Attributes

std::string m_inputFile
 Path to the file describing the network, including filename. More...
 

Private Member Functions

 ClassDef (VNetworkBase, 1)
 

Private Attributes

std::string m_nm
 Message source name. More...
 

Static Private Attributes

static boost::thread_specific_ptr< MsgStream > m_msg_tls ATLAS_THREAD_SAFE
 Do not persistify! More...
 

Friends

std::ostream & operator<< (std::ostream &strm, const VNetworkBase &vNetworkBase)
 Put-to operator to facilitate printing. More...
 

Detailed Description

Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.

Abstract base class for Neural networks. Intially aimed at replacing instances of an lwtnn network with a network that could be either lwtnn or ONNX, so it is an interface which mirrors that of lwtnn graphs. At least 3 derived classes are avaliable;

The TFCSNetworkFactory::Create function has this class as its return type so that it can make a run-time decision about which derived class to use, based on the file or data presented.

A template defining the interface to a neural network.

Has various subclasses to cover differing network libraries and save formats.

Definition at line 38 of file VNetworkBase.h.

Member Typedef Documentation

◆ NetworkInputs

typedef std::map<std::string, std::map<std::string, double> > VNetworkBase::NetworkInputs

Format for network inputs.

The doubles are the values to be passed into the network. Strings in the outer map identify the input node, which must corrispond to the names of the nodes as read from the description of the network found by the constructor. Strings in the inner map identify the part of the input node, for some networks these must be simple integers, in string form, as parts of nodes do not always have the ability to carry real string labels.

Definition at line 90 of file VNetworkBase.h.

◆ NetworkOutputs

typedef std::map<std::string, double> VNetworkBase::NetworkOutputs

Format for network outputs.

The doubles are the values generated by the network. Strings identify which node this value came from, and when nodes have multiple values, are suffixed with a number to indicate which part of the node they came from. So in multi-value nodes the format becomes "<node_name>_<part_n>"

Definition at line 100 of file VNetworkBase.h.

Constructor & Destructor Documentation

◆ VNetworkBase() [1/3]

VNetworkBase::VNetworkBase ( )

VNetworkBase default constructor.

For use in streamers.

Definition at line 16 of file VNetworkBase.cxx.

16 : m_inputFile("unknown"){};

◆ VNetworkBase() [2/3]

VNetworkBase::VNetworkBase ( const std::string &  inputFile)
explicit

VNetworkBase constructor.

Only saves inputFile to m_inputFile; Inherting classes should call setupPersistedVariables and setupNet in constructor;

Parameters
inputFilefile-path on disk (with file name) of a readable file containing a description of the network to be constructed or the content of the file.

Definition at line 19 of file VNetworkBase.cxx.

21  ATH_MSG_DEBUG("Constructor called with inputFile");
22 };

◆ VNetworkBase() [3/3]

VNetworkBase::VNetworkBase ( const VNetworkBase copy_from)

VNetworkBase copy constructor.

Does not call setupPersistedVariables or setupNet but will pass on m_inputFile. Inherting classes should do whatever they need to move the variables created in the setup functions.

Parameters
copy_fromexisting network that we are copying

Definition at line 25 of file VNetworkBase.cxx.

26  : MLogging(),
27  m_inputFile (copy_from.m_inputFile)
28 {
29 };

◆ ~VNetworkBase()

VNetworkBase::~VNetworkBase ( )
virtual

Definition at line 34 of file VNetworkBase.cxx.

34 {};

Member Function Documentation

◆ ClassDef()

VNetworkBase::ClassDef ( VNetworkBase  ,
 
)
private

◆ compute()

virtual NetworkOutputs VNetworkBase::compute ( NetworkInputs const inputs) const
pure virtual

Function to pass values to the network.

This function hides variations in the formated needed by different network libraries, providing a uniform input and output type.

Parameters
inputsvalues to be evaluated by the network
Returns
the output of the network
See also
VNetworkBase::NetworkInputs
VNetworkBase::NetworkOutputs

Implemented in TFCSONNXHandler, TFCSGANLWTNNHandler, and TFCSSimpleLWTNNHandler.

◆ deleteAllButNet()

virtual void VNetworkBase::deleteAllButNet ( )
pure virtual

Get rid of any memory objects that arn't needed to run the net.

Minimise memory usage by deleting any inputs that are no longer required to run the compute function. May prevent the net from being saved.

Implemented in TFCSONNXHandler, and VNetworkLWTNN.

◆ getOutputLayers()

virtual std::vector<std::string> VNetworkBase::getOutputLayers ( ) const
pure virtual

List the names of the outputs.

Outputs are stored in an NetworkOutputs object which is indexed by strings. This function returns the list of all strings that will index the outputs.

Implemented in TFCSONNXHandler, TFCSGANLWTNNHandler, and TFCSSimpleLWTNNHandler.

◆ isFile() [1/2]

bool VNetworkBase::isFile ( ) const

Check if the argument inputFile is the path of a file on disk.

Determines if the string that was passed to the constructor as inputFile corrisponds to tha path of a file that can be read on the disk.

Returns
is it a readable file on disk

Definition at line 117 of file VNetworkBase.cxx.

117 { return isFile(m_inputFile); };

◆ isFile() [2/2]

bool VNetworkBase::isFile ( std::string const inputFile)
static

Check if a string is the path of a file on disk.

Determines if a string corrisponds to tha path of a file that can be read on the disk.

Parameters
inputFilename of the pottential file
Returns
is it a readable file on disk

Definition at line 119 of file VNetworkBase.cxx.

119  {
120  if (FILE *file = std::fopen(inputFile.c_str(), "r")) {
121  std::fclose(file);
122  return true;
123  } else {
124  return false;
125  };
126 };

◆ isRootFile()

bool VNetworkBase::isRootFile ( std::string const filename = "") const
protected

Check if a string is possibly a root file path.

Just checks if the string ends in .root as there are almost no reliable rules for file paths.

Parameters
inputFilename of the pottential file if blank, m_inputFile is used.
Returns
is it the path of a root file

Definition at line 101 of file VNetworkBase.cxx.

101  {
102  const std::string *to_check = &filename;
103  if (filename.length() == 0) {
104  to_check = &this->m_inputFile;
105  ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile);
106  };
107  const std::string ending = ".root";
108  const int ending_len = ending.length();
109  const int filename_len = to_check->length();
110  if (filename_len < ending_len) {
111  return false;
112  }
113  return (0 ==
114  to_check->compare(filename_len - ending_len, ending_len, ending));
115 };

◆ level()

MSG::Level ISF_FCS::MLogging::level ( ) const
inlineinherited

Retrieve output level.

Definition at line 201 of file MLogging.h.

201 { return msg().level(); }

◆ msg() [1/2]

MsgStream & ISF_FCS::MLogging::msg ( ) const
inlineinherited

Return a stream for sending messages directly (no decoration)

Definition at line 231 of file MLogging.h.

231  {
232  MsgStream *ms = m_msg_tls.get();
233  if (!ms) {
234  ms = new MsgStream(Athena::getMessageSvc(), m_nm);
235  m_msg_tls.reset(ms);
236  }
237  return *ms;
238 }

◆ msg() [2/2]

MsgStream & ISF_FCS::MLogging::msg ( const MSG::Level  lvl) const
inlineinherited

Return a decorated starting stream for sending messages.

Definition at line 240 of file MLogging.h.

240  {
241  return msg() << lvl;
242 }

◆ msgLvl()

bool ISF_FCS::MLogging::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Check whether the logging system is active at the provided verbosity level.

Definition at line 222 of file MLogging.h.

222  {
223  if (msg().level() <= lvl) {
224  msg() << lvl;
225  return true;
226  } else {
227  return false;
228  }
229 }

◆ print()

void VNetworkBase::print ( std::ostream &  strm) const
protectedvirtual

Write a short description of this net to the string stream.

Intended to facilitate the put-to operator, allowing subclasses to change how this object is displayed.

Parameters
strmoutput parameter, to which the description will be written.

Reimplemented in TFCSONNXHandler, and VNetworkLWTNN.

Definition at line 73 of file VNetworkBase.cxx.

73  {
74  if (m_inputFile.empty()) {
75  ATH_MSG_DEBUG("Making a network without a named inputFile");
76  strm << "Unknown network";
77  } else {
78  ATH_MSG_DEBUG("Making a network with input file " << m_inputFile);
79  strm << m_inputFile;
80  };
81 };

◆ removePrefixes() [1/2]

void VNetworkBase::removePrefixes ( VNetworkBase::NetworkOutputs outputs) const
protected

Remove any common prefix from the outputs.

Parameters
outputsThe outputs, changed in place.

Definition at line 151 of file VNetworkBase.cxx.

151  {
152  std::vector<std::string> output_layers;
153  for (auto const &output : outputs)
154  output_layers.push_back(output.first);
155  const int length = GetPrefixLength(output_layers);
156  for (std::string layer_name : output_layers) {
157  // remove this output
158  auto nodeHandle = outputs.extract(layer_name);
159  // change the key
160  nodeHandle.key() = layer_name.substr(length);
161  // replace the output
162  outputs.insert(std::move(nodeHandle));
163  }
164 };

◆ removePrefixes() [2/2]

void VNetworkBase::removePrefixes ( std::vector< std::string > &  output_names) const
protected

Remove any common prefix from the outputs.

Parameters
outputsThe output names, changed in place.

Definition at line 144 of file VNetworkBase.cxx.

145  {
146  const int length = GetPrefixLength(output_names);
147  for (long unsigned int i = 0; i < output_names.size(); i++)
148  output_names[i] = output_names[i].substr(length);
149 };

◆ representNetworkInputs()

std::string VNetworkBase::representNetworkInputs ( VNetworkBase::NetworkInputs const inputs,
int  maxValues = 3 
)
static

String representation of network inputs.

Create a string that summarises a set of network inputs. Gives basic dimensions plus a few values, up to the maxValues

Parameters
inputsvalues to be evaluated by the network
maxValuesmaximum number of values to include in the representaiton
Returns
string represetning the inputs

Definition at line 37 of file VNetworkBase.cxx.

38  {
39  std::string representation =
40  "NetworkInputs, outer size " + std::to_string(inputs.size());
41  int valuesIncluded = 0;
42  for (const auto &outer : inputs) {
43  representation += "\n key->" + outer.first + "; ";
44  for (const auto &inner : outer.second) {
45  representation += inner.first + "=" + std::to_string(inner.second) + ", ";
46  ++valuesIncluded;
47  if (valuesIncluded > maxValues)
48  break;
49  };
50  if (valuesIncluded > maxValues)
51  break;
52  };
53  representation += "\n";
54  return representation;
55 };

◆ representNetworkOutputs()

std::string VNetworkBase::representNetworkOutputs ( VNetworkBase::NetworkOutputs const outputs,
int  maxValues = 3 
)
static

String representation of network outputs.

Create a string that summarises a set of network outputs. Gives basic dimensions plus a few values, up to the maxValues

Parameters
outputsoutput of the network
maxValuesmaximum number of values to include in the representaiton
Returns
string represetning the outputs

Definition at line 57 of file VNetworkBase.cxx.

58  {
59  std::string representation =
60  "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n";
61  int valuesIncluded = 0;
62  for (const auto &item : outputs) {
63  representation += item.first + "=" + std::to_string(item.second) + ", ";
64  ++valuesIncluded;
65  if (valuesIncluded > maxValues)
66  break;
67  };
68  representation += "\n";
69  return representation;
70 };

◆ setLevel()

void ISF_FCS::MLogging::setLevel ( MSG::Level  lvl)
virtualinherited

Update outputlevel.

Definition at line 105 of file MLogging.cxx.

105  {
106  lvl = (lvl >= MSG::NUM_LEVELS) ? MSG::ALWAYS
107  : (lvl < MSG::NIL) ? MSG::NIL
108  : lvl;
109  msg().setLevel(lvl);
110 }

◆ setupNet()

virtual void VNetworkBase::setupNet ( )
protectedpure virtual

Perform actions that prepare network for use.

Will be called in the streamer or class constructor after the inputs have been set (either automaically by the streamer or by setupPersistedVariables in the constructor). Does not delete any resources used.

Implemented in TFCSONNXHandler, TFCSGANLWTNNHandler, and TFCSSimpleLWTNNHandler.

◆ setupPersistedVariables()

virtual void VNetworkBase::setupPersistedVariables ( )
protectedpure virtual

Perform actions that prep data to create the net.

Will be called in the class constructor before calling setupNet, but not in the streamer. It sets any variables that the sreamer would persist when saving or loading to file.

Implemented in TFCSONNXHandler, and VNetworkLWTNN.

◆ startMsg()

std::string ISF_FCS::MLogging::startMsg ( MSG::Level  lvl,
const std::string &  file,
int  line 
)
staticinherited

Make a message to decorate the start of logging.

Print a message for the start of logging.

Definition at line 116 of file MLogging.cxx.

116  {
117  int col1_len = 20;
118  int col2_len = 5;
119  int col3_len = 10;
120  auto last_slash = file.find_last_of('/');
121  int path_len = last_slash == std::string::npos ? 0 : last_slash;
122  int trim_point = path_len;
123  int total_len = file.length();
124  if (total_len - path_len > col1_len)
125  trim_point = total_len - col1_len;
126  std::string trimmed_name = file.substr(trim_point);
127  const char *LevelNames[MSG::NUM_LEVELS] = {
128  "NIL", "VERBOSE", "DEBUG", "INFO", "WARNING", "ERROR", "FATAL", "ALWAYS"};
129  std::string level = LevelNames[lvl];
130  std::string level_string = std::string("(") + level + ") ";
131  std::stringstream output;
132  output << std::setw(col1_len) << std::right << trimmed_name << ":"
133  << std::setw(col2_len) << std::left << line << std::setw(col3_len)
134  << std::right << level_string;
135  return output.str();
136 }

◆ writeNetToTTree() [1/3]

void VNetworkBase::writeNetToTTree ( std::string const root_name,
std::string const tree_name = m_defaultTreeName 
)

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_nameThe path of the file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 93 of file VNetworkBase.cxx.

94  {
95  ATH_MSG_DEBUG("Making or updating file name " << root_name);
96  TFile root_file(root_name.c_str(), "UPDATE");
97  this->writeNetToTTree(root_file, tree_name);
98  root_file.Close();
99 };

◆ writeNetToTTree() [2/3]

void VNetworkBase::writeNetToTTree ( TFile &  root_file,
std::string const tree_name = m_defaultTreeName 
)

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
root_fileThe file to save inside.
tree_nameThe name of the TTree to save inside.

Definition at line 83 of file VNetworkBase.cxx.

84  {
85  ATH_MSG_DEBUG("Making tree name " << tree_name);
86  root_file.cd();
87  const std::string title = "onnxruntime saved network";
88  TTree tree(tree_name.c_str(), title.c_str());
89  this->writeNetToTTree(tree);
90  root_file.Write();
91 };

◆ writeNetToTTree() [3/3]

virtual void VNetworkBase::writeNetToTTree ( TTree &  tree)
pure virtual

Save the network to a TTree.

All data required to recreate the network object is saved into a TTree. The format is not specified.

Parameters
treeThe tree to save inside.

Implemented in TFCSONNXHandler, and VNetworkLWTNN.

Friends And Related Function Documentation

◆ operator<<

std::ostream& operator<< ( std::ostream &  strm,
const VNetworkBase vNetworkBase 
)
friend

Put-to operator to facilitate printing.

It is useful to be able to display a reasonable representation of a network for debugging. This can be altered by subclasses by changing the protected print function of this class.

Definition at line 154 of file VNetworkBase.h.

155  {
156  vNetworkBase.print(strm);
157  return strm;
158  }

Member Data Documentation

◆ ATLAS_THREAD_SAFE

boost::thread_specific_ptr<MsgStream> m_msg_tls ISF_FCS::MLogging::ATLAS_THREAD_SAFE
inlinestaticprivateinherited

Do not persistify!

MsgStream instance (a std::cout like with print-out levels)

Definition at line 215 of file MLogging.h.

◆ m_defaultTreeName

const std::string VNetworkBase::m_defaultTreeName = "onnxruntime_session"
inlinestatic

Default name for the TTree to save in.

Definition at line 173 of file VNetworkBase.h.

◆ m_inputFile

std::string VNetworkBase::m_inputFile
protected

Path to the file describing the network, including filename.

Definition at line 245 of file VNetworkBase.h.

◆ m_nm

std::string ISF_FCS::MLogging::m_nm
privateinherited

Message source name.

Definition at line 211 of file MLogging.h.


The documentation for this class was generated from the following files:
checkFileSG.line
line
Definition: checkFileSG.py:75
ISF_FCS::MLogging::level
MSG::Level level() const
Retrieve output level.
Definition: MLogging.h:201
tree
TChain * tree
Definition: tile_monitor.h:30
VNetworkBase::isFile
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
Definition: VNetworkBase.cxx:117
python.SystemOfUnits.ms
int ms
Definition: SystemOfUnits.py:132
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
ISF_FCS::MLogging::MLogging
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
Definition: MLogging.cxx:91
ISF_FCS::MLogging::msg
MsgStream & msg() const
Return a stream for sending messages directly (no decoration)
Definition: MLogging.h:231
TrigConf::MSGTC::ALWAYS
@ ALWAYS
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:29
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
covarianceTool.title
title
Definition: covarianceTool.py:542
ISF_FCS::MLogging::m_nm
std::string m_nm
Message source name.
Definition: MLogging.h:211
file
TFile * file
Definition: tile_monitor.h:29
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
merge.output
output
Definition: merge.py:17
TrigConf::MSGTC::NUM_LEVELS
@ NUM_LEVELS
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:30
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
item
Definition: ItemListSvc.h:43
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
VNetworkBase::writeNetToTTree
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
TrigConf::MSGTC::NIL
@ NIL
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:22
VNetworkBase::m_inputFile
std::string m_inputFile
Path to the file describing the network, including filename.
Definition: VNetworkBase.h:245
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
VNetworkBase::print
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
Definition: VNetworkBase.cxx:73