ATLAS Offline Software
TauGNNEvaluator.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUGNNEVALUATOR_H
6 #define TAURECTOOLS_TAUGNNEVALUATOR_H
7 
9 #include "tauRecTools/TauGNN.h"
10 
11 #include "xAODTau/TauJet.h"
14 
17 
18 #include <memory>
19 
30 public:
32 
33  TauGNNEvaluator(const std::string &name = "TauGNNEvaluator");
34  virtual ~TauGNNEvaluator();
35 
36  virtual StatusCode initialize() override;
37  virtual StatusCode execute(xAOD::TauJet &tau) const override;
38  // Getter for the underlying RNN implementation
39  inline const TauGNN* get_gnn_inclusive() const { return m_net_inclusive.get(); }
40  inline const TauGNN* get_gnn_0p() const { return m_net_0p.get(); }
41  inline const TauGNN* get_gnn_1p() const { return m_net_1p.get(); }
42  inline const TauGNN* get_gnn_2p() const { return m_net_2p.get(); }
43  inline const TauGNN* get_gnn_3p() const { return m_net_3p.get(); }
44 
45  // Selects tracks to be used as input to the network
47  std::vector<const xAOD::TauTrack *> &out) const;
48 
49  // Selects clusters to be used as input to the network
51  std::vector<xAOD::CaloVertexedTopoCluster> &out) const;
52 
53  enum Discriminant {
55  PTau = 1
56  };
57 
58 private:
59 
60  Gaudi::Property<std::string> m_tauContainerName{this, "TauContainerName", "", "Name of TauJetContainer, must be set when using "};
61  SG::WriteDecorHandleKey<xAOD::TauJetContainer> m_scoreHandleKey{this, "ScoreHandleKey","","Output Score"};
62 
63  // properties
64  Gaudi::Property<std::string> m_weightfile_inclusive{this, "NetworkFileInclusive", ""};
65  Gaudi::Property<std::string> m_weightfile_0p{this, "NetworkFile0P", ""};
66  Gaudi::Property<std::string> m_weightfile_1p{this, "NetworkFile1P", ""};
67  Gaudi::Property<std::string> m_weightfile_2p{this, "NetworkFile2P", ""};
68  Gaudi::Property<std::string> m_weightfile_3p{this, "NetworkFile3P", ""};
69  Gaudi::Property<std::string> m_output_varname{this, "OutputVarname", "GNTauScore"};
70  Gaudi::Property<std::string> m_output_ptau{this, "OutputPTau", "GNTauProbTau"};
71  Gaudi::Property<std::string> m_output_pjet{this, "OutputPJet", "GNTauProbJet"};
72  Gaudi::Property<unsigned int> m_output_discriminant{this, "OutputDiscriminant", Discriminant::NegLogPJet,
73  "Discriminant used to calculate the output score: 0 -> -log(PJet), 1 -> PTau"};
74  Gaudi::Property<int> m_max_tracks{this, "MaxTracks", 30};
75  Gaudi::Property<int> m_max_clusters{this, "MaxClusters", 20};
76  Gaudi::Property<float> m_max_cluster_dr{this, "MaxClusterDR", 1.0f};
77  Gaudi::Property<bool> m_doVertexCorrection{this, "VertexCorrection", true};
78  Gaudi::Property<bool> m_doTrackClassification{this, "TrackClassification", true};
79  Gaudi::Property<float> m_minTauPt{this, "MinTauPt", 0.};
80  Gaudi::Property<bool> m_applyLooseTrackSel{this, "ApplyLooseTrackSel", false};
81  Gaudi::Property<bool> m_applyTightTrackSel{this, "ApplyTightTrackSel", false};
82  Gaudi::Property<float> m_min_prong_track_pt{this, "MinProngTrackPt", 0.};
83  Gaudi::Property<std::string> m_input_layer_scalar{this, "InputLayerScalar", "tau_vars"};
84  Gaudi::Property<std::string> m_input_layer_tracks{this, "InputLayerTracks", "track_vars"};
85  Gaudi::Property<std::string> m_input_layer_clusters{this, "InputLayerClusters", "cluster_vars"};
86  Gaudi::Property<std::string> m_outnode_tau{this, "NodeNameTau", "GN2TauNoAux_pb"};
87  Gaudi::Property<std::string> m_outnode_jet{this, "NodeNameJet", "GN2TauNoAux_pu"};
88 
89  // Wrappers for lwtnn
90  std::unique_ptr<TauGNN> m_net_inclusive;
91  std::unique_ptr<TauGNN> m_net_0p;
92  std::unique_ptr<TauGNN> m_net_1p;
93  std::unique_ptr<TauGNN> m_net_2p;
94  std::unique_ptr<TauGNN> m_net_3p;
95 
96  std::unique_ptr<TauGNN> load_network(const std::string& network_file, const TauGNN::Config& config) const;
97 };
98 
99 #endif // TAURECTOOLS_TAUGNNEVALUATOR_H
TauGNNEvaluator::m_doTrackClassification
Gaudi::Property< bool > m_doTrackClassification
Definition: TauGNNEvaluator.h:78
SG::WriteDecorHandleKey
Property holding a SG store/key/clid/attr name from which a WriteDecorHandle is made.
Definition: StoreGate/StoreGate/WriteDecorHandleKey.h:89
TauGNNEvaluator::Discriminant
Discriminant
Definition: TauGNNEvaluator.h:53
TauGNNEvaluator::m_output_pjet
Gaudi::Property< std::string > m_output_pjet
Definition: TauGNNEvaluator.h:71
PropertyWrapper.h
TauGNNEvaluator::m_applyLooseTrackSel
Gaudi::Property< bool > m_applyLooseTrackSel
Definition: TauGNNEvaluator.h:80
TauGNNEvaluator::NegLogPJet
@ NegLogPJet
Definition: TauGNNEvaluator.h:54
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
TauGNNEvaluator::get_gnn_3p
const TauGNN * get_gnn_3p() const
Definition: TauGNNEvaluator.h:43
TauGNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauGNNEvaluator.cxx:182
TauGNNEvaluator::get_gnn_inclusive
const TauGNN * get_gnn_inclusive() const
Definition: TauGNNEvaluator.h:39
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
TauGNNEvaluator::m_net_3p
std::unique_ptr< TauGNN > m_net_3p
Definition: TauGNNEvaluator.h:94
TauGNNEvaluator::load_network
std::unique_ptr< TauGNN > load_network(const std::string &network_file, const TauGNN::Config &config) const
Definition: TauGNNEvaluator.cxx:81
TauGNNEvaluator::m_weightfile_inclusive
Gaudi::Property< std::string > m_weightfile_inclusive
Definition: TauGNNEvaluator.h:64
TauGNNEvaluator::m_max_clusters
Gaudi::Property< int > m_max_clusters
Definition: TauGNNEvaluator.h:75
TauGNNEvaluator::m_net_1p
std::unique_ptr< TauGNN > m_net_1p
Definition: TauGNNEvaluator.h:92
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_weightfile_3p
Gaudi::Property< std::string > m_weightfile_3p
Definition: TauGNNEvaluator.h:68
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:21
TauGNNEvaluator::m_applyTightTrackSel
Gaudi::Property< bool > m_applyTightTrackSel
Definition: TauGNNEvaluator.h:81
TauGNNEvaluator::m_max_cluster_dr
Gaudi::Property< float > m_max_cluster_dr
Definition: TauGNNEvaluator.h:76
TauGNNEvaluator::m_net_0p
std::unique_ptr< TauGNN > m_net_0p
Definition: TauGNNEvaluator.h:91
TauGNNEvaluator::get_gnn_2p
const TauGNN * get_gnn_2p() const
Definition: TauGNNEvaluator.h:42
TauGNNEvaluator::get_gnn_0p
const TauGNN * get_gnn_0p() const
Definition: TauGNNEvaluator.h:40
TauGNNEvaluator::m_net_2p
std::unique_ptr< TauGNN > m_net_2p
Definition: TauGNNEvaluator.h:93
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
TauGNNEvaluator::m_weightfile_2p
Gaudi::Property< std::string > m_weightfile_2p
Definition: TauGNNEvaluator.h:67
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNNEvaluator::m_output_discriminant
Gaudi::Property< unsigned int > m_output_discriminant
Definition: TauGNNEvaluator.h:72
TauGNNEvaluator::m_tauContainerName
Gaudi::Property< std::string > m_tauContainerName
Definition: TauGNNEvaluator.h:60
TauGNNEvaluator::m_max_tracks
Gaudi::Property< int > m_max_tracks
Definition: TauGNNEvaluator.h:74
TauJetContainer.h
TauGNN.h
TauGNNEvaluator::m_net_inclusive
std::unique_ptr< TauGNN > m_net_inclusive
Definition: TauGNNEvaluator.h:90
TauGNNEvaluator::m_minTauPt
Gaudi::Property< float > m_minTauPt
Definition: TauGNNEvaluator.h:79
TauGNNEvaluator::m_outnode_tau
Gaudi::Property< std::string > m_outnode_tau
Definition: TauGNNEvaluator.h:86
TauGNNEvaluator::m_input_layer_tracks
Gaudi::Property< std::string > m_input_layer_tracks
Definition: TauGNNEvaluator.h:84
TauGNNEvaluator::m_input_layer_scalar
Gaudi::Property< std::string > m_input_layer_scalar
Definition: TauGNNEvaluator.h:83
TauGNNEvaluator::m_min_prong_track_pt
Gaudi::Property< float > m_min_prong_track_pt
Definition: TauGNNEvaluator.h:82
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
TauGNN
Wrapper around SaltModel to compute the output score of a model.
Definition: TauGNN.h:36
TauGNNEvaluator::m_doVertexCorrection
Gaudi::Property< bool > m_doVertexCorrection
Definition: TauGNNEvaluator.h:77
TauGNNEvaluator::m_outnode_jet
Gaudi::Property< std::string > m_outnode_jet
Definition: TauGNNEvaluator.h:87
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
WriteDecorHandleKey.h
TauGNNEvaluator::m_weightfile_0p
Gaudi::Property< std::string > m_weightfile_0p
Definition: TauGNNEvaluator.h:65
TauJet.h
TauGNNEvaluator::m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_1p
Definition: TauGNNEvaluator.h:66
TauGNNEvaluator::PTau
@ PTau
Definition: TauGNNEvaluator.h:55
TauGNNEvaluator
Tool to calculate tau identification score from .onnx inputs.
Definition: TauGNNEvaluator.h:29
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:19
TauGNNEvaluator::get_gnn_1p
const TauGNN * get_gnn_1p() const
Definition: TauGNNEvaluator.h:41
TauGNN::Config
Definition: TauGNN.h:39
TauGNNEvaluator::m_input_layer_clusters
Gaudi::Property< std::string > m_input_layer_clusters
Definition: TauGNNEvaluator.h:85
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:13
TauGNNEvaluator::m_scoreHandleKey
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
Definition: TauGNNEvaluator.h:61
TauGNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauGNNEvaluator.cxx:211
TauGNNEvaluator::m_output_ptau
Gaudi::Property< std::string > m_output_ptau
Definition: TauGNNEvaluator.h:70
TauGNNEvaluator::m_output_varname
Gaudi::Property< std::string > m_output_varname
Definition: TauGNNEvaluator.h:69
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:100