ATLAS Offline Software
|
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration. More...
#include <TFCSGANLWTNNHandler.h>
Public Types | |
typedef std::map< std::string, std::map< std::string, double > > | NetworkInputs |
Format for network inputs. More... | |
typedef std::map< std::string, double > | NetworkOutputs |
Format for network outputs. More... | |
Public Member Functions | |
TFCSGANLWTNNHandler (const std::string &inputFile) | |
TFCSGANLWTNNHandler constructor. More... | |
TFCSGANLWTNNHandler (const TFCSGANLWTNNHandler ©_from) | |
TFCSGANLWTNNHandler copy constructor. More... | |
NetworkOutputs | compute (NetworkInputs const &inputs) const override |
Function to pass values to the network. More... | |
std::vector< std::string > | getOutputLayers () const override |
List the names of the outputs. More... | |
VNetworkLWTNN (const VNetworkLWTNN ©_from) | |
VNetworkLWTNN copy constructor. More... | |
void | writeNetToTTree (TTree &tree) override |
Save the network to a TTree. More... | |
virtual void | writeNetToTTree (TTree &tree)=0 |
Save the network to a TTree. More... | |
void | writeNetToTTree (TFile &root_file, std::string const &tree_name=m_defaultTreeName) |
Save the network to a TTree. More... | |
void | writeNetToTTree (std::string const &root_name, std::string const &tree_name=m_defaultTreeName) |
Save the network to a TTree. More... | |
void | writeNetToTTree (TFile &root_file, std::string const &tree_name=m_defaultTreeName) |
Save the network to a TTree. More... | |
void | writeNetToTTree (std::string const &root_name, std::string const &tree_name=m_defaultTreeName) |
Save the network to a TTree. More... | |
void | deleteAllButNet () override |
Get rid of any memory objects that arn't needed to run the net. More... | |
VNetworkBase () | |
VNetworkBase default constructor. More... | |
VNetworkBase (const std::string &inputFile) | |
VNetworkBase constructor. More... | |
VNetworkBase (const VNetworkBase ©_from) | |
VNetworkBase copy constructor. More... | |
bool | isFile () const |
Check if the argument inputFile is the path of a file on disk. More... | |
bool | msgLvl (const MSG::Level lvl) const |
Check whether the logging system is active at the provided verbosity level. More... | |
MsgStream & | msg () const |
Return a stream for sending messages directly (no decoration) More... | |
MsgStream & | msg (const MSG::Level lvl) const |
Return a decorated starting stream for sending messages. More... | |
MSG::Level | level () const |
Retrieve output level. More... | |
virtual void | setLevel (MSG::Level lvl) |
Update outputlevel. More... | |
Static Public Member Functions | |
static std::string | representNetworkInputs (NetworkInputs const &inputs, int maxValues=3) |
String representation of network inputs. More... | |
static std::string | representNetworkOutputs (NetworkOutputs const &outputs, int maxValues=3) |
String representation of network outputs. More... | |
static bool | isFile (std::string const &inputFile) |
Check if a string is the path of a file on disk. More... | |
static std::string | startMsg (MSG::Level lvl, const std::string &file, int line) |
Make a message to decorate the start of logging. More... | |
Static Public Attributes | |
static const std::string | m_defaultTreeName = "onnxruntime_session" |
Default name for the TTree to save in. More... | |
Protected Member Functions | |
void | setupNet () override |
Perform actions that prepare network for use. More... | |
virtual void | print (std::ostream &strm) const override |
Write a short description of this net to the string stream. More... | |
void | setupPersistedVariables () override |
Perform actions that prep data to create the net. More... | |
bool | isRootFile (std::string const &filename="") const |
Check if a string is possibly a root file path. More... | |
void | removePrefixes (NetworkOutputs &outputs) const |
Remove any common prefix from the outputs. More... | |
void | removePrefixes (std::vector< std::string > &output_names) const |
Remove any common prefix from the outputs. More... | |
Protected Attributes | |
std::string | m_json |
String containing json input file. More... | |
std::string | m_inputFile |
Path to the file describing the network, including filename. More... | |
Private Member Functions | |
ClassDefOverride (TFCSGANLWTNNHandler, 6) | |
void | fillJson (std::string const &tree_name=m_defaultTreeName) |
Fill out m_json from a file provided to the constructor. More... | |
std::string | readStringFromTTree (TTree &tree) |
Get json string from TTree. More... | |
void | writeStringToTTree (TTree &tree, std::string json_string) |
Get json string from TTree. More... | |
ClassDef (VNetworkBase, 1) | |
Private Attributes | |
std::unique_ptr< lwt::LightweightGraph > | m_lwtnn_graph |
The network that we are wrapping here. More... | |
std::vector< std::string > | m_outputLayers |
Do not persistify. More... | |
std::string * | m_input = nullptr |
Do not persistify. More... | |
std::string | m_printable_name |
Stores a printable identifyer for the net. More... | |
std::string | m_nm |
Message source name. More... | |
Static Private Attributes | |
static boost::thread_specific_ptr< MsgStream > m_msg_tls | ATLAS_THREAD_SAFE |
Do not persistify! More... | |
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Class for a neural network read in the LWTNN format. Derived from the abstract base class VNetworkBase such that it can be used interchangably with it's sibling class, TFCSONNXHandler, TFCSGANLWTNNHandler, TFCSSimpleLWTNNHandler.
Frustratingly, LightweightNeuralNetwork and LightweightGraph from lwtnn do not have a common ancestor, they could be connected with the bridge pattern, but that is more complex that currently required. This one handles the graph case, TFCSSimpleLWTNNHandler is for the non-graph case.
The LoadNetwork function has VNetworkBase as it's return type so that it can make a run-time decision about which derived class to use, based on the file name presented.
Definition at line 37 of file TFCSGANLWTNNHandler.h.
|
inherited |
Format for network inputs.
The doubles are the values to be passed into the network. Strings in the outer map identify the input node, which must corrispond to the names of the nodes as read from the description of the network found by the constructor. Strings in the inner map identify the part of the input node, for some networks these must be simple integers, in string form, as parts of nodes do not always have the ability to carry real string labels.
Definition at line 90 of file VNetworkBase.h.
|
inherited |
Format for network outputs.
The doubles are the values generated by the network. Strings identify which node this value came from, and when nodes have multiple values, are suffixed with a number to indicate which part of the node they came from. So in multi-value nodes the format becomes "<node_name>_<part_n>"
Definition at line 100 of file VNetworkBase.h.
|
explicit |
TFCSGANLWTNNHandler constructor.
Calls setupPersistedVariables and setupNet.
inputFile | file-path on disk (with file name) of a readable lwtnn file containing a json format description of the network to be constructed, or the json itself as a string. |
Definition at line 15 of file TFCSGANLWTNNHandler.cxx.
TFCSGANLWTNNHandler::TFCSGANLWTNNHandler | ( | const TFCSGANLWTNNHandler & | copy_from | ) |
TFCSGANLWTNNHandler copy constructor.
Will copy the variables that would be generated by setupPersistedVariables and setupNet.
copy_from | existing network that we are copying |
Definition at line 22 of file TFCSGANLWTNNHandler.cxx.
|
privateinherited |
|
private |
|
overridevirtual |
Function to pass values to the network.
This function hides variations in the formated needed by different network libraries, providing a uniform input and output type.
inputs | values to be evaluated by the network |
Implements VNetworkBase.
Definition at line 69 of file TFCSGANLWTNNHandler.cxx.
|
overridevirtualinherited |
Get rid of any memory objects that arn't needed to run the net.
Minimise memory usage by deleting nay inputs that are no longer required to run the compute function. Will prevent the net from being saved, if you need to call writeNetToTTree that must happen before this is called.
Implements VNetworkBase.
Definition at line 87 of file VNetworkLWTNN.cxx.
|
privateinherited |
Fill out m_json from a file provided to the constructor.
Provided the string provided as inputFile to the constructor is a known file type (root or json) this function retreives the json string itself and puts it into m_json.
tree_name | TTree name to check in when reading root files. |
Definition at line 52 of file VNetworkLWTNN.cxx.
|
overridevirtual |
List the names of the outputs.
Outputs are stored in an NetworkOutputs object which is indexed by strings. This function returns the list of all strings that will index the outputs.
Implements VNetworkBase.
Definition at line 63 of file TFCSGANLWTNNHandler.cxx.
|
inherited |
Check if the argument inputFile is the path of a file on disk.
Determines if the string that was passed to the constructor as inputFile corrisponds to tha path of a file that can be read on the disk.
Definition at line 117 of file VNetworkBase.cxx.
|
staticinherited |
Check if a string is the path of a file on disk.
Determines if a string corrisponds to tha path of a file that can be read on the disk.
inputFile | name of the pottential file |
Definition at line 119 of file VNetworkBase.cxx.
|
protectedinherited |
Check if a string is possibly a root file path.
Just checks if the string ends in .root as there are almost no reliable rules for file paths.
inputFile | name of the pottential file if blank, m_inputFile is used. |
Definition at line 101 of file VNetworkBase.cxx.
|
inlineinherited |
|
inlineinherited |
Return a stream for sending messages directly (no decoration)
Definition at line 231 of file MLogging.h.
|
inlineinherited |
Return a decorated starting stream for sending messages.
Definition at line 240 of file MLogging.h.
|
inlineinherited |
Check whether the logging system is active at the provided verbosity level.
Definition at line 222 of file MLogging.h.
|
overrideprotectedvirtualinherited |
Write a short description of this net to the string stream.
Outputs a printable name, which maybe a file name, or a note specifying that the file has been provided from memory.
strm | output parameter, to which the description will be written. |
Reimplemented from VNetworkBase.
Definition at line 44 of file VNetworkLWTNN.cxx.
|
privateinherited |
Get json string from TTree.
Given a TTree object, retrive the json string from the standard branch. This is used to retrive a network previously saved using writeNetToTTree.
tree | TTree with the json saved inside. |
Definition at line 73 of file VNetworkLWTNN.cxx.
|
protectedinherited |
Remove any common prefix from the outputs.
outputs | The outputs, changed in place. |
Definition at line 151 of file VNetworkBase.cxx.
|
protectedinherited |
Remove any common prefix from the outputs.
outputs | The output names, changed in place. |
Definition at line 144 of file VNetworkBase.cxx.
|
staticinherited |
String representation of network inputs.
Create a string that summarises a set of network inputs. Gives basic dimensions plus a few values, up to the maxValues
inputs | values to be evaluated by the network |
maxValues | maximum number of values to include in the representaiton |
Definition at line 37 of file VNetworkBase.cxx.
|
staticinherited |
String representation of network outputs.
Create a string that summarises a set of network outputs. Gives basic dimensions plus a few values, up to the maxValues
outputs | output of the network |
maxValues | maximum number of values to include in the representaiton |
Definition at line 57 of file VNetworkBase.cxx.
|
virtualinherited |
|
overrideprotectedvirtual |
Perform actions that prepare network for use.
Will be called in the streamer or class constructor after the inputs have been set (either automaically by the streamer or by setupPersistedVariables in the constructor). Does not delete any resources used.
Implements VNetworkBase.
Definition at line 33 of file TFCSGANLWTNNHandler.cxx.
|
overrideprotectedvirtualinherited |
Perform actions that prep data to create the net.
Will be called in the base class constructor before calling setupNet, but not in the streamer. It sets any variables that the sreamer would persist when saving or loading to file.
Implements VNetworkBase.
Definition at line 30 of file VNetworkLWTNN.cxx.
|
staticinherited |
Make a message to decorate the start of logging.
Print a message for the start of logging.
Definition at line 116 of file MLogging.cxx.
|
inherited |
VNetworkBase default constructor.
For use in streamers.
Definition at line 45 of file VNetworkBase.cxx.
|
explicitinherited |
VNetworkBase constructor.
Only saves inputFile to m_inputFile; Inherting classes should call setupPersistedVariables and setupNet in constructor;
inputFile | file-path on disk (with file name) of a readable file containing a description of the network to be constructed or the content of the file. |
Definition at line 59 of file VNetworkBase.cxx.
|
inherited |
VNetworkBase copy constructor.
Does not call setupPersistedVariables or setupNet but will pass on m_inputFile. Inherting classes should do whatever they need to move the variables created in the setup functions.
copy_from | existing network that we are copying |
Definition at line 71 of file VNetworkBase.cxx.
VNetworkLWTNN::VNetworkLWTNN |
VNetworkLWTNN copy constructor.
Will copy the variables that would be generated by setupPersistedVariables and setupNet. Will fail if deleteAllButNet has already been called.
copy_from | existing network that we are copying |
Definition at line 45 of file VNetworkLWTNN.cxx.
|
inherited |
Save the network to a TTree.
All data required to recreate the network object is saved into a TTree. The format is not specified.
root_name | The path of the file to save inside. |
tree_name | The name of the TTree to save inside. |
Definition at line 196 of file VNetworkBase.cxx.
|
inherited |
Save the network to a TTree.
All data required to recreate the network object is saved into a TTree. The format is not specified.
root_name | The path of the file to save inside. |
tree_name | The name of the TTree to save inside. |
Definition at line 93 of file VNetworkBase.cxx.
|
inherited |
Save the network to a TTree.
All data required to recreate the network object is saved into a TTree. The format is not specified.
root_file | The file to save inside. |
tree_name | The name of the TTree to save inside. |
Definition at line 184 of file VNetworkBase.cxx.
|
inherited |
Save the network to a TTree.
All data required to recreate the network object is saved into a TTree. The format is not specified.
root_file | The file to save inside. |
tree_name | The name of the TTree to save inside. |
Definition at line 83 of file VNetworkBase.cxx.
|
overridevirtualinherited |
Save the network to a TTree.
All data required to recreate the network object is saved into a TTree. The format is not specified.
tree | The tree to save inside. |
Implements VNetworkBase.
Definition at line 48 of file VNetworkLWTNN.cxx.
|
inherited |
Save the network to a TTree.
All data required to recreate the network object is saved into a TTree. The format is not specified.
tree | The tree to save inside. |
|
privateinherited |
Get json string from TTree.
Given a TTree object, retrive the json string from the standard branch. This is used to retrive a network previously saved using writeNetToTTree.
tree | TTree with the json saved inside. |
Definition at line 81 of file VNetworkLWTNN.cxx.
|
inlinestaticprivateinherited |
Do not persistify!
MsgStream instance (a std::cout like with print-out levels)
Definition at line 215 of file MLogging.h.
|
inlinestaticinherited |
Default name for the TTree to save in.
Definition at line 173 of file VNetworkBase.h.
|
private |
Do not persistify.
Just for backcompatability.
Definition at line 115 of file TFCSGANLWTNNHandler.h.
|
protectedinherited |
Path to the file describing the network, including filename.
Definition at line 245 of file VNetworkBase.h.
|
protectedinherited |
String containing json input file.
Is needed to save the network with writeNetToTTree but not needed to run the network with compute. Is eraised by deleteAllButNet Should be persisted.
Definition at line 84 of file VNetworkLWTNN.h.
|
private |
The network that we are wrapping here.
Definition at line 105 of file TFCSGANLWTNNHandler.h.
|
privateinherited |
Message source name.
Definition at line 211 of file MLogging.h.
|
private |
Do not persistify.
List of names that index the output layer.
Definition at line 110 of file TFCSGANLWTNNHandler.h.
|
privateinherited |
Stores a printable identifyer for the net.
Not unique.
Definition at line 144 of file VNetworkLWTNN.h.