ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSSimpleLWTNNHandler.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7// For writing to a tree
8#include "TBranch.h"
9#include "TTree.h"
10
11// LWTNN
12#include "lwtnn/LightweightNeuralNetwork.hh"
13#include "lwtnn/parse_json.hh"
14
16 : VNetworkLWTNN(inputFile) {
17 ATH_MSG_DEBUG("Setting up from inputFile.");
20};
21
23 const TFCSSimpleLWTNNHandler &copy_from)
24 : VNetworkLWTNN(copy_from) {
25 // Cannot take copy of lwt::LightweightNeuralNetwork
26 // (copy constructor disabled)
27 ATH_MSG_DEBUG("Making new m_lwtnn_neural for copy of network.");
28 std::stringstream json_stream(m_json);
29 const lwt::JSONConfig config = lwt::parse_json(json_stream);
30 m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>(
31 config.inputs, config.layers, config.outputs);
33};
34
36 // build the graph
37 ATH_MSG_DEBUG("Reading the m_json string stream into a neural network");
38 std::stringstream json_stream(m_json);
39 const lwt::JSONConfig config = lwt::parse_json(json_stream);
40 m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>(
41 config.inputs, config.layers, config.outputs);
42 // Get the output layers
43 ATH_MSG_DEBUG("Getting output layers for neural network");
44 for (const std::string& name : config.outputs) {
45 ATH_MSG_VERBOSE("Found output layer called " << name);
46 m_outputLayers.push_back(name);
47 };
48 ATH_MSG_DEBUG("Removing prefix from stored layers.");
50 ATH_MSG_DEBUG("Finished output nodes.");
51}
52
53std::vector<std::string> TFCSSimpleLWTNNHandler::getOutputLayers() const {
54 return m_outputLayers;
55};
56
57// This is implement the specific compute, and ensure the output is returned in
58// regular format. For LWTNN, that's easy.
60 TFCSSimpleLWTNNHandler::NetworkInputs const &inputs) const {
61 ATH_MSG_DEBUG("Running computation on LWTNN neural network");
63 // Flatten the map depth
64 if (inputs.size() != 1) {
65 ATH_MSG_ERROR("The inputs have multiple elements."
66 << " An LWTNN neural network can only handle one node.");
67 };
68 std::map<std::string, double> flat_inputs;
69 for (const auto & node : inputs) {
70 flat_inputs = node.second;
71 }
72 // Now we have flattened, we can compute.
73 NetworkOutputs outputs = m_lwtnn_neural->compute(flat_inputs);
74 removePrefixes(outputs);
76 ATH_MSG_DEBUG("Computation on LWTNN neural network done, returning");
77 return outputs;
78};
79
80// Giving this it's own streamer to call setupNet
81void TFCSSimpleLWTNNHandler::Streamer(TBuffer &buf) {
82 ATH_MSG_DEBUG("In streamer of " << __FILE__);
83 if (buf.IsReading()) {
84 ATH_MSG_DEBUG("Reading buffer in TFCSSimpleLWTNNHandler ");
85 // Get the persisted variables filled in
86 TFCSSimpleLWTNNHandler::Class()->ReadBuffer(buf, this);
87 // Setup the net, creating the non persisted variables
88 // exactly as in the constructor
89 this->setupNet();
90#ifndef __FastCaloSimStandAlone__
91 // When running inside Athena, delete persisted information
92 // to conserve memory
93 this->deleteAllButNet();
94#endif
95 } else {
96 if (!m_json.empty()) {
97 ATH_MSG_DEBUG("Writing buffer in TFCSSimpleLWTNNHandler ");
98 } else {
100 "Writing buffer in TFCSSimpleLWTNNHandler, but m_json is empty.");
101 }
102 // Persist variables
103 TFCSSimpleLWTNNHandler::Class()->WriteBuffer(buf, this);
104 };
105};
#define ATH_MSG_ERROR(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::vector< std::string > m_outputLayers
Do not persistify.
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
void setupNet() override
Perform actions that prepare network for use.
std::unique_ptr< lwt::LightweightNeuralNetwork > m_lwtnn_neural
The network that we are wrapping here.
VNetworkLWTNN(const VNetworkLWTNN &copy_from)
VNetworkLWTNN copy constructor.
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
TFCSSimpleLWTNNHandler(const std::string &inputFile)
TFCSSimpleLWTNNHandler constructor.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
static std::string representNetworkOutputs(NetworkOutputs const &outputs, int maxValues=3)
String representation of network outputs.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
void removePrefixes(NetworkOutputs &outputs) const
Remove any common prefix from the outputs.
static std::string representNetworkInputs(NetworkInputs const &inputs, int maxValues=3)
String representation of network inputs.
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
std::string m_json
String containing json input file.
void setupPersistedVariables() override
Perform actions that prep data to create the net.
Definition node.h:24