ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNEvaluator.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef TAURECTOOLS_TAUGNNEVALUATOR_H
6#define TAURECTOOLS_TAUGNNEVALUATOR_H
7
10
11#include "xAODTau/TauJet.h"
14
17
18#include <memory>
19
30public:
32
33 TauGNNEvaluator(const std::string &name = "TauGNNEvaluator");
34 virtual ~TauGNNEvaluator();
35
36 virtual StatusCode initialize() override;
37 virtual StatusCode execute(xAOD::TauJet &tau) const override;
38 // Getter for the underlying RNN implementation
39 inline const TauGNN* get_gnn_inclusive() const { return m_net_inclusive.get(); }
40 inline const TauGNN* get_gnn_0p() const { return m_net_0p.get(); }
41 inline const TauGNN* get_gnn_1p() const { return m_net_1p.get(); }
42 inline const TauGNN* get_gnn_2p() const { return m_net_2p.get(); }
43 inline const TauGNN* get_gnn_3p() const { return m_net_3p.get(); }
44
47 PTau = 1
48 };
49
50private:
51
52 Gaudi::Property<std::string> m_tauContainerName{this, "TauContainerName", "", "Name of TauJetContainer, must be set when using "};
53 SG::WriteDecorHandleKey<xAOD::TauJetContainer> m_scoreHandleKey{this, "ScoreHandleKey","","Output Score"};
54
55 // properties
56 Gaudi::Property<std::string> m_weightfile_inclusive{this, "NetworkFileInclusive", ""};
57 Gaudi::Property<std::string> m_weightfile_0p{this, "NetworkFile0P", ""};
58 Gaudi::Property<std::string> m_weightfile_1p{this, "NetworkFile1P", ""};
59 Gaudi::Property<std::string> m_weightfile_2p{this, "NetworkFile2P", ""};
60 Gaudi::Property<std::string> m_weightfile_3p{this, "NetworkFile3P", ""};
61 Gaudi::Property<std::string> m_input_layer_scalar{this, "InputLayerScalar","tau_vars"};
62 Gaudi::Property<std::string> m_input_layer_tracks{this, "InputLayerTracks","track_vars"};
63 Gaudi::Property<std::string> m_input_layer_clusters{this, "InputLayerClusters","cluster_vars"};
64 Gaudi::Property<std::string> m_output_varname{this, "OutputVarname", "GNTauScore"};
65 Gaudi::Property<std::string> m_output_ptau{this, "OutputPTau", "GNTauProbTau"};
66 Gaudi::Property<std::string> m_output_pjet{this, "OutputPJet", "GNTauProbJet"};
67 Gaudi::Property<unsigned int> m_output_discriminant{this, "OutputDiscriminant", Discriminant::NegLogPJet,
68 "Discriminant used to calculate the output score: 0 -> -log(PJet), 1 -> PTau"};
69 Gaudi::Property<int> m_max_tracks{this, "MaxTracks", 30};
70 Gaudi::Property<int> m_max_clusters{this, "MaxClusters", 20};
71 Gaudi::Property<float> m_max_cluster_dr{this, "MaxClusterDR", 1.0f};
72 Gaudi::Property<bool> m_doVertexCorrection{this, "VertexCorrection", true};
73 Gaudi::Property<bool> m_doTrackClassification{this, "TrackClassification", true};
74 Gaudi::Property<bool> m_useTRT{this, "useTRT", true};
75 Gaudi::Property<float> m_minTauPt{this, "MinTauPt", 0.};
76 Gaudi::Property<bool> m_applyLooseTrackSel{this, "ApplyLooseTrackSel", false};
77 Gaudi::Property<bool> m_applyTightTrackSel{this, "ApplyTightTrackSel", false};
78 Gaudi::Property<std::string> m_outnode_tau{this, "NodeNameTau", "GN2TauNoAux_pb"};
79 Gaudi::Property<std::string> m_outnode_jet{this, "NodeNameJet", "GN2TauNoAux_pu"};
80 Gaudi::Property<float> m_min_prong_track_pt{this, "MinProngTrackPt", 0.};
81
82 // Wrappers for lwtnn
83 std::unique_ptr<TauGNN> m_net_inclusive;
84 std::unique_ptr<TauGNN> m_net_0p;
85 std::unique_ptr<TauGNN> m_net_1p;
86 std::unique_ptr<TauGNN> m_net_2p;
87 std::unique_ptr<TauGNN> m_net_3p;
88
89 std::unique_ptr<TauGNN> load_network(const std::string& network_file) const;
90};
91
92#endif // TAURECTOOLS_TAUGNNEVALUATOR_H
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Evaluate cluster kinematics with a different vertex / signal state.
The base class for all tau tools.
Property holding a SG store/key/clid/attr name from which a WriteDecorHandle is made.
Gaudi::Property< std::string > m_weightfile_inclusive
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Gaudi::Property< int > m_max_tracks
const TauGNN * get_gnn_inclusive() const
std::unique_ptr< TauGNN > m_net_1p
Gaudi::Property< std::string > m_input_layer_scalar
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
virtual ~TauGNNEvaluator()
Gaudi::Property< float > m_max_cluster_dr
std::unique_ptr< TauGNN > m_net_3p
Gaudi::Property< unsigned int > m_output_discriminant
Gaudi::Property< float > m_minTauPt
std::unique_ptr< TauGNN > load_network(const std::string &network_file) const
Gaudi::Property< std::string > m_input_layer_clusters
Gaudi::Property< int > m_max_clusters
Gaudi::Property< float > m_min_prong_track_pt
Gaudi::Property< std::string > m_outnode_tau
std::unique_ptr< TauGNN > m_net_0p
Gaudi::Property< std::string > m_tauContainerName
Gaudi::Property< bool > m_doVertexCorrection
Gaudi::Property< bool > m_applyTightTrackSel
const TauGNN * get_gnn_2p() const
Gaudi::Property< std::string > m_output_varname
Gaudi::Property< std::string > m_output_pjet
Gaudi::Property< bool > m_useTRT
const TauGNN * get_gnn_1p() const
std::unique_ptr< TauGNN > m_net_inclusive
std::unique_ptr< TauGNN > m_net_2p
const TauGNN * get_gnn_0p() const
const TauGNN * get_gnn_3p() const
Gaudi::Property< bool > m_doTrackClassification
Gaudi::Property< std::string > m_outnode_jet
Gaudi::Property< std::string > m_output_ptau
Gaudi::Property< std::string > m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_3p
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_applyLooseTrackSel
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_weightfile_2p
Gaudi::Property< std::string > m_input_layer_tracks
Gaudi::Property< std::string > m_weightfile_0p
Wrapper around SaltModel to compute the output score of a model.
Definition TauGNN.h:36
TauRecToolBase(const std::string &name)
TauJet_v3 TauJet
Definition of the current "tau version".