ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNEvaluator.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
7
9
10#include <algorithm>
11
12
13TauGNNEvaluator::TauGNNEvaluator(const std::string &name):
14 TauRecToolBase(name) {}
15
17
19 ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks.value()<<" tracks, "<<m_max_clusters<<" clusters, and "<<m_max_hits<<" hits...");
20
21 // We can either use an inclussive GNN (e.g. Offline GNTauv0), or a prong-dependent GNN (e.g. HLT GNTau), not both!
22
23 if(!m_weightfile_inclusive.empty()) { // Prong-inclusive network
24 if(!m_weightfile_0p.empty() || !m_weightfile_1p.empty() || !m_weightfile_2p.empty() || !m_weightfile_3p.empty()) {
25 ATH_MSG_ERROR("Cannot load both prong-inclusive and prong-dependent networks!");
26 return StatusCode::FAILURE;
27 }
28
29 ATH_MSG_INFO("Loading prong-inclusive TauID GNN");
31 if(!m_net_inclusive) return StatusCode::FAILURE;
32
33 } else { // Prong-dependent networks
34
35 // 0-prong is optional
36 if(!m_weightfile_0p.empty()) {
37 ATH_MSG_INFO("Loading 0-prong TauID GNN");
39 if(!m_net_0p) return StatusCode::FAILURE;
40 }
41
42 ATH_MSG_INFO("Loading 1-prong TauID GNN");
44 if(!m_net_1p) return StatusCode::FAILURE;
45
46 // 2-prong is optional
47 if(!m_weightfile_2p.empty()) {
48 ATH_MSG_INFO("Loading 2-prong TauID GNN");
50 if(!m_net_2p) return StatusCode::FAILURE;
51 }
52
53 ATH_MSG_INFO("Loading 3-prong TauID GNN");
55 if(!m_net_3p) return StatusCode::FAILURE;
56 }
57
59 ATH_MSG_FATAL("Invalid TauGNNEvaluator discriminant setting: " << m_output_discriminant);
60 }
61
62 if(!m_tauContainerName.empty()) {
63 // We should move to using WriteDecorHandles in the future, but for now
64 // we create keys to enforce data-dependencies in the scheduler
65
68 ATH_CHECK(m_scoreHandleKey.initialize());
69 }
70
72 ATH_CHECK(m_pTauHandleKey.initialize());
73
75 ATH_CHECK(m_pJetHandleKey.initialize());
76 }
77
78 if(!m_tauContainerName.empty() && !m_hitsHandleKey.empty()) {
81 ATH_CHECK(m_hitsHandleKey.initialize());
82 } else if (m_max_hits > 0) {
83 ATH_MSG_ERROR("TauContainerName and HitsHandleKey must be provided to read hits for the GNN evaluation");
84 return StatusCode::FAILURE;
85 }
86
87 return StatusCode::SUCCESS;
88}
89
90std::unique_ptr<TauGNN> TauGNNEvaluator::load_network(const std::string& network_file) const {
91 // Use PathResolver to search for the weight files
92 if(network_file.empty()) return nullptr;
93
94 const std::string pr_network_file = find_file(network_file);
95 if(pr_network_file.empty()) {
96 ATH_MSG_ERROR("Could not find network weights: " << network_file);
97 return nullptr;
98 }
99
100 ATH_MSG_INFO("Using network config: " << pr_network_file);
101
102 // Load the weights and create the network
104 config.nnFile = pr_network_file;
105 config.input_layer_scalar = m_input_layer_scalar.value();
106 config.input_layer_tracks = m_input_layer_tracks.value();
107 config.input_layer_clusters = m_input_layer_clusters.value();
108 config.input_layer_hits = m_input_layer_hits.value();
109 config.output_node_tau = m_outnode_tau.value();
110 config.output_node_jet = m_outnode_jet.value();
111 config.n_max_tracks = m_max_tracks.value();
112 config.n_max_clusters = m_max_clusters.value();
113 config.max_dr_cluster = m_max_cluster_dr.value();
114 config.n_max_hits = m_max_hits.value();
115 config.doVertexCorrection = m_doVertexCorrection.value();
116 config.trackClassification = m_doTrackClassification.value();
117 config.useTRT = m_useTRT.value();
118 config.hits_decor_name = m_hits_decor_name;
119
120 std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(config);
121 if(!net) ATH_MSG_ERROR("No network configured.");
122
123 return net;
124}
125
127 // Output variable Decorators
129 const SG::Accessor<float> out_ptau(m_output_ptau);
130 const SG::Accessor<float> out_pjet(m_output_pjet);
131 const SG::Decorator<char> out_trkclass("GNTau_TrackClass");
132 // Set default score and overwrite later
133 if(m_output_discriminant != Discriminant::Disabled) output(tau) = -1111.0f;
134 out_ptau(tau) = -1111.0f;
135 out_pjet(tau) = -1111.0f;
136
137 //Skip execution for low-pT taus to save resources
138 if (tau.pt() < m_minTauPt) {
139 return StatusCode::SUCCESS;
140 }
141
142 // save CPU when running PHYS derivations
144 if (tau.nTracks()>5) return StatusCode::SUCCESS;
145 }
146
147 // save CPU when running in RAWtoALL for tau trigger monitoring purpose
149 if (tau.nTracks()!=1 && tau.nTracks()!=3) return StatusCode::SUCCESS;
150 }
151
152 // Network outputs
153 std::map<std::string, float> out_f;
154 std::map<std::string, std::vector<char>> out_vc;
155 std::map<std::string, std::vector<float>> out_vf;
156
157 // Evaluate networks
158 ATH_MSG_DEBUG("Evaluating GNN for tau with nTracks = " << tau.nTracksCharged());
159 if(m_net_inclusive) {
160 std::tie(out_f, out_vc, out_vf) = m_net_inclusive->compute(tau);
161 } else {
162 // First we calculate the tau prongness
163 int n_tracks = tau.nTracksCharged();
164 // in trigger, we need to apply a min pT cut on the tracks to count the prongs,
165 // as no track classification is available
167 auto trks = tau.allTracks();
168 const float threshold = m_min_prong_track_pt;
169 n_tracks = std::count_if(trks.begin(), trks.end(),
170 [&threshold](const xAOD::TauTrack* trk) { return trk->pt() > threshold; }
171 );
172 }
173
174 if(n_tracks == 0 && m_net_0p) std::tie(out_f, out_vc, out_vf) = m_net_0p->compute(tau);
175 else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) = m_net_1p->compute(tau);
176 else if(n_tracks == 2) {
177 if(m_net_2p) std::tie(out_f, out_vc, out_vf) = m_net_2p->compute(tau);
178 else std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau);
179 } else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau);
180 }
181
182 // Store scores only if the inferences actually ran
183 if(out_f.contains(m_outnode_tau)) {
185 output(tau) = std::log10(1/(1-out_f.at(m_outnode_tau)));
187 output(tau) = out_f.at(m_outnode_tau);
188 }
189
190 out_ptau(tau) = out_f.at(m_outnode_tau);
191 out_pjet(tau) = out_f.at(m_outnode_jet);
192 }
193
194 return StatusCode::SUCCESS;
195}
196
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
Helper class to provide type-safe access to aux data.
Helper class to provide type-safe access to aux data.
Definition Decorator.h:59
Gaudi::Property< std::string > m_weightfile_inclusive
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Gaudi::Property< int > m_max_tracks
Gaudi::Property< std::string > m_input_layer_hits
std::unique_ptr< TauGNN > m_net_1p
Gaudi::Property< std::string > m_input_layer_scalar
std::string m_hits_decor_name
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
virtual ~TauGNNEvaluator()
Gaudi::Property< float > m_max_cluster_dr
std::unique_ptr< TauGNN > m_net_3p
Gaudi::Property< float > m_minTauPt
std::unique_ptr< TauGNN > load_network(const std::string &network_file) const
Gaudi::Property< int > m_output_discriminant
Gaudi::Property< std::string > m_input_layer_clusters
Gaudi::Property< int > m_max_clusters
Gaudi::Property< float > m_min_prong_track_pt
Gaudi::Property< std::string > m_outnode_tau
std::unique_ptr< TauGNN > m_net_0p
Gaudi::Property< std::string > m_tauContainerName
Gaudi::Property< bool > m_doVertexCorrection
Gaudi::Property< int > m_max_hits
Gaudi::Property< bool > m_applyTightTrackSel
Gaudi::Property< std::string > m_output_varname
Gaudi::Property< std::string > m_output_pjet
Gaudi::Property< bool > m_useTRT
std::unique_ptr< TauGNN > m_net_inclusive
std::unique_ptr< TauGNN > m_net_2p
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_pJetHandleKey
Gaudi::Property< bool > m_doTrackClassification
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_pTauHandleKey
Gaudi::Property< std::string > m_outnode_jet
Gaudi::Property< std::string > m_output_ptau
Gaudi::Property< std::string > m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_3p
SG::ReadDecorHandleKey< xAOD::TauJetContainer > m_hitsHandleKey
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_applyLooseTrackSel
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_weightfile_2p
Gaudi::Property< std::string > m_input_layer_tracks
Gaudi::Property< std::string > m_weightfile_0p
TauRecToolBase(const std::string &name)
std::string find_file(const std::string &fname) const
virtual double pt() const
The transverse momentum ( ) of the particle.
size_t nTracksCharged() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".