ATLAS Offline Software
TFCSGANLWTNNHandler.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 // For writing to a tree
8 #include "TBranch.h"
9 #include "TTree.h"
10 
11 // LWTNN
12 #include "lwtnn/LightweightGraph.hh"
13 #include "lwtnn/parse_json.hh"
14 
17  ATH_MSG_DEBUG("Setting up from inputFile.");
20 };
21 
23  : VNetworkLWTNN(copy_from) {
24  // Cannot take copies of lwt::LightweightGraph
25  // (copy constructor disabled)
26  ATH_MSG_DEBUG("Making a new m_lwtnn_graph for copied network");
27  std::stringstream json_stream(m_json);
28  const lwt::GraphConfig config = lwt::parse_json_graph(json_stream);
29  m_lwtnn_graph = std::make_unique<lwt::LightweightGraph>(config);
30  m_outputLayers = copy_from.m_outputLayers;
31 };
32 
34  // Backcompatability, previous versions stored this in m_input
35  if (m_json.length() == 0 && m_input != nullptr) {
36  m_json = *m_input;
37  delete m_input;
38  m_input = nullptr;
39  }
40  // build the graph
41  ATH_MSG_VERBOSE("m_json has size " << m_json.length());
42  ATH_MSG_DEBUG("m_json starts with " << m_json.substr(0, 10));
43  ATH_MSG_VERBOSE("Reading the m_json string stream into a graph network");
44  std::stringstream json_stream(m_json);
45  const lwt::GraphConfig config = lwt::parse_json_graph(json_stream);
46  m_lwtnn_graph = std::make_unique<lwt::LightweightGraph>(config);
47  // Get the output layers
48  ATH_MSG_VERBOSE("Getting output layers for neural network");
49  for (auto node : config.outputs) {
50  const std::string node_name = node.first;
51  const lwt::OutputNodeConfig node_config = node.second;
52  for (const std::string & label : node_config.labels) {
53  ATH_MSG_VERBOSE("Found output layer called " << node_name << "_"
54  << label);
55  m_outputLayers.push_back(node_name + "_" + label);
56  }
57  };
58  ATH_MSG_VERBOSE("Removing prefix from stored layers.");
60  ATH_MSG_VERBOSE("Finished output nodes.");
61 };
62 
63 std::vector<std::string> TFCSGANLWTNNHandler::getOutputLayers() const {
64  return m_outputLayers;
65 };
66 
67 // This is implement the specific compute, and ensure the output is returned in
68 // regular format. For LWTNN, that's easy.
71  ATH_MSG_DEBUG("Running computation on LWTNN graph network");
72  NetworkInputs local_copy = inputs;
73  if (inputs.find("Noise") != inputs.end()) {
74  // Graphs from EnergyAndHitsGANV2 have the local_copy encoded as Noise =
75  // node_0 and mycond = node_1
76  auto noiseNode = local_copy.extract("Noise");
77  noiseNode.key() = "node_0";
78  local_copy.insert(std::move(noiseNode));
79  auto mycondNode = local_copy.extract("mycond");
80  mycondNode.key() = "node_1";
81  local_copy.insert(std::move(mycondNode));
82  }
83  // now we can compute
85  m_lwtnn_graph->compute(local_copy);
87  ATH_MSG_DEBUG("Computation on LWTNN graph network done, returning.");
88  return outputs;
89 };
90 
91 // Giving this it's own streamer to call setupNet
92 void TFCSGANLWTNNHandler::Streamer(TBuffer &buf) {
93  ATH_MSG_DEBUG("In streamer of " << __FILE__);
94  if (buf.IsReading()) {
95  ATH_MSG_DEBUG("Reading buffer in TFCSGANLWTNNHandler ");
96  // Get the persisted variables filled in
97  TFCSGANLWTNNHandler::Class()->ReadBuffer(buf, this);
98  ATH_MSG_DEBUG("m_json has size " << m_json.length());
99  ATH_MSG_DEBUG("m_json starts with " << m_json.substr(0, 10));
100  // Setup the net, creating the non persisted variables
101  // exactly as in the constructor
102  this->setupNet();
103 #ifndef __FastCaloSimStandAlone__
104  // When running inside Athena, delete persisted information
105  // to conserve memory
106  this->deleteAllButNet();
107 #endif
108  } else {
109  if (!m_json.empty()) {
110  ATH_MSG_DEBUG("Writing buffer in TFCSGANLWTNNHandler ");
111  } else {
113  "Writing buffer in TFCSGANLWTNNHandler, but m_json is empty");
114  };
115  // Persist variables
116  TFCSGANLWTNNHandler::Class()->WriteBuffer(buf, this);
117  };
118 };
TFCSGANLWTNNHandler::m_lwtnn_graph
std::unique_ptr< lwt::LightweightGraph > m_lwtnn_graph
The network that we are wrapping here.
Definition: TFCSGANLWTNNHandler.h:105
VNetworkBase::NetworkOutputs
std::map< std::string, double > NetworkOutputs
Format for network outputs.
Definition: VNetworkBase.h:100
TFCSGANLWTNNHandler
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: TFCSGANLWTNNHandler.h:37
PlotCalibFromCool.label
label
Definition: PlotCalibFromCool.py:78
VNetworkBase::NetworkInputs
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
Definition: VNetworkBase.h:90
VNetworkLWTNN::m_json
std::string m_json
String containing json input file.
Definition: VNetworkLWTNN.h:84
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
TFCSGANLWTNNHandler::m_input
std::string * m_input
Do not persistify.
Definition: TFCSGANLWTNNHandler.h:115
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TFCSGANLWTNNHandler::compute
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
Definition: TFCSGANLWTNNHandler.cxx:69
VNetworkLWTNN::deleteAllButNet
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
Definition: VNetworkLWTNN.cxx:87
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
TFCSGANLWTNNHandler::getOutputLayers
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
Definition: TFCSGANLWTNNHandler.cxx:63
TFCSGANLWTNNHandler::m_outputLayers
std::vector< std::string > m_outputLayers
Do not persistify.
Definition: TFCSGANLWTNNHandler.h:110
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
VNetworkBase::removePrefixes
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
Definition: VNetworkBase.cxx:151
TFCSGANLWTNNHandler::setupNet
void setupNet() override
Perform actions that prepare network for use.
Definition: TFCSGANLWTNNHandler.cxx:33
VNetworkLWTNN
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: VNetworkLWTNN.h:31
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
TFCSGANLWTNNHandler::TFCSGANLWTNNHandler
TFCSGANLWTNNHandler(const std::string &inputFile)
TFCSGANLWTNNHandler constructor.
Definition: TFCSGANLWTNNHandler.cxx:15
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
VNetworkLWTNN::setupPersistedVariables
void setupPersistedVariables() override
Perform actions that prep data to create the net.
Definition: VNetworkLWTNN.cxx:30
TFCSGANLWTNNHandler.h
node
Definition: memory_hooks-stdcmalloc.h:74