ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSNetworkFactory.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
10
11#include <boost/property_tree/ptree.hpp>
12#include <fstream> // For checking if files exist
13#include <stdexcept>
14
15// For messaging
18
19void TFCSNetworkFactory::resolveGlobs(std::string &filename) {
21 const std::string ending = ".*";
22 const int ending_len = ending.length();
23 const int filename_len = filename.length();
24 if (filename_len < ending_len) {
25 ATH_MSG_NOCLASS(logger, "Filename is implausably short.");
26 } else if (0 ==
27 filename.compare(filename_len - ending_len, ending_len, ending)) {
28 ATH_MSG_NOCLASS(logger, "Filename ends in glob.");
29 // Remove the glob
30 filename.pop_back();
31 if (std::filesystem::exists(filename + "onnx")) {
32 filename += "onnx";
33 } else if (std::filesystem::exists(filename + "json")) {
34 filename += std::string("json");
35 } else {
36 throw std::invalid_argument("No file found matching globbed filename " +
37 filename);
38 };
39 };
40};
41
42bool TFCSNetworkFactory::isOnnxFile(std::string const &filename) {
44 const std::string ending = ".onnx";
45 const int ending_len = ending.length();
46 const int filename_len = filename.length();
47 bool is_onnx;
48 if (filename_len < ending_len) {
49 is_onnx = false;
50 } else {
51 is_onnx =
52 (0 == filename.compare(filename_len - ending_len, ending_len, ending));
53 };
54 return is_onnx;
55};
56
57std::unique_ptr<VNetworkBase>
58TFCSNetworkFactory::create(std::vector<char> const &input) {
60 ATH_MSG_NOCLASS(logger, "Directly creating ONNX network from bytes length "
61 << input.size());
62 std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input));
63 return created;
64};
65
66std::unique_ptr<VNetworkBase> TFCSNetworkFactory::create(std::string input) {
68 resolveGlobs(input);
69 if (VNetworkBase::isFile(input) && isOnnxFile(input)) {
70 ATH_MSG_NOCLASS(logger, "Creating ONNX network from file ..."
71 << input.substr(input.length() - 10));
72 std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input));
73 return created;
74 } else {
75 try {
76 std::unique_ptr<VNetworkBase> created(new TFCSSimpleLWTNNHandler(input));
78 "Succedeed in creating LWTNN nn from string starting "
79 << input.substr(0, 10));
80 return created;
81 } catch (const boost::property_tree::ptree_bad_path &e) {
82 // If we get this error, it was actually a graph, not a NeuralNetwork
83 std::unique_ptr<VNetworkBase> created(new TFCSGANLWTNNHandler(input));
84 ATH_MSG_NOCLASS(logger, "Succedeed in creating LWTNN graph from string");
85 return created;
86 };
87 };
88};
89
90std::unique_ptr<VNetworkBase> TFCSNetworkFactory::create(std::string input,
91 bool graph_form) {
93 resolveGlobs(input);
94 if (VNetworkBase::isFile(input) && isOnnxFile(input)) {
95 ATH_MSG_NOCLASS(logger, "Creating ONNX network from file ..."
96 << input.substr(input.length() - 10));
97 std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input));
98 return created;
99 } else if (graph_form) {
100 ATH_MSG_NOCLASS(logger, "Creating LWTNN graph from string");
101 std::unique_ptr<VNetworkBase> created(new TFCSGANLWTNNHandler(input));
102 return created;
103 } else {
104 std::unique_ptr<VNetworkBase> created(new TFCSSimpleLWTNNHandler(input));
105 ATH_MSG_NOCLASS(logger, "Creating LWTNN nn from string");
106 return created;
107 };
108};
109
110std::unique_ptr<VNetworkBase>
111TFCSNetworkFactory::create(std::vector<char> const &vector_input,
112 std::string string_input) {
114 ATH_MSG_NOCLASS(logger, "Given both bytes and a string to create an nn.");
115 resolveGlobs(string_input);
116 if (vector_input.size() > 0) {
118 "Bytes contains data, size=" << vector_input.size()
119 << ", creating from bytes.");
120 return create(vector_input);
121 } else if (string_input.length() > 0) {
122 ATH_MSG_NOCLASS(logger, "No data in bytes, string contains data, "
123 << "creating from string.");
124 return create(std::move(string_input));
125 } else {
126 throw std::invalid_argument(
127 "Neither vector_input nor string_input contain data");
128 };
129};
130
131std::unique_ptr<VNetworkBase>
132TFCSNetworkFactory::create(std::vector<char> const &vector_input,
133 std::string string_input, bool graph_form) {
136 logger,
137 "Given both bytes, a string and graph form sepcified to create an nn.");
138 resolveGlobs(string_input);
139 if (vector_input.size() > 0) {
141 "Bytes contains data, size=" << vector_input.size()
142 << ", creating from bytes.");
143 return create(vector_input);
144 } else if (string_input.length() > 0) {
145 ATH_MSG_NOCLASS(logger, "No data in bytes, string contains data, "
146 << "creating from string.");
147 return create(std::move(string_input), graph_form);
148 } else {
149 throw std::invalid_argument(
150 "Neither vector_input nor string_input contain data");
151 };
152};
#define ATH_MSG_NOCLASS(logger_name, x)
Definition MLogging.h:52
Cut down AthMessaging.
Definition MLogging.h:176
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
static bool isOnnxFile(std::string const &filename)
Check if a filename seems to be an onnx file.
static std::unique_ptr< VNetworkBase > create(std::string input)
Given a string, make a network.
static void resolveGlobs(std::string &filename)
If the filepath ends in .
Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration.
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
static bool isFile(std::string const &inputFile)
Check if a string is the path of a file on disk.
static Root::TMsgLogger logger("iLumiCalc")