ATLAS Offline Software
TFCSNetworkFactory.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
10 
11 #include <boost/property_tree/ptree.hpp>
12 #include <fstream> // For checking if files exist
13 #include <stdexcept>
14 
15 // For messaging
17 using ISF_FCS::MLogging;
18 
21  const std::string ending = ".*";
22  const int ending_len = ending.length();
23  const int filename_len = filename.length();
24  if (filename_len < ending_len) {
25  ATH_MSG_NOCLASS(logger, "Filename is implausably short.");
26  } else if (0 ==
27  filename.compare(filename_len - ending_len, ending_len, ending)) {
28  ATH_MSG_NOCLASS(logger, "Filename ends in glob.");
29  // Remove the glob
30  filename.pop_back();
31  if (std::filesystem::exists(filename + "onnx")) {
32  filename += "onnx";
33  } else if (std::filesystem::exists(filename + "json")) {
34  filename += std::string("json");
35  } else {
36  throw std::invalid_argument("No file found matching globbed filename " +
37  filename);
38  };
39  };
40 };
41 
42 bool TFCSNetworkFactory::isOnnxFile(std::string const &filename) {
44  const std::string ending = ".onnx";
45  const int ending_len = ending.length();
46  const int filename_len = filename.length();
47  bool is_onnx;
48  if (filename_len < ending_len) {
49  is_onnx = false;
50  } else {
51  is_onnx =
52  (0 == filename.compare(filename_len - ending_len, ending_len, ending));
53  };
54  return is_onnx;
55 };
56 
57 std::unique_ptr<VNetworkBase>
58 TFCSNetworkFactory::create(std::vector<char> const &input) {
60  ATH_MSG_NOCLASS(logger, "Directly creating ONNX network from bytes length "
61  << input.size());
62  std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input));
63  return created;
64 };
65 
66 std::unique_ptr<VNetworkBase> TFCSNetworkFactory::create(std::string input) {
70  ATH_MSG_NOCLASS(logger, "Creating ONNX network from file ..."
71  << input.substr(input.length() - 10));
72  std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input));
73  return created;
74  } else {
75  try {
76  std::unique_ptr<VNetworkBase> created(new TFCSSimpleLWTNNHandler(input));
78  "Succedeed in creating LWTNN nn from string starting "
79  << input.substr(0, 10));
80  return created;
81  } catch (const boost::property_tree::ptree_bad_path &e) {
82  // If we get this error, it was actually a graph, not a NeuralNetwork
83  std::unique_ptr<VNetworkBase> created(new TFCSGANLWTNNHandler(input));
84  ATH_MSG_NOCLASS(logger, "Succedeed in creating LWTNN graph from string");
85  return created;
86  };
87  };
88 };
89 
90 std::unique_ptr<VNetworkBase> TFCSNetworkFactory::create(std::string input,
91  bool graph_form) {
95  ATH_MSG_NOCLASS(logger, "Creating ONNX network from file ..."
96  << input.substr(input.length() - 10));
97  std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input));
98  return created;
99  } else if (graph_form) {
100  ATH_MSG_NOCLASS(logger, "Creating LWTNN graph from string");
101  std::unique_ptr<VNetworkBase> created(new TFCSGANLWTNNHandler(input));
102  return created;
103  } else {
104  std::unique_ptr<VNetworkBase> created(new TFCSSimpleLWTNNHandler(input));
105  ATH_MSG_NOCLASS(logger, "Creating LWTNN nn from string");
106  return created;
107  };
108 };
109 
110 std::unique_ptr<VNetworkBase>
111 TFCSNetworkFactory::create(std::vector<char> const &vector_input,
112  std::string string_input) {
114  ATH_MSG_NOCLASS(logger, "Given both bytes and a string to create an nn.");
115  resolveGlobs(string_input);
116  if (vector_input.size() > 0) {
118  "Bytes contains data, size=" << vector_input.size()
119  << ", creating from bytes.");
120  return create(vector_input);
121  } else if (string_input.length() > 0) {
122  ATH_MSG_NOCLASS(logger, "No data in bytes, string contains data, "
123  << "creating from string.");
124  return create(string_input);
125  } else {
126  throw std::invalid_argument(
127  "Neither vector_input nor string_input contain data");
128  };
129 };
130 
131 std::unique_ptr<VNetworkBase>
132 TFCSNetworkFactory::create(std::vector<char> const &vector_input,
133  std::string string_input, bool graph_form) {
136  logger,
137  "Given both bytes, a string and graph form sepcified to create an nn.");
138  resolveGlobs(string_input);
139  if (vector_input.size() > 0) {
141  "Bytes contains data, size=" << vector_input.size()
142  << ", creating from bytes.");
143  return create(vector_input);
144  } else if (string_input.length() > 0) {
145  ATH_MSG_NOCLASS(logger, "No data in bytes, string contains data, "
146  << "creating from string.");
147  return create(string_input, graph_form);
148  } else {
149  throw std::invalid_argument(
150  "Neither vector_input nor string_input contain data");
151  };
152 };
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
VNetworkBase.h
ISF_FCS::MLogging
Cut down AthMessaging.
Definition: MLogging.h:176
TFCSONNXHandler.h
TFCSGANLWTNNHandler
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: TFCSGANLWTNNHandler.h:37
TFCSNetworkFactory::resolveGlobs
static void resolveGlobs(std::string &filename)
If the filepath ends in .
Definition: TFCSNetworkFactory.cxx:19
VNetworkBase::isFile
bool isFile() const
Check if the argument inputFile is the path of a file on disk.
Definition: VNetworkBase.cxx:117
TFCSSimpleLWTNNHandler
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: TFCSSimpleLWTNNHandler.h:34
TFCSNetworkFactory::create
static std::unique_ptr< VNetworkBase > create(std::string input)
Given a string, make a network.
Definition: TFCSNetworkFactory.cxx:66
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
TFCSNetworkFactory::isOnnxFile
static bool isOnnxFile(std::string const &filename)
Check if a filename seems to be an onnx file.
Definition: TFCSNetworkFactory.cxx:42
ATH_MSG_NOCLASS
#define ATH_MSG_NOCLASS(logger_name, x)
Definition: MLogging.h:52
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
python.dummyaccess.exists
def exists(filename)
Definition: dummyaccess.py:9
TFCSNetworkFactory.h
TFCSSimpleLWTNNHandler.h
MLogging.h
TFCSGANLWTNNHandler.h
python.iconfTool.gui.pad.logger
logger
Definition: pad.py:14
TFCSONNXHandler
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: TFCSONNXHandler.h:43