ATLAS Offline Software
Loading...
Searching...
No Matches
VNetworkBase.h
Go to the documentation of this file.
1
18#ifndef VNETWORKBASE_H
19#define VNETWORKBASE_H
20
21// For conversion to ostream
22#include <iostream>
23#include <map>
24
25// For reading and writing
26#include "TFile.h"
27#include "TTree.h"
28
29// For messaging
31
39public:
46
47 // explicit = Don't let this do implicit type conversion
59 explicit VNetworkBase(const std::string &inputFile);
60
71 VNetworkBase(const VNetworkBase &copy_from);
72
73 // virtual destructor, to ensure that it is always called, even
74 // when a base class is deleted via a pointer to a derived class
75 virtual ~VNetworkBase();
76
77 // same as for lwtnn
90 typedef std::map<std::string, std::map<std::string, double>> NetworkInputs;
100 typedef std::map<std::string, double> NetworkOutputs;
101
112 static std::string representNetworkInputs(NetworkInputs const &inputs,
113 int maxValues = 3);
114
125 static std::string representNetworkOutputs(NetworkOutputs const &outputs,
126 int maxValues = 3);
127
128 // pure virtual, derived classes must impement this
141 virtual NetworkOutputs compute(NetworkInputs const &inputs) const = 0;
142
143 // Conversion to ostream
144 // It's not possible to have a virtual friend function
145 // so instead, have a friend function that calls a virtual protected method
154 friend std::ostream &operator<<(std::ostream &strm,
155 const VNetworkBase &vNetworkBase) {
156 vNetworkBase.print(strm);
157 return strm;
158 }
159
168 virtual void writeNetToTTree(TTree &tree) = 0;
169
173 inline static const std::string m_defaultTreeName = "onnxruntime_session";
174
184 void writeNetToTTree(TFile &root_file,
185 std::string const &tree_name = m_defaultTreeName);
186
196 void writeNetToTTree(std::string const &root_name,
197 std::string const &tree_name = m_defaultTreeName);
198
207 virtual std::vector<std::string> getOutputLayers() const = 0;
208
218 static bool isFile(std::string const &inputFile);
219
229 bool isFile() const;
230
239 virtual void deleteAllButNet() = 0;
240
241protected:
245 std::string m_inputFile;
246
256 virtual void setupPersistedVariables() = 0;
257
267 virtual void setupNet() = 0;
268
277 virtual void print(std::ostream &strm) const;
278
289 bool isRootFile(std::string const &filename = "") const;
290
296 void removePrefixes(NetworkOutputs &outputs) const;
297
303 void removePrefixes(std::vector<std::string> &output_names) const;
304
305private:
306 // Suppling a ClassDef for writing to file.
308};
309
310#endif
Cut down AthMessaging.
Definition MLogging.h:176
virtual ~VNetworkBase()
static const std::string m_defaultTreeName
Default name for the TTree to save in.
VNetworkBase()
VNetworkBase default constructor.
virtual std::vector< std::string > getOutputLayers() const =0
List the names of the outputs.
virtual NetworkOutputs compute(NetworkInputs const &inputs) const =0
Function to pass values to the network.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
friend std::ostream & operator<<(std::ostream &strm, const VNetworkBase &vNetworkBase)
Put-to operator to facilitate printing.
std::string m_inputFile
Path to the file describing the network, including filename.
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
virtual void deleteAllButNet()=0
Get rid of any memory objects that arn't needed to run the net.
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
ClassDef(VNetworkBase, 1)
virtual void setupNet()=0
Perform actions that prepare network for use.
virtual void setupPersistedVariables()=0
Perform actions that prep data to create the net.
TChain * tree