40 std::string representation =
41 "NetworkInputs, outer size " + std::to_string(inputs.size());
42 int valuesIncluded = 0;
43 for (
const auto &outer : inputs) {
44 representation +=
"\n key->" + outer.first +
"; ";
45 for (
const auto &inner : outer.second) {
46 representation += inner.first +
"=" + std::format(
"{:f}",inner.second) +
", ";
48 if (valuesIncluded > maxValues)
51 if (valuesIncluded > maxValues)
54 representation +=
"\n";
55 return representation;
60 std::string representation =
61 "NetworkOutputs, size " + std::to_string(outputs.size()) +
"; \n";
62 int valuesIncluded = 0;
63 for (
const auto &item : outputs) {
64 representation += item.first +
"=" + std::format(
"{:f}", item.second) +
", ";
66 if (valuesIncluded > maxValues)
69 representation +=
"\n";
70 return representation;
77 strm <<
"Unknown network";
85 std::string
const &tree_name) {
88 const std::string title =
"onnxruntime saved network";
89 TTree
tree(tree_name.c_str(), title.c_str());
95 std::string
const &tree_name) {
97 TFile root_file(root_name.c_str(),
"UPDATE");
103 const std::string *to_check = &filename;
104 if (filename.length() == 0) {
108 const std::string ending =
".root";
109 const int ending_len = ending.length();
110 const int filename_len = to_check->length();
111 if (filename_len < ending_len) {
115 to_check->compare(filename_len - ending_len, ending_len, ending));
121 if (FILE *
file = std::fopen(inputFile.c_str(),
"r")) {
130int GetPrefixLength(
const std::vector<std::string>& strings) {
131 const std::string first = strings[0];
132 int length = first.length();
133 for (
const std::string& this_string : strings) {
134 for (
int i = 0; i <
length; i++) {
135 if (first[i] != this_string[i]) {
146 std::vector<std::string> &output_names)
const {
147 const int length = GetPrefixLength(output_names);
148 for (
long unsigned int i = 0; i < output_names.size(); i++)
149 output_names[i] = output_names[i].substr(
length);
153 std::vector<std::string> output_layers;
154 for (
auto const &output : outputs)
155 output_layers.push_back(output.first);
156 const int length = GetPrefixLength(output_layers);
157 for (std::string layer_name : output_layers) {
159 auto nodeHandle = outputs.extract(layer_name);
161 nodeHandle.key() = layer_name.substr(
length);
163 outputs.insert(std::move(nodeHandle));
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
VNetworkBase()
VNetworkBase default constructor.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::string m_inputFile
Path to the file describing the network, including filename.
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
bool isFile() const
Check if the argument inputFile is the path of a file on disk.