ATLAS Offline Software
Loading...
Searching...
No Matches
VNetworkBase.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include <iostream>
7#include <format>
8
9// For streamer
10#include "TBuffer.h"
11
12// For reading and writing to root
13#include "TFile.h"
14#include "TTree.h"
15
16// Probably called by a streamer.
18
19// record the input file and provided it's not empty call SetUp
20VNetworkBase::VNetworkBase(const std::string &inputFile)
21 : m_inputFile(inputFile) {
22 ATH_MSG_DEBUG("Constructor called with inputFile");
23};
24
25// No setupPersistedVariables or setupNet here!
27 : MLogging(),
28 m_inputFile (copy_from.m_inputFile)
29{
30};
31
32// Nothing is needed from the destructor right now.
33// We don't use new anywhere, so the whole thing should clean
34// itself up.
36
37std::string
39 int maxValues) {
40 std::string representation =
41 "NetworkInputs, outer size " + std::to_string(inputs.size());
42 int valuesIncluded = 0;
43 for (const auto &outer : inputs) {
44 representation += "\n key->" + outer.first + "; ";
45 for (const auto &inner : outer.second) {
46 representation += inner.first + "=" + std::format("{:f}",inner.second) + ", ";
47 ++valuesIncluded;
48 if (valuesIncluded > maxValues)
49 break;
50 };
51 if (valuesIncluded > maxValues)
52 break;
53 };
54 representation += "\n";
55 return representation;
56};
57
59 VNetworkBase::NetworkOutputs const &outputs, int maxValues) {
60 std::string representation =
61 "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n";
62 int valuesIncluded = 0;
63 for (const auto &item : outputs) {
64 representation += item.first + "=" + std::format("{:f}", item.second) + ", ";
65 ++valuesIncluded;
66 if (valuesIncluded > maxValues)
67 break;
68 };
69 representation += "\n";
70 return representation;
71};
72
73// this is also used for the stream operator
74void VNetworkBase::print(std::ostream &strm) const {
75 if (m_inputFile.empty()) {
76 ATH_MSG_DEBUG("Making a network without a named inputFile");
77 strm << "Unknown network";
78 } else {
79 ATH_MSG_DEBUG("Making a network with input file " << m_inputFile);
80 strm << m_inputFile;
81 };
82};
83
84void VNetworkBase::writeNetToTTree(TFile &root_file,
85 std::string const &tree_name) {
86 ATH_MSG_DEBUG("Making tree name " << tree_name);
87 root_file.cd();
88 const std::string title = "onnxruntime saved network";
89 TTree tree(tree_name.c_str(), title.c_str());
90 this->writeNetToTTree(tree);
91 root_file.Write();
92};
93
94void VNetworkBase::writeNetToTTree(std::string const &root_name,
95 std::string const &tree_name) {
96 ATH_MSG_DEBUG("Making or updating file name " << root_name);
97 TFile root_file(root_name.c_str(), "UPDATE");
98 this->writeNetToTTree(root_file, tree_name);
99 root_file.Close();
100};
101
102bool VNetworkBase::isRootFile(std::string const &filename) const {
103 const std::string *to_check = &filename;
104 if (filename.length() == 0) {
105 to_check = &this->m_inputFile;
106 ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile);
107 };
108 const std::string ending = ".root";
109 const int ending_len = ending.length();
110 const int filename_len = to_check->length();
111 if (filename_len < ending_len) {
112 return false;
113 }
114 return (0 ==
115 to_check->compare(filename_len - ending_len, ending_len, ending));
116};
117
118bool VNetworkBase::isFile() const { return isFile(m_inputFile); };
119
120bool VNetworkBase::isFile(std::string const &inputFile) {
121 if (FILE *file = std::fopen(inputFile.c_str(), "r")) {
122 std::fclose(file);
123 return true;
124 } else {
125 return false;
126 };
127};
128
129namespace {
130int GetPrefixLength(const std::vector<std::string>& strings) {
131 const std::string first = strings[0];
132 int length = first.length();
133 for (const std::string& this_string : strings) {
134 for (int i = 0; i < length; i++) {
135 if (first[i] != this_string[i]) {
136 length = i;
137 break;
138 }
139 }
140 }
141 return length;
142};
143} // namespace
144
146 std::vector<std::string> &output_names) const {
147 const int length = GetPrefixLength(output_names);
148 for (long unsigned int i = 0; i < output_names.size(); i++)
149 output_names[i] = output_names[i].substr(length);
150};
151
153 std::vector<std::string> output_layers;
154 for (auto const &output : outputs)
155 output_layers.push_back(output.first);
156 const int length = GetPrefixLength(output_layers);
157 for (std::string layer_name : output_layers) {
158 // remove this output
159 auto nodeHandle = outputs.extract(layer_name);
160 // change the key
161 nodeHandle.key() = layer_name.substr(length);
162 // replace the output
163 outputs.insert(std::move(nodeHandle));
164 }
165};
#define ATH_MSG_DEBUG(x)
double length(const pvec &v)
MLogging(const std::string &name="ISF_FastCaloSimEvent")
Constructor.
Definition MLogging.cxx:91
virtual ~VNetworkBase()
VNetworkBase()
VNetworkBase default constructor.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::string m_inputFile
Path to the file describing the network, including filename.
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
TChain * tree
TFile * file