39 std::string representation =
40 "NetworkInputs, outer size " + std::to_string(inputs.size());
41 int valuesIncluded = 0;
42 for (
const auto &outer : inputs) {
43 representation +=
"\n key->" + outer.first +
"; ";
44 for (
const auto &inner : outer.second) {
45 representation += inner.first +
"=" + std::to_string(inner.second) +
", ";
47 if (valuesIncluded > maxValues)
50 if (valuesIncluded > maxValues)
53 representation +=
"\n";
54 return representation;
59 std::string representation =
60 "NetworkOutputs, size " + std::to_string(outputs.size()) +
"; \n";
61 int valuesIncluded = 0;
62 for (
const auto &item : outputs) {
63 representation += item.first +
"=" + std::to_string(item.second) +
", ";
65 if (valuesIncluded > maxValues)
68 representation +=
"\n";
69 return representation;
76 strm <<
"Unknown network";
84 std::string
const &tree_name) {
87 const std::string title =
"onnxruntime saved network";
88 TTree
tree(tree_name.c_str(), title.c_str());
94 std::string
const &tree_name) {
96 TFile root_file(root_name.c_str(),
"UPDATE");
102 const std::string *to_check = &filename;
103 if (filename.length() == 0) {
107 const std::string ending =
".root";
108 const int ending_len = ending.length();
109 const int filename_len = to_check->length();
110 if (filename_len < ending_len) {
114 to_check->compare(filename_len - ending_len, ending_len, ending));
120 if (FILE *
file = std::fopen(inputFile.c_str(),
"r")) {
129int GetPrefixLength(
const std::vector<std::string>& strings) {
130 const std::string first = strings[0];
131 int length = first.length();
132 for (
const std::string& this_string : strings) {
133 for (
int i = 0; i <
length; i++) {
134 if (first[i] != this_string[i]) {
145 std::vector<std::string> &output_names)
const {
146 const int length = GetPrefixLength(output_names);
147 for (
long unsigned int i = 0; i < output_names.size(); i++)
148 output_names[i] = output_names[i].substr(
length);
152 std::vector<std::string> output_layers;
153 for (
auto const &output : outputs)
154 output_layers.push_back(output.first);
155 const int length = GetPrefixLength(output_layers);
156 for (std::string layer_name : output_layers) {
158 auto nodeHandle = outputs.extract(layer_name);
160 nodeHandle.key() = layer_name.substr(
length);
162 outputs.insert(std::move(nodeHandle));
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
VNetworkBase()
VNetworkBase default constructor.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::string m_inputFile
Path to the file describing the network, including filename.
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
bool isFile() const
Check if the argument inputFile is the path of a file on disk.