ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TauGNNEvaluator.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUGNNEVALUATOR_H
6 #define TAURECTOOLS_TAUGNNEVALUATOR_H
7 
9 
10 #include "tauRecTools/TauGNN.h"
11 
12 #include "xAODTau/TauJet.h"
14 
15 #include <memory>
16 
27 public:
29 
30  TauGNNEvaluator(const std::string &name = "TauGNNEvaluator");
31  virtual ~TauGNNEvaluator();
32 
33  virtual StatusCode initialize() override;
34  virtual StatusCode execute(xAOD::TauJet &tau) const override;
35  // Getter for the underlying RNN implementation
36  inline const TauGNN* get_gnn_inclusive() const { return m_net_inclusive.get(); }
37  inline const TauGNN* get_gnn_0p() const { return m_net_0p.get(); }
38  inline const TauGNN* get_gnn_1p() const { return m_net_1p.get(); }
39  inline const TauGNN* get_gnn_2p() const { return m_net_2p.get(); }
40  inline const TauGNN* get_gnn_3p() const { return m_net_3p.get(); }
41 
42  // Selects tracks to be used as input to the network
44  std::vector<const xAOD::TauTrack *> &out) const;
45 
46  // Selects clusters to be used as input to the network
48  std::vector<xAOD::CaloVertexedTopoCluster> &out) const;
49 
50  enum Discriminant {
52  PTau = 1
53  };
54 
55 private:
56  std::string m_output_varname;
57  std::string m_output_ptau;
58  std::string m_output_pjet;
59  unsigned int m_output_discriminant;
60 
62  std::string m_weightfile_0p;
63  std::string m_weightfile_1p;
64  std::string m_weightfile_2p;
65  std::string m_weightfile_3p;
67 
71  float m_minTauPt;
75 
76  // Configuration of the network file
77  std::string m_input_layer_scalar;
78  std::string m_input_layer_tracks;
80  std::string m_outnode_tau;
81  std::string m_outnode_jet;
82 
83  // Wrappers for lwtnn
84  std::unique_ptr<TauGNN> m_net_inclusive;
85  std::unique_ptr<TauGNN> m_net_0p;
86  std::unique_ptr<TauGNN> m_net_1p;
87  std::unique_ptr<TauGNN> m_net_2p;
88  std::unique_ptr<TauGNN> m_net_3p;
89 
90  std::unique_ptr<TauGNN> load_network(const std::string& network_file, const TauGNN::Config& config) const;
91 };
92 
93 #endif // TAURECTOOLS_TAUGNNEVALUATOR_H
TauGNNEvaluator::m_max_clusters
int m_max_clusters
Definition: TauGNNEvaluator.h:69
TauGNNEvaluator::Discriminant
Discriminant
Definition: TauGNNEvaluator.h:50
TauGNNEvaluator::m_input_layer_scalar
std::string m_input_layer_scalar
Definition: TauGNNEvaluator.h:77
TauGNNEvaluator::NegLogPJet
@ NegLogPJet
Definition: TauGNNEvaluator.h:51
TauGNNEvaluator::m_doTrackClassification
bool m_doTrackClassification
Definition: TauGNNEvaluator.h:73
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
TauGNNEvaluator::m_weightfile_2p
std::string m_weightfile_2p
Definition: TauGNNEvaluator.h:64
TauGNNEvaluator::get_gnn_3p
const TauGNN * get_gnn_3p() const
Definition: TauGNNEvaluator.h:40
TauGNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauGNNEvaluator.cxx:205
TauGNNEvaluator::get_gnn_inclusive
const TauGNN * get_gnn_inclusive() const
Definition: TauGNNEvaluator.h:36
TauGNNEvaluator::m_output_discriminant
unsigned int m_output_discriminant
Definition: TauGNNEvaluator.h:59
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauGNNEvaluator::m_net_3p
std::unique_ptr< TauGNN > m_net_3p
Definition: TauGNNEvaluator.h:88
TauGNNEvaluator::m_input_layer_tracks
std::string m_input_layer_tracks
Definition: TauGNNEvaluator.h:78
TauGNNEvaluator::load_network
std::unique_ptr< TauGNN > load_network(const std::string &network_file, const TauGNN::Config &config) const
Definition: TauGNNEvaluator.cxx:107
TauGNNEvaluator::m_input_layer_clusters
std::string m_input_layer_clusters
Definition: TauGNNEvaluator.h:79
TauGNNEvaluator::m_outnode_tau
std::string m_outnode_tau
Definition: TauGNNEvaluator.h:80
TauGNNEvaluator::m_net_1p
std::unique_ptr< TauGNN > m_net_1p
Definition: TauGNNEvaluator.h:86
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_output_ptau
std::string m_output_ptau
Definition: TauGNNEvaluator.h:57
TauGNNEvaluator::m_max_tracks
int m_max_tracks
Definition: TauGNNEvaluator.h:68
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:52
TauGNNEvaluator::m_net_0p
std::unique_ptr< TauGNN > m_net_0p
Definition: TauGNNEvaluator.h:85
TauGNNEvaluator::get_gnn_2p
const TauGNN * get_gnn_2p() const
Definition: TauGNNEvaluator.h:39
TauGNNEvaluator::m_output_pjet
std::string m_output_pjet
Definition: TauGNNEvaluator.h:58
TauGNNEvaluator::get_gnn_0p
const TauGNN * get_gnn_0p() const
Definition: TauGNNEvaluator.h:37
TauGNNEvaluator::m_net_2p
std::unique_ptr< TauGNN > m_net_2p
Definition: TauGNNEvaluator.h:87
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNNEvaluator::m_decorateTracks
bool m_decorateTracks
Definition: TauGNNEvaluator.h:74
TauGNNEvaluator::m_weightfile_0p
std::string m_weightfile_0p
Definition: TauGNNEvaluator.h:62
TauGNN.h
TauGNNEvaluator::m_net_inclusive
std::unique_ptr< TauGNN > m_net_inclusive
Definition: TauGNNEvaluator.h:84
TauGNNEvaluator::m_weightfile_inclusive
std::string m_weightfile_inclusive
Definition: TauGNNEvaluator.h:61
TauGNNEvaluator::m_max_cluster_dr
float m_max_cluster_dr
Definition: TauGNNEvaluator.h:70
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
TauGNN
Wrapper around SaltModel to compute the output score of a model.
Definition: TauGNN.h:32
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
TauGNNEvaluator::m_output_varname
std::string m_output_varname
Definition: TauGNNEvaluator.h:56
TauGNNEvaluator::m_minTauPt
float m_minTauPt
Definition: TauGNNEvaluator.h:71
TauGNNEvaluator::m_doVertexCorrection
bool m_doVertexCorrection
Definition: TauGNNEvaluator.h:72
TauJet.h
TauGNNEvaluator::m_outnode_jet
std::string m_outnode_jet
Definition: TauGNNEvaluator.h:81
TauGNNEvaluator::m_min_prong_track_pt
float m_min_prong_track_pt
Definition: TauGNNEvaluator.h:66
TauGNNEvaluator::PTau
@ PTau
Definition: TauGNNEvaluator.h:52
TauGNNEvaluator
Tool to calculate tau identification score from .onnx inputs.
Definition: TauGNNEvaluator.h:26
TauGNNEvaluator::m_weightfile_1p
std::string m_weightfile_1p
Definition: TauGNNEvaluator.h:63
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:50
TauGNNEvaluator::get_gnn_1p
const TauGNN * get_gnn_1p() const
Definition: TauGNNEvaluator.h:38
TauGNNEvaluator::m_weightfile_3p
std::string m_weightfile_3p
Definition: TauGNNEvaluator.h:65
TauGNN::Config
Definition: TauGNN.h:35
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:13
TauGNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauGNNEvaluator.cxx:234
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:126