ATLAS Offline Software
Loading...
Searching...
No Matches
TauJetRNN.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include <algorithm>
8#include <fstream>
9
10#include "lwtnn/LightweightGraph.hh"
11#include "lwtnn/Exceptions.hh"
12#include "lwtnn/parse_json.hh"
13
15
16
17TauJetRNN::TauJetRNN(const std::string &filename, const Config &config, bool useTRT)
18 : asg::AsgMessaging("TauJetRNN"), m_config(config), m_graph(nullptr), m_useTRT(useTRT) {
19 // Load the json file defining the network
20 std::ifstream input_file(filename);
21 lwt::GraphConfig lwtnn_config;
22 try {
23 lwtnn_config = lwt::parse_json_graph(input_file);
24 } catch (const std::logic_error &e) {
25 ATH_MSG_ERROR("Error parsing network config: " << e.what());
26 throw;
27 }
28
29 // Search for input layer names specified in 'config'
30 auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
31 return in_node.name == config.input_layer_scalar;
32 };
33 auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
34 return in_node.name == config.input_layer_tracks;
35 };
36 auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
37 return in_node.name == config.input_layer_clusters;
38 };
39
40 auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
41 lwtnn_config.inputs.cend(),
42 node_is_scalar);
43
44 auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
45 lwtnn_config.input_sequences.cend(),
46 node_is_track);
47
48 auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
49 lwtnn_config.input_sequences.cend(),
50 node_is_cluster);
51
52 // Check which input layers were found
53 auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
54 auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
55 auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
56
57 // Fill the variable names of each input layer into the corresponding vector
58 if (has_scalar_node) {
59 for (const auto &in : scalar_node->variables) {
60 m_scalar_inputs.push_back(in.name);
61 }
62 }
63
64 if (has_track_node) {
65 for (const auto &in : track_node->variables) {
66 m_track_inputs.push_back(in.name);
67 }
68 }
69
70 if (has_cluster_node) {
71 for (const auto &in : cluster_node->variables) {
72 m_cluster_inputs.push_back(in.name);
73 }
74 }
75
76 // Configure the network
77 try {
78 m_graph = std::make_unique<lwt::LightweightGraph>(
79 lwtnn_config, config.output_layer);
80 } catch (const lwt::NNConfigurationException &e) {
81 ATH_MSG_ERROR(e.what());
82 throw;
83 }
84
85 // Load the variable calculator
87}
88
90
92 const std::vector<const xAOD::TauTrack *> &tracks,
93 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
94 InputMap scalarInputs;
95 InputSequenceMap vectorInputs;
96 if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
97 return -1111.0;
98 }
99 // Compute the network outputs
100 const auto outputs = m_graph->compute(scalarInputs, vectorInputs);
101 // Return value of the output neuron
102 return outputs.at(m_config.output_node);
103}
104
106 const std::vector<const xAOD::TauTrack *> &tracks,
107 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
108 std::map<std::string, std::map<std::string, double>>& scalarInputs,
109 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
110 scalarInputs.clear();
111 vectorInputs.clear();
112 // Populate input (sequence) map with input variables
113 for (const auto &varname : m_scalar_inputs) {
114 if (!m_var_calc->compute(varname, tau,
115 scalarInputs[m_config.input_layer_scalar][varname])) {
116 ATH_MSG_WARNING("Error computing '" << varname
117 << "' returning default");
118 return false;
119 }
120 }
121
122 for (const auto &varname : m_track_inputs) {
123 if (!m_var_calc->compute(varname, tau, tracks,
124 vectorInputs[m_config.input_layer_tracks][varname])) {
125 ATH_MSG_WARNING("Error computing '" << varname
126 << "' returning default");
127 return false;
128 }
129 }
130
131 for (const auto &varname : m_cluster_inputs) {
132 if (!m_var_calc->compute(varname, tau, clusters,
133 vectorInputs[m_config.input_layer_clusters][varname])) {
134 ATH_MSG_WARNING("Error computing '" << varname
135 << "' returning default");
136 return false;
137 }
138 }
139 return true;
140}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
const Config m_config
Definition TauJetRNN.h:81
std::vector< std::string > m_track_inputs
Definition TauJetRNN.h:86
std::vector< std::string > m_scalar_inputs
Definition TauJetRNN.h:85
std::unique_ptr< TauJetRNNUtils::VarCalc > m_var_calc
Definition TauJetRNN.h:90
std::map< std::string, VectorMap > InputSequenceMap
Definition TauJetRNN.h:78
std::map< std::string, VariableMap > InputMap
Definition TauJetRNN.h:77
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
bool m_useTRT
Definition TauJetRNN.h:92
std::unique_ptr< const lwt::LightweightGraph > m_graph
Definition TauJetRNN.h:82
std::vector< std::string > m_cluster_inputs
Definition TauJetRNN.h:87
TauJetRNN(const std::string &filename, const Config &config, bool useTRT)
Definition TauJetRNN.cxx:17
float compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition TauJetRNN.cxx:91
AsgMessaging(const std::string &name)
Constructor with a name.
std::unique_ptr< VarCalc > get_calculator(const std::vector< std::string > &scalar_vars, const std::vector< std::string > &track_vars, const std::vector< std::string > &cluster_vars, bool useTRT)
TauJet_v3 TauJet
Definition of the current "tau version".