ATLAS Offline Software
TauGNN.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUGNN_H
6 #define TAURECTOOLS_TAUGNN_H
7 
8 #include "xAODTau/TauJet.h"
10 
12 
14 
15 #include <memory>
16 #include <string>
17 #include <map>
18 
19 namespace TauGNNUtils {
20  class GNNVarCalc;
21 }
22 
23 namespace FlavorTagInference{
24  class SaltModel;
25 }
26 
36 class TauGNN : public asg::AsgMessaging {
37 public:
38  // Configuration of the weight file structure
39  struct Config {
40  std::string input_layer_scalar;
41  std::string input_layer_tracks;
42  std::string input_layer_clusters;
43  std::string output_node_tau;
44  std::string output_node_jet;
45  };
46 public:
47  TauGNN(const std::string &nnFile, const Config &config, bool useTRT);
48  ~TauGNN();
49 
50  // Output the SaltModel tuple
51  std::tuple<
52  std::map<std::string, float>,
53  std::map<std::string, std::vector<char>>,
54  std::map<std::string, std::vector<float>> >
55  compute(const xAOD::TauJet &tau,
56  const std::vector<const xAOD::TauTrack *> &tracks,
57  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const;
58 
59  // Compute all input variables and store them in the maps that are passed by reference
60  std::tuple<std::vector<float>, std::vector<float>, std::vector<float>> calculateInputVariables(
61  const xAOD::TauJet &tau,
62  const std::vector<const xAOD::TauTrack *> &tracks,
63  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters
64  ) const;
65 
66  // Getter for the variable calculator
68  return m_var_calc.get();
69  }
70 
71  //Make the output config transparent to external tools
72  FlavorTagInference::SaltModel::OutputConfig gnn_output_config;
73 
74 private:
76  // Abbreviations for lwtnn
77  using VariableMap = std::map<std::string, double>;
78  using VectorMap = std::map<std::string, std::vector<double>>;
79 
80  using InputMap = std::map<std::string, VariableMap>;
81  using InputSequenceMap = std::map<std::string, VectorMap>;
82 
83 private:
84  std::shared_ptr<const FlavorTagInference::SaltModel> m_saltModel;
86 
87  // Names of the input variables
88  std::vector<std::string> m_scalar_inputs;
89  std::vector<std::string> m_track_inputs;
90  std::vector<std::string> m_cluster_inputs;
91  // Names passed to the variable calculator
92  std::vector<std::string> m_scalarCalc_inputs;
93  std::vector<std::string> m_trackCalc_inputs;
94  std::vector<std::string> m_clusterCalc_inputs;
95 
96  // Variable calculator to calculate input variables on the fly
97  std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc;
98  bool m_useTRT = true;
99 
100  std::vector<float> flatten(const std::vector<std::vector<float>>& mat) const {
101  std::vector<float> flat;
102  for (size_t col = 0; col < mat[0].size(); col++){
103  for (size_t row = 0; row < mat.size(); row++){
104  flat.push_back(mat[row][col]);
105  }
106  }
107  return flat;
108  };
109 };
110 
111 #endif // TAURECTOOLS_TAUGNN_H
SaltModel.h
TauGNN::m_config
const Config m_config
Definition: TauGNN.h:85
TauGNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauGNN.h:41
TauGNN::m_scalarCalc_inputs
std::vector< std::string > m_scalarCalc_inputs
Definition: TauGNN.h:92
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
TauGNN::InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauGNN.h:81
TauGNNUtils
Definition: TauGNNUtils.cxx:12
TauGNN::m_clusterCalc_inputs
std::vector< std::string > m_clusterCalc_inputs
Definition: TauGNN.h:94
mat
GeoMaterial * mat
Definition: LArDetectorConstructionTBEC.cxx:55
keylayer_zslicemap.row
row
Definition: keylayer_zslicemap.py:155
TauGNN::m_useTRT
bool m_useTRT
Definition: TauGNN.h:98
TauGNN::Config::output_node_tau
std::string output_node_tau
Definition: TauGNN.h:43
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
SaltModel
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h:14
AsgMessaging.h
TauGNN::m_track_inputs
std::vector< std::string > m_track_inputs
Definition: TauGNN.h:89
TauGNN::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauGNN.h:78
TauGNN::compute
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:104
TauGNN::Config::output_node_jet
std::string output_node_jet
Definition: TauGNN.h:44
TauGNN::VariableMap
std::map< std::string, double > VariableMap
Definition: TauGNN.h:77
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNN::m_cluster_inputs
std::vector< std::string > m_cluster_inputs
Definition: TauGNN.h:90
TauGNN::m_scalar_inputs
std::vector< std::string > m_scalar_inputs
Definition: TauGNN.h:88
TauGNN::gnn_output_config
FlavorTagInference::SaltModel::OutputConfig gnn_output_config
Definition: TauGNN.h:72
TauGNN::InputMap
std::map< std::string, VariableMap > InputMap
Definition: TauGNN.h:80
TauGNN::variable_calculator
const TauGNNUtils::GNNVarCalc * variable_calculator() const
Definition: TauGNN.h:67
TauGNN::calculateInputVariables
std::tuple< std::vector< float >, std::vector< float >, std::vector< float > > calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:132
TauGNN::m_var_calc
std::unique_ptr< TauGNNUtils::GNNVarCalc > m_var_calc
Definition: TauGNN.h:97
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
TauGNN
Wrapper around SaltModel to compute the output score of a model.
Definition: TauGNN.h:36
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
TauGNN::m_trackCalc_inputs
std::vector< std::string > m_trackCalc_inputs
Definition: TauGNN.h:93
TauGNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauGNN.h:42
TauGNN::m_saltModel
std::shared_ptr< const FlavorTagInference::SaltModel > m_saltModel
Definition: TauGNN.h:84
TauGNN::TauGNN
TauGNN(const std::string &nnFile, const Config &config, bool useTRT)
Definition: TauGNN.cxx:15
TauGNNUtils::GNNVarCalc
Tool to calculate input variables for the GNN-based tau identification.
Definition: TauGNNUtils.h:275
TauGNN::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &mat) const
Definition: TauGNN.h:100
TauJet.h
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNN::Inputs
FlavorTagInference::Inputs Inputs
Definition: TauGNN.h:75
TauGNN::~TauGNN
~TauGNN()
Definition: TauGNN.cxx:98
FlavorTagInference::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: FlavorTagInference/FlavorTagInference/SaltModel.h:28
TauGNN::Config
Definition: TauGNN.h:39
TauGNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauGNN.h:40