ATLAS Offline Software
TauGNN.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "tauRecTools/TauGNN.h"
9 
10 #include <algorithm>
11 #include <fstream>
12 
14 
15 TauGNN::TauGNN(const std::string &nnFile, const Config &config, bool useTRT):
16  asg::AsgMessaging("TauGNN"),
17  m_saltModel(std::make_shared<FlavorTagInference::SaltModel>(nnFile)),
18  m_config{config}, m_useTRT(useTRT)
19  {
20  //==================================================//
21  // This part is ported from FTagDiscriminant GNN.cxx//
22  //==================================================//
23 
24  // get the configuration of the model outputs
25  FlavorTagInference::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig();
26 
27  //Let's see the output!
28  for (const auto& out_node: gnn_output_config) {
29  if(out_node.type==FlavorTagInference::SaltModelOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name);
30  if(out_node.type==FlavorTagInference::SaltModelOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name);
31  if(out_node.type==FlavorTagInference::SaltModelOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name);
32  }
33 
34  //Get model config (for inputs)
35  auto graph_config = m_saltModel->getGraphConfig();
36 
37  //===================================================//
38  // This part is ported from tauRecTools TauJetRNN.cxx//
39  //===================================================//
40 
41  // Search for input layer names specified in 'config'
42  auto node_is_scalar = [&config](const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig &in_node) {
43  return in_node.name == config.input_layer_scalar;
44  };
45  auto node_is_track = [&config](const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig &in_node) {
46  return in_node.name == config.input_layer_tracks;
47  };
48  auto node_is_cluster = [&config](const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig &in_node) {
49  return in_node.name == config.input_layer_clusters;
50  };
51 
52  auto scalar_node = std::find_if(graph_config.inputs.cbegin(),
53  graph_config.inputs.cend(),
54  node_is_scalar);
55 
56  auto track_node = std::find_if(graph_config.input_sequences.cbegin(),
57  graph_config.input_sequences.cend(),
58  node_is_track);
59 
60  auto cluster_node = std::find_if(graph_config.input_sequences.cbegin(),
61  graph_config.input_sequences.cend(),
62  node_is_cluster);
63 
64  // Check which input layers were found
65  auto has_scalar_node = scalar_node != graph_config.inputs.cend();
66  auto has_track_node = track_node != graph_config.input_sequences.cend();
67  auto has_cluster_node = cluster_node != graph_config.input_sequences.cend();
68  if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!");
69  if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!");
70  if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!");
71 
72  // Fill the variable names of each input layer into the corresponding vector
73  if (has_scalar_node) {
74  for (const auto &in : scalar_node->variables) {
75  std::string name = in.name;
76  m_scalarCalc_inputs.push_back(name);
77  }
78  }
79 
80  if (has_track_node) {
81  for (const auto &in : track_node->variables) {
82  std::string name = in.name;
83  m_trackCalc_inputs.push_back(name);
84  }
85  }
86 
87  if (has_cluster_node) {
88  for (const auto &in : cluster_node->variables) {
89  std::string name = in.name;
90  m_clusterCalc_inputs.push_back(name);
91  }
92  }
93  // Load the variable calculator
94  m_var_calc = std::make_unique<TauGNNUtils::GNNVarCalc>(m_useTRT);
95  ATH_MSG_INFO("TauGNN object initialized successfully!");
96 }
97 
99 
100 std::tuple<
101  std::map<std::string, float>,
102  std::map<std::string, std::vector<char>>,
103  std::map<std::string, std::vector<float>> >
105  const std::vector<const xAOD::TauTrack *> &tracks,
106  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
107  std::map<std::string, Inputs> gnn_input;
108  ATH_MSG_DEBUG("Starting compute...");
109  //Prepare input variables
110  auto [tau_feats, trk_feats, cls_feats] = calculateInputVariables(tau, tracks, clusters);
111 
112  std::vector<int64_t> tau_feats_dim = {static_cast<int64_t>(1), static_cast<int64_t>(tau_feats.size())};
113  std::vector<int64_t> trk_feats_dim = {static_cast<int64_t>(tracks.size()), static_cast<int64_t>(m_trackCalc_inputs.size())};
114  std::vector<int64_t> cls_feats_dim = {static_cast<int64_t>(clusters.size()), static_cast<int64_t>(m_clusterCalc_inputs.size())};
115 
116  Inputs tau_info (tau_feats, tau_feats_dim);
117  Inputs trk_info (trk_feats, trk_feats_dim);
118  Inputs cls_info (cls_feats, cls_feats_dim);
119 
120  gnn_input.insert({"tau_vars", tau_info});
121  gnn_input.insert({"track_vars", trk_info});
122  gnn_input.insert({"cluster_vars", cls_info});
123 
124  //RUN THE INFERENCE!!!
125  ATH_MSG_DEBUG("Prepared inputs, running inference...");
126  auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input);
127  ATH_MSG_DEBUG("Finished compute!");
128  return std::make_tuple(out_f, out_vc, out_vf);
129 }
130 
131 std::tuple<std::vector<float>, std::vector<float>, std::vector<float>>
133  const xAOD::TauJet &tau,
134  const std::vector<const xAOD::TauTrack *> &tracks,
135  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters
136  ) const {
137  // Populate input (sequence) map with input variables
138  std::vector<float> tau_feats;
139  std::vector<std::vector<float>> track_feats_2d, cluster_feats_2d;
140  for (const auto &varname : m_scalarCalc_inputs) {
141  tau_feats.push_back(m_var_calc->compute(varname, tau));
142  }
143  for (const auto &varname : m_trackCalc_inputs) {
144  track_feats_2d.push_back(m_var_calc->compute(varname, tau, tracks));
145  }
146  for (const auto &varname : m_clusterCalc_inputs) {
147  cluster_feats_2d.push_back(m_var_calc->compute(varname, tau, clusters));
148  }
149  //transposing the 2d feature arrays
150  std::vector<float> track_feats = flatten(track_feats_2d);
151  std::vector<float> cluster_feats = flatten(cluster_feats_2d);
152  return std::make_tuple(tau_feats, track_feats, cluster_feats);
153 }
FlavorTagInference::SaltModelOutput::OutputType::VECCHAR
@ VECCHAR
SaltModel.h
TauGNN::m_scalarCalc_inputs
std::vector< std::string > m_scalarCalc_inputs
Definition: TauGNN.h:92
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauGNN::m_clusterCalc_inputs
std::vector< std::string > m_clusterCalc_inputs
Definition: TauGNN.h:94
TauGNNUtils.h
python.base_data.config
config
Definition: base_data.py:20
asg
Definition: DataHandleTestTool.h:28
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
SaltModel
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:14
TauGNN::compute
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:104
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
FlavorTagInference::SaltModelGraphConfig::InputNodeConfig
Definition: SaltModelGraphConfig.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNN.h
TauGNN::calculateInputVariables
std::tuple< std::vector< float >, std::vector< float >, std::vector< float > > calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:132
TauGNN::m_var_calc
std::unique_ptr< TauGNNUtils::GNNVarCalc > m_var_calc
Definition: TauGNN.h:97
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
TauGNN::m_trackCalc_inputs
std::vector< std::string > m_trackCalc_inputs
Definition: TauGNN.h:93
TauGNN::m_saltModel
std::shared_ptr< const FlavorTagInference::SaltModel > m_saltModel
Definition: TauGNN.h:84
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
TauGNN::TauGNN
TauGNN(const std::string &nnFile, const Config &config, bool useTRT)
Definition: TauGNN.cxx:15
FlavorTagInference::SaltModelOutput::OutputType::FLOAT
@ FLOAT
TauGNN::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &mat) const
Definition: TauGNN.h:100
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNN::Inputs
FlavorTagInference::Inputs Inputs
Definition: TauGNN.h:75
TauGNN::~TauGNN
~TauGNN()
Definition: TauGNN.cxx:98
TauGNN::Config
Definition: TauGNN.h:39
FlavorTagInference::SaltModelOutput::OutputType::VECFLOAT
@ VECFLOAT
SaltModelGraphConfig.h