ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSGANLWTNNHandler.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7// For writing to a tree
8#include "TBranch.h"
9#include "TTree.h"
10
11// LWTNN
12#include "lwtnn/LightweightGraph.hh"
13#include "lwtnn/parse_json.hh"
14
15TFCSGANLWTNNHandler::TFCSGANLWTNNHandler(const std::string &inputFile)
16 : VNetworkLWTNN(inputFile) {
17 ATH_MSG_DEBUG("Setting up from inputFile.");
20};
21
23 : VNetworkLWTNN(copy_from) {
24 // Cannot take copies of lwt::LightweightGraph
25 // (copy constructor disabled)
26 ATH_MSG_DEBUG("Making a new m_lwtnn_graph for copied network");
27 std::stringstream json_stream(m_json);
28 const lwt::GraphConfig config = lwt::parse_json_graph(json_stream);
29 m_lwtnn_graph = std::make_unique<lwt::LightweightGraph>(config);
31};
32
34 // Backcompatability, previous versions stored this in m_input
35 if (m_json.length() == 0 && m_input != nullptr) {
36 m_json = *m_input;
37 delete m_input;
38 m_input = nullptr;
39 }
40 // build the graph
41 ATH_MSG_VERBOSE("m_json has size " << m_json.length());
42 ATH_MSG_DEBUG("m_json starts with " << m_json.substr(0, 10));
43 ATH_MSG_VERBOSE("Reading the m_json string stream into a graph network");
44 std::stringstream json_stream(m_json);
45 const lwt::GraphConfig config = lwt::parse_json_graph(json_stream);
46 m_lwtnn_graph = std::make_unique<lwt::LightweightGraph>(config);
47 // Get the output layers
48 ATH_MSG_VERBOSE("Getting output layers for neural network");
49 for (const auto & node : config.outputs) {
50 const std::string node_name = node.first;
51 const lwt::OutputNodeConfig node_config = node.second;
52 for (const std::string & label : node_config.labels) {
53 ATH_MSG_VERBOSE("Found output layer called " << node_name << "_"
54 << label);
55 m_outputLayers.push_back(node_name + "_" + label);
56 }
57 };
58 ATH_MSG_VERBOSE("Removing prefix from stored layers.");
60 ATH_MSG_VERBOSE("Finished output nodes.");
61};
62
63std::vector<std::string> TFCSGANLWTNNHandler::getOutputLayers() const {
64 return m_outputLayers;
65};
66
67// This is implement the specific compute, and ensure the output is returned in
68// regular format. For LWTNN, that's easy.
70 TFCSGANLWTNNHandler::NetworkInputs const &inputs) const {
71 ATH_MSG_DEBUG("Running computation on LWTNN graph network");
72 NetworkInputs local_copy = inputs;
73 if (inputs.find("Noise") != inputs.end()) {
74 // Graphs from EnergyAndHitsGANV2 have the local_copy encoded as Noise =
75 // node_0 and mycond = node_1
76 auto noiseNode = local_copy.extract("Noise");
77 noiseNode.key() = "node_0";
78 local_copy.insert(std::move(noiseNode));
79 auto mycondNode = local_copy.extract("mycond");
80 mycondNode.key() = "node_1";
81 local_copy.insert(std::move(mycondNode));
82 }
83 // now we can compute
85 m_lwtnn_graph->compute(local_copy);
86 removePrefixes(outputs);
87 ATH_MSG_DEBUG("Computation on LWTNN graph network done, returning.");
88 return outputs;
89};
90
91// Giving this it's own streamer to call setupNet
92void TFCSGANLWTNNHandler::Streamer(TBuffer &buf) {
93 ATH_MSG_DEBUG("In streamer of " << __FILE__);
94 if (buf.IsReading()) {
95 ATH_MSG_DEBUG("Reading buffer in TFCSGANLWTNNHandler ");
96 // Get the persisted variables filled in
97 TFCSGANLWTNNHandler::Class()->ReadBuffer(buf, this);
98 ATH_MSG_DEBUG("m_json has size " << m_json.length());
99 ATH_MSG_DEBUG("m_json starts with " << m_json.substr(0, 10));
100 // Setup the net, creating the non persisted variables
101 // exactly as in the constructor
102 this->setupNet();
103#ifndef __FastCaloSimStandAlone__
104 // When running inside Athena, delete persisted information
105 // to conserve memory
106 this->deleteAllButNet();
107#endif
108 } else {
109 if (!m_json.empty()) {
110 ATH_MSG_DEBUG("Writing buffer in TFCSGANLWTNNHandler ");
111 } else {
113 "Writing buffer in TFCSGANLWTNNHandler, but m_json is empty");
114 };
115 // Persist variables
116 TFCSGANLWTNNHandler::Class()->WriteBuffer(buf, this);
117 };
118};
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
void setupNet() override
Perform actions that prepare network for use.
std::string * m_input
Do not persistify.
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
TFCSGANLWTNNHandler(const std::string &inputFile)
TFCSGANLWTNNHandler constructor.
std::unique_ptr< lwt::LightweightGraph > m_lwtnn_graph
The network that we are wrapping here.
VNetworkLWTNN(const VNetworkLWTNN &copy_from)
VNetworkLWTNN copy constructor.
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
std::vector< std::string > m_outputLayers
Do not persistify.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
std::string m_json
String containing json input file.
void setupPersistedVariables() override
Perform actions that prep data to create the net.
Definition node.h:24
std::string label(const std::string &format, int i)
Definition label.h:19