ATLAS Offline Software
TauJetRNN.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 #include <algorithm>
8 #include <fstream>
9 
10 #include "lwtnn/LightweightGraph.hh"
11 #include "lwtnn/Exceptions.hh"
12 #include "lwtnn/parse_json.hh"
13 
15 
16 
17 TauJetRNN::TauJetRNN(const std::string &filename, const Config &config)
18  : asg::AsgMessaging("TauJetRNN"), m_config(config), m_graph(nullptr) {
19  // Load the json file defining the network
20  std::ifstream input_file(filename);
21  lwt::GraphConfig lwtnn_config;
22  try {
23  lwtnn_config = lwt::parse_json_graph(input_file);
24  } catch (const std::logic_error &e) {
25  ATH_MSG_ERROR("Error parsing network config: " << e.what());
26  throw;
27  }
28 
29  // Search for input layer names specified in 'config'
30  auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
31  return in_node.name == config.input_layer_scalar;
32  };
33  auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
34  return in_node.name == config.input_layer_tracks;
35  };
36  auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
37  return in_node.name == config.input_layer_clusters;
38  };
39 
40  auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
41  lwtnn_config.inputs.cend(),
42  node_is_scalar);
43 
44  auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
45  lwtnn_config.input_sequences.cend(),
46  node_is_track);
47 
48  auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
49  lwtnn_config.input_sequences.cend(),
50  node_is_cluster);
51 
52  // Check which input layers were found
53  auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
54  auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
55  auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
56 
57  // Fill the variable names of each input layer into the corresponding vector
58  if (has_scalar_node) {
59  for (const auto &in : scalar_node->variables) {
60  m_scalar_inputs.push_back(in.name);
61  }
62  }
63 
64  if (has_track_node) {
65  for (const auto &in : track_node->variables) {
66  m_track_inputs.push_back(in.name);
67  }
68  }
69 
70  if (has_cluster_node) {
71  for (const auto &in : cluster_node->variables) {
72  m_cluster_inputs.push_back(in.name);
73  }
74  }
75 
76  // Configure the network
77  try {
78  m_graph = std::make_unique<lwt::LightweightGraph>(
79  lwtnn_config, config.output_layer);
80  } catch (const lwt::NNConfigurationException &e) {
81  ATH_MSG_ERROR(e.what());
82  throw;
83  }
84 
85  // Load the variable calculator
87 }
88 
90 
92  const std::vector<const xAOD::TauTrack *> &tracks,
93  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
94  InputMap scalarInputs;
95  InputSequenceMap vectorInputs;
96  if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
97  return -1111.0;
98  }
99  // Compute the network outputs
100  const auto outputs = m_graph->compute(scalarInputs, vectorInputs);
101  // Return value of the output neuron
102  return outputs.at(m_config.output_node);
103 }
104 
106  const std::vector<const xAOD::TauTrack *> &tracks,
107  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
108  std::map<std::string, std::map<std::string, double>>& scalarInputs,
109  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
110  scalarInputs.clear();
111  vectorInputs.clear();
112  // Populate input (sequence) map with input variables
113  for (const auto &varname : m_scalar_inputs) {
114  if (!m_var_calc->compute(varname, tau,
115  scalarInputs[m_config.input_layer_scalar][varname])) {
116  ATH_MSG_WARNING("Error computing '" << varname
117  << "' returning default");
118  return false;
119  }
120  }
121 
122  for (const auto &varname : m_track_inputs) {
123  if (!m_var_calc->compute(varname, tau, tracks,
124  vectorInputs[m_config.input_layer_tracks][varname])) {
125  ATH_MSG_WARNING("Error computing '" << varname
126  << "' returning default");
127  return false;
128  }
129  }
130 
131  for (const auto &varname : m_cluster_inputs) {
132  if (!m_var_calc->compute(varname, tau, clusters,
133  vectorInputs[m_config.input_layer_clusters][varname])) {
134  ATH_MSG_WARNING("Error computing '" << varname
135  << "' returning default");
136  return false;
137  }
138  }
139  return true;
140 }
TauJetRNN::InputMap
std::map< std::string, VariableMap > InputMap
Definition: TauJetRNN.h:77
TauJetRNNUtils.h
TauJetRNN::compute
float compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauJetRNN.cxx:91
TauJetRNN::m_graph
std::unique_ptr< const lwt::LightweightGraph > m_graph
Definition: TauJetRNN.h:82
TauJetRNN::TauJetRNN
TauJetRNN(const std::string &filename, const Config &config)
Definition: TauJetRNN.cxx:17
asg
Definition: DataHandleTestTool.h:28
TauJetRNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauJetRNN.cxx:105
python.resample_meson.input_file
input_file
Definition: resample_meson.py:164
TauJetRNN::m_config
const Config m_config
Definition: TauJetRNN.h:81
TauJetRNN::Config::output_node
std::string output_node
Definition: TauJetRNN.h:42
TauJetRNN::m_track_inputs
std::vector< std::string > m_track_inputs
Definition: TauJetRNN.h:86
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauJetRNN::m_scalar_inputs
std::vector< std::string > m_scalar_inputs
Definition: TauJetRNN.h:85
TauJetRNNUtils::get_calculator
std::unique_ptr< VarCalc > get_calculator(const std::vector< std::string > &scalar_vars, const std::vector< std::string > &track_vars, const std::vector< std::string > &cluster_vars)
Definition: TauJetRNNUtils.cxx:110
TauJetRNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauJetRNN.h:39
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
TauJetRNN.h
TauJetRNN::InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauJetRNN.h:78
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
DiTauMassTools::MaxHistStrategyV2::e
e
Definition: PhysicsAnalysis/TauID/DiTauMassTools/DiTauMassTools/HelperFunctions.h:26
TauJetRNN::~TauJetRNN
~TauJetRNN()
Definition: TauJetRNN.cxx:89
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
TauJetRNN::m_cluster_inputs
std::vector< std::string > m_cluster_inputs
Definition: TauJetRNN.h:87
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:72
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
TauJetRNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauJetRNN.h:40
TauJetRNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauJetRNN.h:38
TauJetRNN::m_var_calc
std::unique_ptr< TauJetRNNUtils::VarCalc > m_var_calc
Definition: TauJetRNN.h:90
TauJetRNN::Config
Definition: TauJetRNN.h:37