ATLAS Offline Software
Loading...
Searching...
No Matches
TauJetRNNEvaluator.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
8
10
11#include <algorithm>
12
13
15 TauRecToolBase(name),
16 m_net_0p(nullptr),
17 m_net_1p(nullptr),
18 m_net_2p(nullptr),
19 m_net_3p(nullptr) {
20}
21
22
24
26 ATH_MSG_INFO("Initializing TauJetRNNEvaluator");
27
28 std::string weightfile_0p("");
29 std::string weightfile_1p("");
30 std::string weightfile_2p("");
31 std::string weightfile_3p("");
32
33 // Use PathResolver to search for the weight files
34 if (!m_weightfile_0p.empty()) {
35 weightfile_0p = find_file(m_weightfile_0p);
36 if (weightfile_0p.empty()) {
37 ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_0p);
38 return StatusCode::FAILURE;
39 } else {
40 ATH_MSG_INFO("Using network config [0-prong]: " << weightfile_0p);
41 }
42 }
43
44 if (!m_weightfile_1p.empty()) {
45 weightfile_1p = find_file(m_weightfile_1p);
46 if (weightfile_1p.empty()) {
47 ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_1p);
48 return StatusCode::FAILURE;
49 } else {
50 ATH_MSG_INFO("Using network config [1-prong]: " << weightfile_1p);
51 }
52 }
53
54 if (!m_weightfile_2p.empty()) {
55 weightfile_2p = find_file(m_weightfile_2p);
56 if (weightfile_2p.empty()) {
57 ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_2p);
58 return StatusCode::FAILURE;
59 } else {
60 ATH_MSG_INFO("Using network config [2-prong]: " << weightfile_2p);
61 }
62 }
63
64 if (!m_weightfile_3p.empty()) {
65 weightfile_3p = find_file(m_weightfile_3p);
66 if (weightfile_3p.empty()) {
67 ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_3p);
68 return StatusCode::FAILURE;
69 } else {
70 ATH_MSG_INFO("Using network config [3-prong]: " << weightfile_3p);
71 }
72 }
73
74 // Set the layer and node names in the weight file
76 config.input_layer_scalar = m_input_layer_scalar;
77 config.input_layer_tracks = m_input_layer_tracks;
78 config.input_layer_clusters = m_input_layer_clusters;
79 config.output_layer = m_output_layer;
80 config.output_node = m_output_node;
81
82 // Load the weights and create the network
83 // 0p is for trigger only
84 if (!weightfile_0p.empty()) {
85 m_net_0p = std::make_unique<TauJetRNN>(weightfile_0p, config, m_useTRT);
86 if (!m_net_0p) {
87 ATH_MSG_ERROR("No network configured for 0-prong taus.");
88 return StatusCode::FAILURE;
89 }
90 }
91
92 m_net_1p = std::make_unique<TauJetRNN>(weightfile_1p, config, m_useTRT);
93 if (!m_net_1p) {
94 ATH_MSG_ERROR("No network configured for 1-prong taus.");
95 return StatusCode::FAILURE;
96 }
97
98 // 2p is optional
99 if (!weightfile_2p.empty()) {
100 m_net_2p = std::make_unique<TauJetRNN>(weightfile_2p, config, m_useTRT);
101 if (!m_net_2p) {
102 ATH_MSG_ERROR("No network configured for 2-prong taus.");
103 return StatusCode::FAILURE;
104 }
105 }
106
107 m_net_3p = std::make_unique<TauJetRNN>(weightfile_3p, config, m_useTRT);
108 if (!m_net_3p) {
109 ATH_MSG_ERROR("No network configured for 3-prong taus.");
110 return StatusCode::FAILURE;
111 }
112
113 return StatusCode::SUCCESS;
114}
115
117 // Output variable accessor
119
120 // Set default score and overwrite later
121 output(tau) = -1111.0f;
122
123 // save CPU when running PHYS derivations
125 if (tau.nTracks()>5) return StatusCode::SUCCESS;
126 }
127
128 const auto nTracksCharged = tau.nTracksCharged();
129
130 // Get input objects
131 std::vector<const xAOD::TauTrack *> tracks;
132 ATH_CHECK(get_tracks(tau, tracks));
133 std::vector<xAOD::CaloVertexedTopoCluster> clusters;
134 ATH_CHECK(get_clusters(tau, clusters));
135
136 // Evaluate networks
137 if (nTracksCharged==0 && m_net_0p) {
138 output(tau) = m_net_0p->compute(tau, tracks, clusters);
139 }
140 else if (nTracksCharged == 1) {
141 output(tau) = m_net_1p->compute(tau, tracks, clusters);
142 }
143 else if (nTracksCharged == 2) {
144 if(m_net_2p) {
145 output(tau) = m_net_2p->compute(tau, tracks, clusters);
146 } else {
147 output(tau) = m_net_3p->compute(tau, tracks, clusters);
148 }
149 }
150 else if (nTracksCharged > 2) {
151 output(tau) = m_net_3p->compute(tau, tracks, clusters);
152 }
153
154 return StatusCode::SUCCESS;
155}
156
157StatusCode TauJetRNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
158 std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
159
160 // Skip unclassified tracks:
161 // - the track is a LRT and classifyLRT = false
162 // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
163 // - track classification is not run (trigger)
165 std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin();
166 while(it != tracks.end()) {
167 if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
168 it = tracks.erase(it);
169 }
170 else {
171 ++it;
172 }
173 }
174 }
175
176 // Sort by descending pt
177 auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
178 return lhs->pt() > rhs->pt();
179 };
180 std::sort(tracks.begin(), tracks.end(), cmp_pt);
181
182 // Truncate tracks
183 if (tracks.size() > m_max_tracks) {
184 tracks.resize(m_max_tracks);
185 }
186 out = std::move(tracks);
187
188 return StatusCode::SUCCESS;
189}
190
191StatusCode TauJetRNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
192
193 TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
194
195 std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters();
196 for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) {
197 TLorentzVector clusterP4 = vertexedCluster.p4();
198 if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
199
200 clusters.push_back(vertexedCluster);
201 }
202
203 // Sort by descending et
204 auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
206 return lhs.p4().Et() > rhs.p4().Et();
207 };
208 std::sort(clusters.begin(), clusters.end(), et_cmp);
209
210 // Truncate clusters
211 if (clusters.size() > m_max_clusters) {
212 clusters.resize(m_max_clusters, clusters[0]);
213 }
214
215 return StatusCode::SUCCESS;
216}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
Helper class to provide type-safe access to aux data.
Gaudi::Property< std::string > m_weightfile_2p
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_input_layer_clusters
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Gaudi::Property< std::string > m_input_layer_tracks
Gaudi::Property< float > m_max_cluster_dr
std::unique_ptr< TauJetRNN > m_net_0p
Gaudi::Property< bool > m_doTrackClassification
Gaudi::Property< std::string > m_weightfile_1p
Gaudi::Property< std::size_t > m_max_clusters
Gaudi::Property< bool > m_useTRT
TauJetRNNEvaluator(const std::string &name="TauJetRNNEvaluator")
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
std::unique_ptr< TauJetRNN > m_net_2p
Gaudi::Property< std::string > m_output_node
std::unique_ptr< TauJetRNN > m_net_3p
Gaudi::Property< std::string > m_weightfile_0p
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Gaudi::Property< std::string > m_weightfile_3p
Gaudi::Property< std::string > m_output_layer
Gaudi::Property< std::string > m_input_layer_scalar
std::unique_ptr< TauJetRNN > m_net_1p
Gaudi::Property< bool > m_doVertexCorrection
Gaudi::Property< std::string > m_output_varname
Gaudi::Property< bool > m_applyLooseTrackSel
Gaudi::Property< std::size_t > m_max_tracks
TauRecToolBase(const std::string &name)
std::string find_file(const std::string &fname) const
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Evaluate cluster kinematics with a different vertex / signal state.
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
size_t nTracksCharged() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual double pt() const
The transverse momentum ( ) of the particle.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".