ATLAS Offline Software
TauGNN.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "tauRecTools/TauGNN.h"
7 #include "lwtnn/parse_json.hh"
9 
10 #include <algorithm>
11 #include <fstream>
12 
14 
15 TauGNN::TauGNN(const std::string &nnFile, const Config &config):
16  asg::AsgMessaging("TauGNN"),
17  m_onnxUtil(nullptr)
18  {
19  //==================================================//
20  // This part is ported from FTagDiscriminant GNN.cxx//
21  //==================================================//
22 
23  m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile);
24 
25  // get the configuration of the model outputs
26  FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
27 
28  //Let's see the output!
29  for (const auto& out_node: gnn_output_config) {
30  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name);
31  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name);
32  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name);
33  }
34 
35  //Get model config (for inputs)
36  auto lwtnn_config = m_onnxUtil->getLwtConfig();
37 
38  //===================================================//
39  // This part is ported from tauRecTools TauJetRNN.cxx//
40  //===================================================//
41 
42  // Search for input layer names specified in 'config'
43  auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
44  return in_node.name == config.input_layer_scalar;
45  };
46  auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
47  return in_node.name == config.input_layer_tracks;
48  };
49  auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
50  return in_node.name == config.input_layer_clusters;
51  };
52 
53  auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
54  lwtnn_config.inputs.cend(),
55  node_is_scalar);
56 
57  auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
58  lwtnn_config.input_sequences.cend(),
59  node_is_track);
60 
61  auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
62  lwtnn_config.input_sequences.cend(),
63  node_is_cluster);
64 
65  // Check which input layers were found
66  auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
67  auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
68  auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
69  if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!");
70  if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!");
71  if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!");
72 
73  // Fill the variable names of each input layer into the corresponding vector
74  if (has_scalar_node) {
75  for (const auto &in : scalar_node->variables) {
76  std::string name = in.name;
77  m_scalarCalc_inputs.push_back(name);
78  }
79  }
80 
81  if (has_track_node) {
82  for (const auto &in : track_node->variables) {
83  std::string name = in.name;
84  m_trackCalc_inputs.push_back(name);
85  }
86  }
87 
88  if (has_cluster_node) {
89  for (const auto &in : cluster_node->variables) {
90  std::string name = in.name;
91  m_clusterCalc_inputs.push_back(name);
92  }
93  }
94  // Load the variable calculator
96  ATH_MSG_INFO("TauGNN object initialized successfully!");
97 }
98 
100 
101 std::tuple<
102  std::map<std::string, float>,
103  std::map<std::string, std::vector<char>>,
104  std::map<std::string, std::vector<float>> >
106  const std::vector<const xAOD::TauTrack *> &tracks,
107  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
108  InputMap scalarInputs;
109  InputSequenceMap vectorInputs;
110  std::map<std::string, Inputs> gnn_input;
111  ATH_MSG_DEBUG("Starting compute...");
112  //Prepare input variables
113  if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
114  ATH_MSG_FATAL("Failed calculateInputVariables");
115  throw StatusCode::FAILURE;
116  }
117 
118  // Add TauJet-level features to the input
119  std::vector<float> tau_feats;
120  for (const auto &varname : m_scalarCalc_inputs) {
121  tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname]));
122  }
123  std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())};
124  Inputs tau_info (tau_feats, tau_feats_dim);
125  gnn_input.insert({"tau_vars", tau_info});
126 
127  //Add track-level features to the input
128  std::vector<float> trk_feats;
129  int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size());
130  int num_node_vars=static_cast<int>(m_trackCalc_inputs.size());
131  trk_feats.resize(num_nodes * num_node_vars);
132  int var_idx=0;
133  for (const auto &varname : m_trackCalc_inputs) {
134  for (int node_idx=0; node_idx<num_nodes; node_idx++){
135  trk_feats.at(node_idx*num_node_vars + var_idx)
136  = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx));
137  }
138  var_idx++;
139  }
140  std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
141  Inputs trk_info (trk_feats, trk_feats_dim);
142  gnn_input.insert({"track_vars", trk_info});
143 
144  //Add cluster-level features to the input
145  std::vector<float> cls_feats;
146  num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size());
147  num_node_vars=static_cast<int>(m_clusterCalc_inputs.size());
148  cls_feats.resize(num_nodes * num_node_vars);
149  var_idx=0;
150  for (const auto &varname : m_clusterCalc_inputs) {
151  for (int node_idx=0; node_idx<num_nodes; node_idx++){
152  cls_feats.at(node_idx*num_node_vars + var_idx)
153  = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx));
154  }
155  var_idx++;
156  }
157  std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
158  Inputs cls_info (cls_feats, cls_feats_dim);
159  gnn_input.insert({"cluster_vars", cls_info});
160 
161  //RUN THE INFERENCE!!!
162  ATH_MSG_DEBUG("Prepared inputs, running inference...");
163  auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input);
164  ATH_MSG_DEBUG("Finished compute!");
165  return std::make_tuple(out_f, out_vc, out_vf);
166 }
167 
169  const std::vector<const xAOD::TauTrack *> &tracks,
170  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
171  std::map<std::string, std::map<std::string, double>>& scalarInputs,
172  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
173  scalarInputs.clear();
174  vectorInputs.clear();
175  // Populate input (sequence) map with input variables
176  for (const auto &varname : m_scalarCalc_inputs) {
177  if (!m_var_calc->compute(varname, tau,
178  scalarInputs[m_config.input_layer_scalar][varname])) {
179  ATH_MSG_WARNING("Error computing '" << varname
180  << "' returning default");
181  return false;
182  }
183  }
184 
185  for (const auto &varname : m_trackCalc_inputs) {
186  if (!m_var_calc->compute(varname, tau, tracks,
187  vectorInputs[m_config.input_layer_tracks][varname])) {
188  ATH_MSG_WARNING("Error computing '" << varname
189  << "' returning default");
190  return false;
191  }
192  }
193 
194  for (const auto &varname : m_clusterCalc_inputs) {
195  if (!m_var_calc->compute(varname, tau, clusters,
196  vectorInputs[m_config.input_layer_clusters][varname])) {
197  ATH_MSG_WARNING("Error computing '" << varname
198  << "' returning default");
199  return false;
200  }
201  }
202  return true;
203 }
TauGNN::m_config
const Config m_config
Definition: TauGNN.h:85
FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR
@ VECCHAR
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
TauGNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauGNN.h:41
TauGNN::m_scalarCalc_inputs
std::vector< std::string > m_scalarCalc_inputs
Definition: TauGNN.h:92
TauGNNUtils::get_calculator
std::unique_ptr< GNNVarCalc > get_calculator(const std::vector< std::string > &scalar_vars, const std::vector< std::string > &track_vars, const std::vector< std::string > &cluster_vars)
Definition: TauGNNUtils.cxx:113
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauGNN::InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauGNN.h:82
FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT
@ FLOAT
TauGNN::Inputs
FlavorTagDiscriminants::Inputs Inputs
Definition: TauGNN.h:76
TauGNN::m_clusterCalc_inputs
std::vector< std::string > m_clusterCalc_inputs
Definition: TauGNN.h:94
TauGNN::TauGNN
TauGNN(const std::string &nnFile, const Config &config)
Definition: TauGNN.cxx:15
TauGNNUtils.h
TauGNN::gnn_output_config
FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config
Definition: TauGNN.h:73
asg
Definition: DataHandleTestTool.h:28
FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT
@ VECFLOAT
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNN::compute
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:105
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNN::InputMap
std::map< std::string, VariableMap > InputMap
Definition: TauGNN.h:81
TauGNN.h
TauGNN::m_var_calc
std::unique_ptr< TauGNNUtils::GNNVarCalc > m_var_calc
Definition: TauGNN.h:97
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
TauGNN::m_trackCalc_inputs
std::vector< std::string > m_trackCalc_inputs
Definition: TauGNN.h:93
TauGNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauGNN.h:42
TauGNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauGNN.cxx:168
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
TauGNN::m_onnxUtil
std::shared_ptr< const FlavorTagDiscriminants::OnnxUtil > m_onnxUtil
Definition: TauGNN.h:46
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:72
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNN::~TauGNN
~TauGNN()
Definition: TauGNN.cxx:99
TauGNN::Config
Definition: TauGNN.h:39
TauGNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauGNN.h:40
OnnxUtil.h