ATLAS Offline Software
TauJetRNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
8 
10 
11 #include <algorithm>
12 
13 
16  m_net_0p(nullptr),
17  m_net_1p(nullptr),
18  m_net_2p(nullptr),
19  m_net_3p(nullptr) {
20 }
21 
22 
24 
26  ATH_MSG_INFO("Initializing TauJetRNNEvaluator");
27 
28  std::string weightfile_0p("");
29  std::string weightfile_1p("");
30  std::string weightfile_2p("");
31  std::string weightfile_3p("");
32 
33  // Use PathResolver to search for the weight files
34  if (!m_weightfile_0p.empty()) {
35  weightfile_0p = find_file(m_weightfile_0p);
36  if (weightfile_0p.empty()) {
37  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_0p);
38  return StatusCode::FAILURE;
39  } else {
40  ATH_MSG_INFO("Using network config [0-prong]: " << weightfile_0p);
41  }
42  }
43 
44  if (!m_weightfile_1p.empty()) {
45  weightfile_1p = find_file(m_weightfile_1p);
46  if (weightfile_1p.empty()) {
47  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_1p);
48  return StatusCode::FAILURE;
49  } else {
50  ATH_MSG_INFO("Using network config [1-prong]: " << weightfile_1p);
51  }
52  }
53 
54  if (!m_weightfile_2p.empty()) {
55  weightfile_2p = find_file(m_weightfile_2p);
56  if (weightfile_2p.empty()) {
57  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_2p);
58  return StatusCode::FAILURE;
59  } else {
60  ATH_MSG_INFO("Using network config [2-prong]: " << weightfile_2p);
61  }
62  }
63 
64  if (!m_weightfile_3p.empty()) {
65  weightfile_3p = find_file(m_weightfile_3p);
66  if (weightfile_3p.empty()) {
67  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_3p);
68  return StatusCode::FAILURE;
69  } else {
70  ATH_MSG_INFO("Using network config [3-prong]: " << weightfile_3p);
71  }
72  }
73 
74  // Set the layer and node names in the weight file
76  config.input_layer_scalar = m_input_layer_scalar;
77  config.input_layer_tracks = m_input_layer_tracks;
78  config.input_layer_clusters = m_input_layer_clusters;
79  config.output_layer = m_output_layer;
80  config.output_node = m_output_node;
81 
82  // Load the weights and create the network
83  // 0p is for trigger only
84  if (!weightfile_0p.empty()) {
85  m_net_0p = std::make_unique<TauJetRNN>(weightfile_0p, config);
86  if (!m_net_0p) {
87  ATH_MSG_ERROR("No network configured for 0-prong taus.");
88  return StatusCode::FAILURE;
89  }
90  }
91 
92  m_net_1p = std::make_unique<TauJetRNN>(weightfile_1p, config);
93  if (!m_net_1p) {
94  ATH_MSG_ERROR("No network configured for 1-prong taus.");
95  return StatusCode::FAILURE;
96  }
97 
98  // 2p is optional
99  if (!weightfile_2p.empty()) {
100  m_net_2p = std::make_unique<TauJetRNN>(weightfile_2p, config);
101  if (!m_net_2p) {
102  ATH_MSG_ERROR("No network configured for 2-prong taus.");
103  return StatusCode::FAILURE;
104  }
105  }
106 
107  m_net_3p = std::make_unique<TauJetRNN>(weightfile_3p, config);
108  if (!m_net_3p) {
109  ATH_MSG_ERROR("No network configured for 3-prong taus.");
110  return StatusCode::FAILURE;
111  }
112 
113  return StatusCode::SUCCESS;
114 }
115 
117  // Output variable accessor
119 
120  // Set default score and overwrite later
121  output(tau) = -1111.0f;
122 
123  // save CPU when running PHYS derivations
124  if (m_applyLooseTrackSel) {
125  if (tau.nTracks()>5) return StatusCode::SUCCESS;
126  }
127 
128  const auto nTracksCharged = tau.nTracksCharged();
129 
130  // Get input objects
131  std::vector<const xAOD::TauTrack *> tracks;
132  ATH_CHECK(get_tracks(tau, tracks));
133  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
135 
136  // Evaluate networks
137  if (nTracksCharged==0 && m_net_0p) {
138  output(tau) = m_net_0p->compute(tau, tracks, clusters);
139  }
140  else if (nTracksCharged == 1) {
141  output(tau) = m_net_1p->compute(tau, tracks, clusters);
142  }
143  else if (nTracksCharged == 2) {
144  if(m_net_2p) {
145  output(tau) = m_net_2p->compute(tau, tracks, clusters);
146  } else {
147  output(tau) = m_net_3p->compute(tau, tracks, clusters);
148  }
149  }
150  else if (nTracksCharged > 2) {
151  output(tau) = m_net_3p->compute(tau, tracks, clusters);
152  }
153 
154  return StatusCode::SUCCESS;
155 }
156 
158  return m_net_0p.get();
159 }
160 
162  return m_net_1p.get();
163 }
164 
166  return m_net_2p.get();
167 }
168 
170  return m_net_3p.get();
171 }
172 
173 StatusCode TauJetRNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
174  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
175 
176  // Skip unclassified tracks:
177  // - the track is a LRT and classifyLRT = false
178  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
179  // - track classification is not run (trigger)
182  while(it != tracks.end()) {
183  if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
184  it = tracks.erase(it);
185  }
186  else {
187  ++it;
188  }
189  }
190  }
191 
192  // Sort by descending pt
193  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
194  return lhs->pt() > rhs->pt();
195  };
196  std::sort(tracks.begin(), tracks.end(), cmp_pt);
197 
198  // Truncate tracks
199  if (tracks.size() > m_max_tracks) {
200  tracks.resize(m_max_tracks);
201  }
202  out = std::move(tracks);
203 
204  return StatusCode::SUCCESS;
205 }
206 
207 StatusCode TauJetRNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
208 
209  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
210 
211  std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters();
212  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) {
213  TLorentzVector clusterP4 = vertexedCluster.p4();
214  if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
215 
216  clusters.push_back(vertexedCluster);
217  }
218 
219  // Sort by descending et
220  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
221  const xAOD::CaloVertexedTopoCluster& rhs) {
222  return lhs.p4().Et() > rhs.p4().Et();
223  };
224  std::sort(clusters.begin(), clusters.end(), et_cmp);
225 
226  // Truncate clusters
227  if (clusters.size() > m_max_clusters) {
228  clusters.resize(m_max_clusters, clusters[0]);
229  }
230 
231  return StatusCode::SUCCESS;
232 }
xAOD::iterator
JetConstituentVector::iterator iterator
Definition: JetConstituentVector.cxx:68
TauJetRNNEvaluator::get_rnn_0p
const TauJetRNN * get_rnn_0p() const
Definition: TauJetRNNEvaluator.cxx:157
xAOD::CaloVertexedClusterBase::p4
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedClusterBase.h:88
TauJetRNNEvaluator::get_rnn_3p
const TauJetRNN * get_rnn_3p() const
Definition: TauJetRNNEvaluator.cxx:169
TauJetRNNEvaluator.h
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
SG::Accessor< float >
tauRecTools::getTauAxis
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
Definition: Reconstruction/tauRecTools/Root/HelperFunctions.cxx:33
TauJetRNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauJetRNNEvaluator.cxx:173
python.base_data.config
config
Definition: base_data.py:20
xAOD::TauJet_v3::nTracks
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Definition: TauJet_v3.cxx:488
skel.it
it
Definition: skel.GENtoEVGEN.py:407
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
TauJetRNNEvaluator::~TauJetRNNEvaluator
virtual ~TauJetRNNEvaluator()
Definition: TauJetRNNEvaluator.cxx:23
TauJetRNNEvaluator::m_max_clusters
Gaudi::Property< std::size_t > m_max_clusters
Definition: TauJetRNNEvaluator.h:60
TauJetRNN
Wrapper around lwtnn to compute the output score of a neural network.
Definition: TauJetRNN.h:34
TauJetRNNEvaluator::m_output_layer
Gaudi::Property< std::string > m_output_layer
Definition: TauJetRNNEvaluator.h:67
TauJetRNNEvaluator::m_max_cluster_dr
Gaudi::Property< float > m_max_cluster_dr
Definition: TauJetRNNEvaluator.h:61
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauJetRNNEvaluator::m_doVertexCorrection
Gaudi::Property< bool > m_doVertexCorrection
Definition: TauJetRNNEvaluator.h:62
TauJetRNNEvaluator::m_output_varname
Gaudi::Property< std::string > m_output_varname
Definition: TauJetRNNEvaluator.h:58
TauJetRNNEvaluator::get_rnn_2p
const TauJetRNN * get_rnn_2p() const
Definition: TauJetRNNEvaluator.cxx:165
TauJetRNNEvaluator::m_output_node
Gaudi::Property< std::string > m_output_node
Definition: TauJetRNNEvaluator.h:68
TauJetRNNEvaluator::get_rnn_1p
const TauJetRNN * get_rnn_1p() const
Definition: TauJetRNNEvaluator.cxx:161
TauJetRNNEvaluator::m_input_layer_tracks
Gaudi::Property< std::string > m_input_layer_tracks
Definition: TauJetRNNEvaluator.h:65
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::TauJet_v3::nTracksCharged
size_t nTracksCharged() const
Definition: TauJet_v3.cxx:494
TauJetRNNEvaluator::m_input_layer_scalar
Gaudi::Property< std::string > m_input_layer_scalar
Definition: TauJetRNNEvaluator.h:64
TauJetRNNEvaluator::m_weightfile_0p
Gaudi::Property< std::string > m_weightfile_0p
Definition: TauJetRNNEvaluator.h:54
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauJetRNNEvaluator::m_net_1p
std::unique_ptr< TauJetRNN > m_net_1p
Definition: TauJetRNNEvaluator.h:73
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
TauJetRNNEvaluator::m_weightfile_2p
Gaudi::Property< std::string > m_weightfile_2p
Definition: TauJetRNNEvaluator.h:56
TauJetRNN.h
merge.output
output
Definition: merge.py:16
PathResolver.h
xAOD::TauTrack_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauJetRNNEvaluator::m_net_0p
std::unique_ptr< TauJetRNN > m_net_0p
Definition: TauJetRNNEvaluator.h:72
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
xAOD::TauJet_v3::vertexedClusters
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
Definition: TauJet_v3.cxx:586
TauJetRNNEvaluator::m_net_2p
std::unique_ptr< TauJetRNN > m_net_2p
Definition: TauJetRNNEvaluator.h:74
TauJetRNNEvaluator::m_applyLooseTrackSel
Gaudi::Property< bool > m_applyLooseTrackSel
Definition: TauJetRNNEvaluator.h:69
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
TauJetRNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauJetRNNEvaluator.cxx:116
TauJetRNNEvaluator::m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_1p
Definition: TauJetRNNEvaluator.h:55
HelperFunctions.h
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauJetRNNEvaluator::TauJetRNNEvaluator
TauJetRNNEvaluator(const std::string &name="TauJetRNNEvaluator")
Definition: TauJetRNNEvaluator.cxx:14
TauJetRNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauJetRNNEvaluator.cxx:207
TauJetRNNEvaluator::m_weightfile_3p
Gaudi::Property< std::string > m_weightfile_3p
Definition: TauJetRNNEvaluator.h:57
TauJetRNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauJetRNNEvaluator.cxx:25
TauJetRNNEvaluator::m_input_layer_clusters
Gaudi::Property< std::string > m_input_layer_clusters
Definition: TauJetRNNEvaluator.h:66
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauJetRNN::Config
Definition: TauJetRNN.h:37
TauJetRNNEvaluator::m_doTrackClassification
Gaudi::Property< bool > m_doTrackClassification
Definition: TauJetRNNEvaluator.h:63
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:482
TauJetRNNEvaluator::m_net_3p
std::unique_ptr< TauJetRNN > m_net_3p
Definition: TauJetRNNEvaluator.h:75
xAOD::TauJetParameters::unclassified
@ unclassified
Definition: TauDefs.h:410
TauJetRNNEvaluator::m_max_tracks
Gaudi::Property< std::size_t > m_max_tracks
Definition: TauJetRNNEvaluator.h:59