ATLAS Offline Software
TFCSSimpleLWTNNHandler.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 // For writing to a tree
8 #include "TBranch.h"
9 #include "TTree.h"
10 
11 // LWTNN
12 #include "lwtnn/LightweightNeuralNetwork.hh"
13 #include "lwtnn/parse_json.hh"
14 
17  ATH_MSG_DEBUG("Setting up from inputFile.");
20 };
21 
23  const TFCSSimpleLWTNNHandler &copy_from)
24  : VNetworkLWTNN(copy_from) {
25  // Cannot take copy of lwt::LightweightNeuralNetwork
26  // (copy constructor disabled)
27  ATH_MSG_DEBUG("Making new m_lwtnn_neural for copy of network.");
28  std::stringstream json_stream(m_json);
29  const lwt::JSONConfig config = lwt::parse_json(json_stream);
30  m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>(
31  config.inputs, config.layers, config.outputs);
32  m_outputLayers = copy_from.m_outputLayers;
33 };
34 
36  // build the graph
37  ATH_MSG_DEBUG("Reading the m_json string stream into a neural network");
38  std::stringstream json_stream(m_json);
39  const lwt::JSONConfig config = lwt::parse_json(json_stream);
40  m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>(
41  config.inputs, config.layers, config.outputs);
42  // Get the output layers
43  ATH_MSG_DEBUG("Getting output layers for neural network");
44  for (std::string name : config.outputs) {
45  ATH_MSG_VERBOSE("Found output layer called " << name);
46  m_outputLayers.push_back(name);
47  };
48  ATH_MSG_DEBUG("Removing prefix from stored layers.");
50  ATH_MSG_DEBUG("Finished output nodes.");
51 }
52 
53 std::vector<std::string> TFCSSimpleLWTNNHandler::getOutputLayers() const {
54  return m_outputLayers;
55 };
56 
57 // This is implement the specific compute, and ensure the output is returned in
58 // regular format. For LWTNN, that's easy.
61  ATH_MSG_DEBUG("Running computation on LWTNN neural network");
63  // Flatten the map depth
64  if (inputs.size() != 1) {
65  ATH_MSG_ERROR("The inputs have multiple elements."
66  << " An LWTNN neural network can only handle one node.");
67  };
68  std::map<std::string, double> flat_inputs;
69  for (auto node : inputs) {
70  flat_inputs = node.second;
71  }
72  // Now we have flattened, we can compute.
73  NetworkOutputs outputs = m_lwtnn_neural->compute(flat_inputs);
76  ATH_MSG_DEBUG("Computation on LWTNN neural network done, returning");
77  return outputs;
78 };
79 
80 // Giving this it's own streamer to call setupNet
81 void TFCSSimpleLWTNNHandler::Streamer(TBuffer &buf) {
82  ATH_MSG_DEBUG("In streamer of " << __FILE__);
83  if (buf.IsReading()) {
84  ATH_MSG_DEBUG("Reading buffer in TFCSSimpleLWTNNHandler ");
85  // Get the persisted variables filled in
86  TFCSSimpleLWTNNHandler::Class()->ReadBuffer(buf, this);
87  // Setup the net, creating the non persisted variables
88  // exactly as in the constructor
89  this->setupNet();
90 #ifndef __FastCaloSimStandAlone__
91  // When running inside Athena, delete persisted information
92  // to conserve memory
93  this->deleteAllButNet();
94 #endif
95  } else {
96  if (!m_json.empty()) {
97  ATH_MSG_DEBUG("Writing buffer in TFCSSimpleLWTNNHandler ");
98  } else {
100  "Writing buffer in TFCSSimpleLWTNNHandler, but m_json is empty.");
101  }
102  // Persist variables
103  TFCSSimpleLWTNNHandler::Class()->WriteBuffer(buf, this);
104  };
105 };
VNetworkBase::NetworkOutputs
std::map< std::string, double > NetworkOutputs
Format for network outputs.
Definition: VNetworkBase.h:100
TFCSSimpleLWTNNHandler::setupNet
void setupNet() override
Perform actions that prepare network for use.
Definition: TFCSSimpleLWTNNHandler.cxx:35
TFCSSimpleLWTNNHandler::m_lwtnn_neural
std::unique_ptr< lwt::LightweightNeuralNetwork > m_lwtnn_neural
The network that we are wrapping here.
Definition: TFCSSimpleLWTNNHandler.h:103
VNetworkBase::representNetworkOutputs
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
Definition: VNetworkBase.cxx:57
VNetworkBase::NetworkInputs
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
Definition: VNetworkBase.h:90
TFCSSimpleLWTNNHandler::compute
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
Definition: TFCSSimpleLWTNNHandler.cxx:59
VNetworkLWTNN::m_json
std::string m_json
String containing json input file.
Definition: VNetworkLWTNN.h:84
TFCSSimpleLWTNNHandler::m_outputLayers
std::vector< std::string > m_outputLayers
Do not persistify.
Definition: TFCSSimpleLWTNNHandler.h:107
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
TFCSSimpleLWTNNHandler
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: TFCSSimpleLWTNNHandler.h:34
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
lwtDev::parse_json
JSONConfig parse_json(std::istream &json)
Definition: parse_json.cxx:42
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
VNetworkLWTNN::deleteAllButNet
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
Definition: VNetworkLWTNN.cxx:87
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
TFCSSimpleLWTNNHandler::TFCSSimpleLWTNNHandler
TFCSSimpleLWTNNHandler(const std::string &inputFile)
TFCSSimpleLWTNNHandler constructor.
Definition: TFCSSimpleLWTNNHandler.cxx:15
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
VNetworkBase::removePrefixes
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
Definition: VNetworkBase.cxx:151
TFCSSimpleLWTNNHandler::getOutputLayers
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
Definition: TFCSSimpleLWTNNHandler.cxx:53
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
VNetworkBase::representNetworkInputs
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
Definition: VNetworkBase.cxx:37
VNetworkLWTNN
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: VNetworkLWTNN.h:31
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
VNetworkLWTNN::setupPersistedVariables
void setupPersistedVariables() override
Perform actions that prep data to create the net.
Definition: VNetworkLWTNN.cxx:30
TFCSSimpleLWTNNHandler.h
node
Definition: memory_hooks-stdcmalloc.h:74