ATLAS Offline Software
Loading...
Searching...
No Matches
VNetworkBase.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include <iostream>
7
8// For streamer
9#include "TBuffer.h"
10
11// For reading and writing to root
12#include "TFile.h"
13#include "TTree.h"
14
15// Probably called by a streamer.
17
18// record the input file and provided it's not empty call SetUp
19VNetworkBase::VNetworkBase(const std::string &inputFile)
20 : m_inputFile(inputFile) {
21 ATH_MSG_DEBUG("Constructor called with inputFile");
22};
23
24// No setupPersistedVariables or setupNet here!
26 : MLogging(),
27 m_inputFile (copy_from.m_inputFile)
28{
29};
30
31// Nothing is needed from the destructor right now.
32// We don't use new anywhere, so the whole thing should clean
33// itself up.
35
36std::string
38 int maxValues) {
39 std::string representation =
40 "NetworkInputs, outer size " + std::to_string(inputs.size());
41 int valuesIncluded = 0;
42 for (const auto &outer : inputs) {
43 representation += "\n key->" + outer.first + "; ";
44 for (const auto &inner : outer.second) {
45 representation += inner.first + "=" + std::to_string(inner.second) + ", ";
46 ++valuesIncluded;
47 if (valuesIncluded > maxValues)
48 break;
49 };
50 if (valuesIncluded > maxValues)
51 break;
52 };
53 representation += "\n";
54 return representation;
55};
56
58 VNetworkBase::NetworkOutputs const &outputs, int maxValues) {
59 std::string representation =
60 "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n";
61 int valuesIncluded = 0;
62 for (const auto &item : outputs) {
63 representation += item.first + "=" + std::to_string(item.second) + ", ";
64 ++valuesIncluded;
65 if (valuesIncluded > maxValues)
66 break;
67 };
68 representation += "\n";
69 return representation;
70};
72// this is also used for the stream operator
73void VNetworkBase::print(std::ostream &strm) const {
74 if (m_inputFile.empty()) {
75 ATH_MSG_DEBUG("Making a network without a named inputFile");
76 strm << "Unknown network";
77 } else {
78 ATH_MSG_DEBUG("Making a network with input file " << m_inputFile);
79 strm << m_inputFile;
80 };
81};
82
83void VNetworkBase::writeNetToTTree(TFile &root_file,
84 std::string const &tree_name) {
85 ATH_MSG_DEBUG("Making tree name " << tree_name);
86 root_file.cd();
87 const std::string title = "onnxruntime saved network";
88 TTree tree(tree_name.c_str(), title.c_str());
89 this->writeNetToTTree(tree);
90 root_file.Write();
91};
92
93void VNetworkBase::writeNetToTTree(std::string const &root_name,
94 std::string const &tree_name) {
95 ATH_MSG_DEBUG("Making or updating file name " << root_name);
96 TFile root_file(root_name.c_str(), "UPDATE");
97 this->writeNetToTTree(root_file, tree_name);
98 root_file.Close();
99};
100
101bool VNetworkBase::isRootFile(std::string const &filename) const {
102 const std::string *to_check = &filename;
103 if (filename.length() == 0) {
104 to_check = &this->m_inputFile;
105 ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile);
106 };
107 const std::string ending = ".root";
108 const int ending_len = ending.length();
109 const int filename_len = to_check->length();
110 if (filename_len < ending_len) {
111 return false;
112 }
113 return (0 ==
114 to_check->compare(filename_len - ending_len, ending_len, ending));
115};
116
117bool VNetworkBase::isFile() const { return isFile(m_inputFile); };
118
119bool VNetworkBase::isFile(std::string const &inputFile) {
120 if (FILE *file = std::fopen(inputFile.c_str(), "r")) {
121 std::fclose(file);
122 return true;
123 } else {
124 return false;
125 };
126};
127
128namespace {
129int GetPrefixLength(const std::vector<std::string>& strings) {
130 const std::string first = strings[0];
131 int length = first.length();
132 for (const std::string& this_string : strings) {
133 for (int i = 0; i < length; i++) {
134 if (first[i] != this_string[i]) {
135 length = i;
136 break;
137 }
138 }
139 }
140 return length;
141};
142} // namespace
143
145 std::vector<std::string> &output_names) const {
146 const int length = GetPrefixLength(output_names);
147 for (long unsigned int i = 0; i < output_names.size(); i++)
148 output_names[i] = output_names[i].substr(length);
149};
150
152 std::vector<std::string> output_layers;
153 for (auto const &output : outputs)
154 output_layers.push_back(output.first);
155 const int length = GetPrefixLength(output_layers);
156 for (std::string layer_name : output_layers) {
157 // remove this output
158 auto nodeHandle = outputs.extract(layer_name);
159 // change the key
160 nodeHandle.key() = layer_name.substr(length);
161 // replace the output
162 outputs.insert(std::move(nodeHandle));
163 }
164};
#define ATH_MSG_DEBUG(x)
double length(const pvec &v)
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
Definition MLogging.cxx:91
virtual ~VNetworkBase()
VNetworkBase()
VNetworkBase default constructor.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::string m_inputFile
Path to the file describing the network, including filename.
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
TChain * tree
TFile * file