ATLAS Offline Software
TauGNN.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "tauRecTools/TauGNN.h"
7 #include "lwtnn/parse_json.hh"
9 
10 #include <algorithm>
11 #include <fstream>
12 
14 
15 TauGNN::TauGNN(const std::string &nnFile, const Config &config):
16  asg::AsgMessaging("TauGNN"),
17  m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)),
18  m_config{config}
19  {
20  //==================================================//
21  // This part is ported from FTagDiscriminant GNN.cxx//
22  //==================================================//
23 
24  // get the configuration of the model outputs
25  FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
26 
27  //Let's see the output!
28  for (const auto& out_node: gnn_output_config) {
29  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name);
30  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name);
31  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name);
32  }
33 
34  //Get model config (for inputs)
35  auto lwtnn_config = m_onnxUtil->getLwtConfig();
36 
37  //===================================================//
38  // This part is ported from tauRecTools TauJetRNN.cxx//
39  //===================================================//
40 
41  // Search for input layer names specified in 'config'
42  auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
43  return in_node.name == config.input_layer_scalar;
44  };
45  auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
46  return in_node.name == config.input_layer_tracks;
47  };
48  auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
49  return in_node.name == config.input_layer_clusters;
50  };
51 
52  auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
53  lwtnn_config.inputs.cend(),
54  node_is_scalar);
55 
56  auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
57  lwtnn_config.input_sequences.cend(),
58  node_is_track);
59 
60  auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
61  lwtnn_config.input_sequences.cend(),
62  node_is_cluster);
63 
64  // Check which input layers were found
65  auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
66  auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
67  auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
68  if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!");
69  if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!");
70  if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!");
71 
72  // Fill the variable names of each input layer into the corresponding vector
73  if (has_scalar_node) {
74  for (const auto &in : scalar_node->variables) {
75  std::string name = in.name;
76  m_scalarCalc_inputs.push_back(name);
77  }
78  }
79 
80  if (has_track_node) {
81  for (const auto &in : track_node->variables) {
82  std::string name = in.name;
83  m_trackCalc_inputs.push_back(name);
84  }
85  }
86 
87  if (has_cluster_node) {
88  for (const auto &in : cluster_node->variables) {
89  std::string name = in.name;
90  m_clusterCalc_inputs.push_back(name);
91  }
92  }
93  // Load the variable calculator
94  m_var_calc = TauGNNUtils::get_calculator(m_scalarCalc_inputs, m_trackCalc_inputs, m_clusterCalc_inputs);
95  ATH_MSG_INFO("TauGNN object initialized successfully!");
96 }
97 
99 
100 std::tuple<
101  std::map<std::string, float>,
102  std::map<std::string, std::vector<char>>,
103  std::map<std::string, std::vector<float>> >
105  const std::vector<const xAOD::TauTrack *> &tracks,
106  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
107  InputMap scalarInputs;
108  InputSequenceMap vectorInputs;
109  std::map<std::string, Inputs> gnn_input;
110  ATH_MSG_DEBUG("Starting compute...");
111  //Prepare input variables
112  if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
113  ATH_MSG_FATAL("Failed calculateInputVariables");
114  throw StatusCode::FAILURE;
115  }
116 
117  // Add TauJet-level features to the input
118  std::vector<float> tau_feats;
119  for (const auto &varname : m_scalarCalc_inputs) {
120  tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname]));
121  }
122  std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())};
123  Inputs tau_info (tau_feats, tau_feats_dim);
124  gnn_input.insert({"tau_vars", tau_info});
125 
126  //Add track-level features to the input
127  std::vector<float> trk_feats;
128  int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size());
129  int num_node_vars=static_cast<int>(m_trackCalc_inputs.size());
130  trk_feats.resize(num_nodes * num_node_vars);
131  int var_idx=0;
132  for (const auto &varname : m_trackCalc_inputs) {
133  for (int node_idx=0; node_idx<num_nodes; node_idx++){
134  trk_feats.at(node_idx*num_node_vars + var_idx)
135  = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx));
136  }
137  var_idx++;
138  }
139  std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
140  Inputs trk_info (trk_feats, trk_feats_dim);
141  gnn_input.insert({"track_vars", trk_info});
142 
143  //Add cluster-level features to the input
144  std::vector<float> cls_feats;
145  num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size());
146  num_node_vars=static_cast<int>(m_clusterCalc_inputs.size());
147  cls_feats.resize(num_nodes * num_node_vars);
148  var_idx=0;
149  for (const auto &varname : m_clusterCalc_inputs) {
150  for (int node_idx=0; node_idx<num_nodes; node_idx++){
151  cls_feats.at(node_idx*num_node_vars + var_idx)
152  = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx));
153  }
154  var_idx++;
155  }
156  std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
157  Inputs cls_info (cls_feats, cls_feats_dim);
158  gnn_input.insert({"cluster_vars", cls_info});
159 
160  //RUN THE INFERENCE!!!
161  ATH_MSG_DEBUG("Prepared inputs, running inference...");
162  auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input);
163  ATH_MSG_DEBUG("Finished compute!");
164  return std::make_tuple(out_f, out_vc, out_vf);
165 }
166 
168  const std::vector<const xAOD::TauTrack *> &tracks,
169  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
170  std::map<std::string, std::map<std::string, double>>& scalarInputs,
171  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
172  scalarInputs.clear();
173  vectorInputs.clear();
174  // Populate input (sequence) map with input variables
175  for (const auto &varname : m_scalarCalc_inputs) {
176  if (!m_var_calc->compute(varname, tau,
177  scalarInputs[m_config.input_layer_scalar][varname])) {
178  ATH_MSG_WARNING("Error computing '" << varname
179  << "' returning default");
180  return false;
181  }
182  }
183 
184  for (const auto &varname : m_trackCalc_inputs) {
185  if (!m_var_calc->compute(varname, tau, tracks,
186  vectorInputs[m_config.input_layer_tracks][varname])) {
187  ATH_MSG_WARNING("Error computing '" << varname
188  << "' returning default");
189  return false;
190  }
191  }
192 
193  for (const auto &varname : m_clusterCalc_inputs) {
194  if (!m_var_calc->compute(varname, tau, clusters,
195  vectorInputs[m_config.input_layer_clusters][varname])) {
196  ATH_MSG_WARNING("Error computing '" << varname
197  << "' returning default");
198  return false;
199  }
200  }
201  return true;
202 }
TauGNN::m_config
const Config m_config
Definition: TauGNN.h:85
FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR
@ VECCHAR
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
TauGNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauGNN.h:41
TauGNN::m_scalarCalc_inputs
std::vector< std::string > m_scalarCalc_inputs
Definition: TauGNN.h:92
TauGNNUtils::get_calculator
std::unique_ptr< GNNVarCalc > get_calculator(const std::vector< std::string > &scalar_vars, const std::vector< std::string > &track_vars, const std::vector< std::string > &cluster_vars)
Definition: TauGNNUtils.cxx:113
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauGNN::InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauGNN.h:82
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT
@ FLOAT
TauGNN::Inputs
FlavorTagDiscriminants::Inputs Inputs
Definition: TauGNN.h:76
TauGNN::m_clusterCalc_inputs
std::vector< std::string > m_clusterCalc_inputs
Definition: TauGNN.h:94
TauGNN::TauGNN
TauGNN(const std::string &nnFile, const Config &config)
Definition: TauGNN.cxx:15
TauGNNUtils.h
python.base_data.config
config
Definition: base_data.py:21
asg
Definition: DataHandleTestTool.h:28
FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT
@ VECFLOAT
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
OnnxUtil
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:14
TauGNN::compute
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:104
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNN::InputMap
std::map< std::string, VariableMap > InputMap
Definition: TauGNN.h:81
TauGNN.h
TauGNN::m_var_calc
std::unique_ptr< TauGNNUtils::GNNVarCalc > m_var_calc
Definition: TauGNN.h:97
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
TauGNN::m_trackCalc_inputs
std::vector< std::string > m_trackCalc_inputs
Definition: TauGNN.h:93
TauGNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauGNN.h:42
TauGNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauGNN.cxx:167
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
TauGNN::m_onnxUtil
std::shared_ptr< const FlavorTagDiscriminants::OnnxUtil > m_onnxUtil
Definition: TauGNN.h:46
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNN::~TauGNN
~TauGNN()
Definition: TauGNN.cxx:98
TauGNN::Config
Definition: TauGNN.h:39
TauGNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauGNN.h:40
OnnxUtil.h