ATLAS Offline Software
TauGNN.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUGNN_H
6 #define TAURECTOOLS_TAUGNN_H
7 
8 #include "xAODTau/TauJet.h"
10 
12 
14 
15 #include <memory>
16 #include <string>
17 #include <map>
18 
19 namespace TauGNNUtils {
20  class GNNVarCalc;
21 }
22 
23 namespace FlavorTagDiscriminants{
24  class OnnxUtil;
25 }
26 
36 class TauGNN : public asg::AsgMessaging {
37 public:
38  // Configuration of the weight file structure
39  struct Config {
40  std::string input_layer_scalar;
41  std::string input_layer_tracks;
42  std::string input_layer_clusters;
43  std::string output_node_tau;
44  std::string output_node_jet;
45  };
46  std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil;
47 public:
48  TauGNN(const std::string &nnFile, const Config &config);
49  ~TauGNN();
50 
51  // Output the OnnxUtil tuple
52  std::tuple<
53  std::map<std::string, float>,
54  std::map<std::string, std::vector<char>>,
55  std::map<std::string, std::vector<float>> >
56  compute(const xAOD::TauJet &tau,
57  const std::vector<const xAOD::TauTrack *> &tracks,
58  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const;
59 
60  // Compute all input variables and store them in the maps that are passed by reference
61  bool calculateInputVariables(const xAOD::TauJet &tau,
62  const std::vector<const xAOD::TauTrack *> &tracks,
63  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
64  std::map<std::string, std::map<std::string, double>>& scalarInputs,
65  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const;
66 
67  // Getter for the variable calculator
69  return m_var_calc.get();
70  }
71 
72  //Make the output config transparent to external tools
73  FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config;
74 
75 private:
77  // Abbreviations for lwtnn
78  using VariableMap = std::map<std::string, double>;
79  using VectorMap = std::map<std::string, std::vector<double>>;
80 
81  using InputMap = std::map<std::string, VariableMap>;
82  using InputSequenceMap = std::map<std::string, VectorMap>;
83 
84 private:
86 
87  // Names of the input variables
88  std::vector<std::string> m_scalar_inputs;
89  std::vector<std::string> m_track_inputs;
90  std::vector<std::string> m_cluster_inputs;
91  // Names passed to the variable calculator
92  std::vector<std::string> m_scalarCalc_inputs;
93  std::vector<std::string> m_trackCalc_inputs;
94  std::vector<std::string> m_clusterCalc_inputs;
95 
96  // Variable calculator to calculate input variables on the fly
97  std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc;
98 };
99 
100 #endif // TAURECTOOLS_TAUGNN_H
TauGNN::m_config
const Config m_config
Definition: TauGNN.h:85
TauGNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauGNN.h:41
TauGNN::m_scalarCalc_inputs
std::vector< std::string > m_scalarCalc_inputs
Definition: TauGNN.h:92
TauGNN::InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauGNN.h:82
TauGNNUtils
Definition: TauGNNUtils.cxx:11
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagDiscriminants::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h:28
TauGNN::Inputs
FlavorTagDiscriminants::Inputs Inputs
Definition: TauGNN.h:76
TauGNN::m_clusterCalc_inputs
std::vector< std::string > m_clusterCalc_inputs
Definition: TauGNN.h:94
TauGNN::TauGNN
TauGNN(const std::string &nnFile, const Config &config)
Definition: TauGNN.cxx:15
TauGNN::gnn_output_config
FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config
Definition: TauGNN.h:73
TauGNN::Config::output_node_tau
std::string output_node_tau
Definition: TauGNN.h:43
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
AsgMessaging.h
OnnxUtil
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:14
TauGNN::m_track_inputs
std::vector< std::string > m_track_inputs
Definition: TauGNN.h:89
TauGNN::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauGNN.h:79
TauGNN::compute
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauGNN.cxx:105
TauGNN::Config::output_node_jet
std::string output_node_jet
Definition: TauGNN.h:44
TauGNN::VariableMap
std::map< std::string, double > VariableMap
Definition: TauGNN.h:78
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNN::m_cluster_inputs
std::vector< std::string > m_cluster_inputs
Definition: TauGNN.h:90
TauGNN::m_scalar_inputs
std::vector< std::string > m_scalar_inputs
Definition: TauGNN.h:88
TauGNN::InputMap
std::map< std::string, VariableMap > InputMap
Definition: TauGNN.h:81
TauGNN::variable_calculator
const TauGNNUtils::GNNVarCalc * variable_calculator() const
Definition: TauGNN.h:68
TauGNN::m_var_calc
std::unique_ptr< TauGNNUtils::GNNVarCalc > m_var_calc
Definition: TauGNN.h:97
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
TauGNN
Wrapper around ONNXUtil to compute the output score of a model.
Definition: TauGNN.h:36
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
TauGNN::m_trackCalc_inputs
std::vector< std::string > m_trackCalc_inputs
Definition: TauGNN.h:93
TauGNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauGNN.h:42
TauGNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauGNN.cxx:168
TauGNNUtils::GNNVarCalc
Tool to calculate input variables for the GNN-based tau identification.
Definition: TauGNNUtils.h:31
TauJet.h
TauGNN::m_onnxUtil
std::shared_ptr< const FlavorTagDiscriminants::OnnxUtil > m_onnxUtil
Definition: TauGNN.h:46
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNN::~TauGNN
~TauGNN()
Definition: TauGNN.cxx:99
TauGNN::Config
Definition: TauGNN.h:39
TauGNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauGNN.h:40
OnnxUtil.h