ATLAS Offline Software
VNetworkBase.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 #include <iostream>
7 
8 // For streamer
9 #include "TBuffer.h"
10 
11 // For reading and writing to root
12 #include "TFile.h"
13 #include "TTree.h"
14 
15 // Probably called by a streamer.
16 VNetworkBase::VNetworkBase() : m_inputFile("unknown"){};
17 
18 // record the input file and provided it's not empty call SetUp
20  : m_inputFile(inputFile) {
21  ATH_MSG_DEBUG("Constructor called with inputFile");
22 };
23 
24 // No setupPersistedVariables or setupNet here!
26  : MLogging(),
27  m_inputFile (copy_from.m_inputFile)
28 {
29 };
30 
31 // Nothing is needed from the destructor right now.
32 // We don't use new anywhere, so the whole thing should clean
33 // itself up.
35 
36 std::string
38  int maxValues) {
39  std::string representation =
40  "NetworkInputs, outer size " + std::to_string(inputs.size());
41  int valuesIncluded = 0;
42  for (const auto &outer : inputs) {
43  representation += "\n key->" + outer.first + "; ";
44  for (const auto &inner : outer.second) {
45  representation += inner.first + "=" + std::to_string(inner.second) + ", ";
46  ++valuesIncluded;
47  if (valuesIncluded > maxValues)
48  break;
49  };
50  if (valuesIncluded > maxValues)
51  break;
52  };
53  representation += "\n";
54  return representation;
55 };
56 
58  VNetworkBase::NetworkOutputs const &outputs, int maxValues) {
59  std::string representation =
60  "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n";
61  int valuesIncluded = 0;
62  for (const auto &item : outputs) {
63  representation += item.first + "=" + std::to_string(item.second) + ", ";
64  ++valuesIncluded;
65  if (valuesIncluded > maxValues)
66  break;
67  };
68  representation += "\n";
69  return representation;
70 };
71 
72 // this is also used for the stream operator
73 void VNetworkBase::print(std::ostream &strm) const {
74  if (m_inputFile.empty()) {
75  ATH_MSG_DEBUG("Making a network without a named inputFile");
76  strm << "Unknown network";
77  } else {
78  ATH_MSG_DEBUG("Making a network with input file " << m_inputFile);
79  strm << m_inputFile;
80  };
81 };
82 
83 void VNetworkBase::writeNetToTTree(TFile &root_file,
84  std::string const &tree_name) {
85  ATH_MSG_DEBUG("Making tree name " << tree_name);
86  root_file.cd();
87  const std::string title = "onnxruntime saved network";
88  TTree tree(tree_name.c_str(), title.c_str());
89  this->writeNetToTTree(tree);
90  root_file.Write();
91 };
92 
93 void VNetworkBase::writeNetToTTree(std::string const &root_name,
94  std::string const &tree_name) {
95  ATH_MSG_DEBUG("Making or updating file name " << root_name);
96  TFile root_file(root_name.c_str(), "UPDATE");
97  this->writeNetToTTree(root_file, tree_name);
98  root_file.Close();
99 };
100 
101 bool VNetworkBase::isRootFile(std::string const &filename) const {
102  const std::string *to_check = &filename;
103  if (filename.length() == 0) {
104  to_check = &this->m_inputFile;
105  ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile);
106  };
107  const std::string ending = ".root";
108  const int ending_len = ending.length();
109  const int filename_len = to_check->length();
110  if (filename_len < ending_len) {
111  return false;
112  }
113  return (0 ==
114  to_check->compare(filename_len - ending_len, ending_len, ending));
115 };
116 
117 bool VNetworkBase::isFile() const { return isFile(m_inputFile); };
118 
119 bool VNetworkBase::isFile(std::string const &inputFile) {
120  if (FILE *file = std::fopen(inputFile.c_str(), "r")) {
121  std::fclose(file);
122  return true;
123  } else {
124  return false;
125  };
126 };
127 
128 namespace {
129 int GetPrefixLength(const std::vector<std::string>& strings) {
130  const std::string first = strings[0];
131  int length = first.length();
132  for (const std::string& this_string : strings) {
133  for (int i = 0; i < length; i++) {
134  if (first[i] != this_string[i]) {
135  length = i;
136  break;
137  }
138  }
139  }
140  return length;
141 };
142 } // namespace
143 
145  std::vector<std::string> &output_names) const {
146  const int length = GetPrefixLength(output_names);
147  for (long unsigned int i = 0; i < output_names.size(); i++)
148  output_names[i] = output_names[i].substr(length);
149 };
150 
152  std::vector<std::string> output_layers;
153  for (auto const &output : outputs)
154  output_layers.push_back(output.first);
155  const int length = GetPrefixLength(output_layers);
156  for (std::string layer_name : output_layers) {
157  // remove this output
158  auto nodeHandle = outputs.extract(layer_name);
159  // change the key
160  nodeHandle.key() = layer_name.substr(length);
161  // replace the output
162  outputs.insert(std::move(nodeHandle));
163  }
164 };
VNetworkBase::NetworkOutputs
std::map< std::string, double > NetworkOutputs
Format for network outputs.
Definition: VNetworkBase.h:100
VNetworkBase.h
VNetworkBase::VNetworkBase
VNetworkBase()
VNetworkBase default constructor.
Definition: VNetworkBase.cxx:16
VNetworkBase::representNetworkOutputs
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
Definition: VNetworkBase.cxx:57
tree
TChain * tree
Definition: tile_monitor.h:30
VNetworkBase::NetworkInputs
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
Definition: VNetworkBase.h:90
VNetworkBase::isFile
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
Definition: VNetworkBase.cxx:117
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
covarianceTool.title
title
Definition: covarianceTool.py:542
file
TFile * file
Definition: tile_monitor.h:29
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
VNetworkBase::removePrefixes
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
Definition: VNetworkBase.cxx:151
merge.output
output
Definition: merge.py:17
VNetworkBase
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: VNetworkBase.h:38
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
VNetworkBase::representNetworkInputs
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
Definition: VNetworkBase.cxx:37
item
Definition: ItemListSvc.h:43
VNetworkBase::~VNetworkBase
virtual ~VNetworkBase()
Definition: VNetworkBase.cxx:34
DeMoScan.first
bool first
Definition: DeMoScan.py:536
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
VNetworkBase::isRootFile
bool isRootFile(std::string const &filename="") const
Check if a string is possibly a root file path.
Definition: VNetworkBase.cxx:101
VNetworkBase::writeNetToTTree
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
VNetworkBase::m_inputFile
std::string m_inputFile
Path to the file describing the network, including filename.
Definition: VNetworkBase.h:245
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
VNetworkBase::print
virtual void print(std::ostream &strm) const
Write a short description of this net to the string stream.
Definition: VNetworkBase.cxx:73