ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNEvaluator.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
7
9
10#include <algorithm>
11
12
13TauGNNEvaluator::TauGNNEvaluator(const std::string &name):
14 TauRecToolBase(name) {}
15
17
19 ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks.value()<<" tracks and "<<m_max_clusters<<" clusters...");
20
21 // We can either use an inclussive GNN (e.g. Offline GNTauv0), or a prong-dependent GNN (e.g. HLT GNTau), not both!
22
23 if(!m_weightfile_inclusive.empty()) { // Prong-inclusive network
24 if(!m_weightfile_0p.empty() || !m_weightfile_1p.empty() || !m_weightfile_2p.empty() || !m_weightfile_3p.empty()) {
25 ATH_MSG_ERROR("Cannot load both prong-inclusive and prong-dependent networks!");
26 return StatusCode::FAILURE;
27 }
28
29 ATH_MSG_INFO("Loading prong-inclusive TauID GNN");
31 if(!m_net_inclusive) return StatusCode::FAILURE;
32
33 } else { // Prong-dependent networks
34
35 // 0-prong is optional
36 if(!m_weightfile_0p.empty()) {
37 ATH_MSG_INFO("Loading 0-prong TauID GNN");
39 if(!m_net_0p) return StatusCode::FAILURE;
40 }
41
42 ATH_MSG_INFO("Loading 1-prong TauID GNN");
44 if(!m_net_1p) return StatusCode::FAILURE;
45
46 // 2-prong is optional
47 if(!m_weightfile_2p.empty()) {
48 ATH_MSG_INFO("Loading 2-prong TauID GNN");
50 if(!m_net_2p) return StatusCode::FAILURE;
51 }
52
53 ATH_MSG_INFO("Loading 3-prong TauID GNN");
55 if(!m_net_3p) return StatusCode::FAILURE;
56 }
57
59 ATH_MSG_FATAL("Invalid TauGNNEvaluator discriminant setting: " << m_output_discriminant);
60 }
61
62 if (!m_tauContainerName.empty()){
64 ATH_CHECK(m_scoreHandleKey.initialize());
65 }
66
67 return StatusCode::SUCCESS;
68}
69
70std::unique_ptr<TauGNN> TauGNNEvaluator::load_network(const std::string& network_file) const {
71 // Use PathResolver to search for the weight files
72 if(network_file.empty()) return nullptr;
73
74 const std::string pr_network_file = find_file(network_file);
75 if(pr_network_file.empty()) {
76 ATH_MSG_ERROR("Could not find network weights: " << network_file);
77 return nullptr;
78 }
79
80 ATH_MSG_INFO("Using network config: " << pr_network_file);
81
82 // Load the weights and create the network
84 config.nnFile = pr_network_file;
85 config.input_layer_scalar = m_input_layer_scalar.value();
86 config.input_layer_tracks = m_input_layer_tracks.value();
87 config.input_layer_clusters = m_input_layer_clusters.value();
88 config.output_node_tau = m_outnode_tau.value();
89 config.output_node_jet = m_outnode_jet.value();
90 config.n_max_tracks = m_max_tracks.value();
91 config.n_max_clusters = m_max_clusters.value();
92 config.max_dr_cluster = m_max_cluster_dr.value();
93 config.doVertexCorrection = m_doVertexCorrection.value();
94 config.trackClassification = m_doTrackClassification.value();
95 config.useTRT = m_useTRT.value();
96
97 std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(config);
98 if(!net) ATH_MSG_ERROR("No network configured.");
99
100 return net;
101}
102
104 // Output variable Decorators
106 const SG::Accessor<float> out_ptau(m_output_ptau);
107 const SG::Accessor<float> out_pjet(m_output_pjet);
108 const SG::Decorator<char> out_trkclass("GNTau_TrackClass");
109 // Set default score and overwrite later
110 output(tau) = -1111.0f;
111 out_ptau(tau) = -1111.0f;
112 out_pjet(tau) = -1111.0f;
113
114 //Skip execution for low-pT taus to save resources
115 if (tau.pt() < m_minTauPt) {
116 return StatusCode::SUCCESS;
117 }
118
119 // save CPU when running PHYS derivations
121 if (tau.nTracks()>5) return StatusCode::SUCCESS;
122 }
123
124 // save CPU when running in RAWtoALL for tau trigger monitoring purpose
126 if (tau.nTracks()!=1 && tau.nTracks()!=3) return StatusCode::SUCCESS;
127 }
128
129 // Network outputs
130 std::map<std::string, float> out_f;
131 std::map<std::string, std::vector<char>> out_vc;
132 std::map<std::string, std::vector<float>> out_vf;
133
134 // Evaluate networks
135 ATH_MSG_DEBUG("Evaluating GNN for tau with nTracks = " << tau.nTracksCharged());
136 if(m_net_inclusive) {
137 std::tie(out_f, out_vc, out_vf) = m_net_inclusive->compute(tau);
138 } else {
139 // First we calculate the tau prongness
140 int n_tracks = tau.nTracksCharged();
141 // in trigger, we need to apply a min pT cut on the tracks to count the prongs,
142 // as no track classification is available
144 auto trks = tau.allTracks();
145 const float threshold = m_min_prong_track_pt;
146 n_tracks = std::count_if(trks.begin(), trks.end(),
147 [&threshold](const xAOD::TauTrack* trk) { return trk->pt() > threshold; }
148 );
149 }
150
151 if(n_tracks == 0 && m_net_0p) std::tie(out_f, out_vc, out_vf) = m_net_0p->compute(tau);
152 else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) = m_net_1p->compute(tau);
153 else if(n_tracks == 2) {
154 if(m_net_2p) std::tie(out_f, out_vc, out_vf) = m_net_2p->compute(tau);
155 else std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau);
156 } else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau);
157 }
158
159 // Store scores only if the inferences actually ran
160 if(out_f.contains(m_outnode_tau)) {
162 output(tau) = std::log10(1/(1-out_f.at(m_outnode_tau)));
164 output(tau) = out_f.at(m_outnode_tau);
165 }
166
167 out_ptau(tau) = out_f.at(m_outnode_tau);
168 out_pjet(tau) = out_f.at(m_outnode_jet);
169 }
170
171 return StatusCode::SUCCESS;
172}
173
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
Helper class to provide type-safe access to aux data.
Helper class to provide type-safe access to aux data.
Definition Decorator.h:59
Gaudi::Property< std::string > m_weightfile_inclusive
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Gaudi::Property< int > m_max_tracks
std::unique_ptr< TauGNN > m_net_1p
Gaudi::Property< std::string > m_input_layer_scalar
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
virtual ~TauGNNEvaluator()
Gaudi::Property< float > m_max_cluster_dr
std::unique_ptr< TauGNN > m_net_3p
Gaudi::Property< unsigned int > m_output_discriminant
Gaudi::Property< float > m_minTauPt
std::unique_ptr< TauGNN > load_network(const std::string &network_file) const
Gaudi::Property< std::string > m_input_layer_clusters
Gaudi::Property< int > m_max_clusters
Gaudi::Property< float > m_min_prong_track_pt
Gaudi::Property< std::string > m_outnode_tau
std::unique_ptr< TauGNN > m_net_0p
Gaudi::Property< std::string > m_tauContainerName
Gaudi::Property< bool > m_doVertexCorrection
Gaudi::Property< bool > m_applyTightTrackSel
Gaudi::Property< std::string > m_output_varname
Gaudi::Property< std::string > m_output_pjet
Gaudi::Property< bool > m_useTRT
std::unique_ptr< TauGNN > m_net_inclusive
std::unique_ptr< TauGNN > m_net_2p
Gaudi::Property< bool > m_doTrackClassification
Gaudi::Property< std::string > m_outnode_jet
Gaudi::Property< std::string > m_output_ptau
Gaudi::Property< std::string > m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_3p
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_applyLooseTrackSel
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_weightfile_2p
Gaudi::Property< std::string > m_input_layer_tracks
Gaudi::Property< std::string > m_weightfile_0p
TauRecToolBase(const std::string &name)
std::string find_file(const std::string &fname) const
virtual double pt() const
The transverse momentum ( ) of the particle.
size_t nTracksCharged() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".