ATLAS Offline Software
TauJetRNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
4 
8 
10 
11 #include <algorithm>
12 
13 
16  m_net_0p(nullptr),
17  m_net_1p(nullptr),
18  m_net_2p(nullptr),
19  m_net_3p(nullptr) {
20 
21  declareProperty("NetworkFile0P", m_weightfile_0p = "");
22  declareProperty("NetworkFile1P", m_weightfile_1p = "");
23  declareProperty("NetworkFile2P", m_weightfile_2p = "");
24  declareProperty("NetworkFile3P", m_weightfile_3p = "");
25  declareProperty("OutputVarname", m_output_varname = "RNNJetScore");
26  declareProperty("MaxTracks", m_max_tracks = 10);
27  declareProperty("MaxClusters", m_max_clusters = 6);
28  declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
29  declareProperty("VertexCorrection", m_doVertexCorrection = true);
30  declareProperty("TrackClassification", m_doTrackClassification = true);
31 
32  // Naming conventions for the network weight files:
33  declareProperty("InputLayerScalar", m_input_layer_scalar = "scalar");
34  declareProperty("InputLayerTracks", m_input_layer_tracks = "tracks");
35  declareProperty("InputLayerClusters", m_input_layer_clusters = "clusters");
36  declareProperty("OutputLayer", m_output_layer = "rnnid_output");
37  declareProperty("OutputNode", m_output_node = "sig_prob");
38  }
39 
41 
43  ATH_MSG_INFO("Initializing TauJetRNNEvaluator");
44 
45  std::string weightfile_0p("");
46  std::string weightfile_1p("");
47  std::string weightfile_2p("");
48  std::string weightfile_3p("");
49 
50  // Use PathResolver to search for the weight files
51  if (!m_weightfile_0p.empty()) {
52  weightfile_0p = find_file(m_weightfile_0p);
53  if (weightfile_0p.empty()) {
54  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_0p);
55  return StatusCode::FAILURE;
56  } else {
57  ATH_MSG_INFO("Using network config [0-prong]: " << weightfile_0p);
58  }
59  }
60 
61  if (!m_weightfile_1p.empty()) {
62  weightfile_1p = find_file(m_weightfile_1p);
63  if (weightfile_1p.empty()) {
64  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_1p);
65  return StatusCode::FAILURE;
66  } else {
67  ATH_MSG_INFO("Using network config [1-prong]: " << weightfile_1p);
68  }
69  }
70 
71  if (!m_weightfile_2p.empty()) {
72  weightfile_2p = find_file(m_weightfile_2p);
73  if (weightfile_2p.empty()) {
74  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_2p);
75  return StatusCode::FAILURE;
76  } else {
77  ATH_MSG_INFO("Using network config [2-prong]: " << weightfile_2p);
78  }
79  }
80 
81  if (!m_weightfile_3p.empty()) {
82  weightfile_3p = find_file(m_weightfile_3p);
83  if (weightfile_3p.empty()) {
84  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile_3p);
85  return StatusCode::FAILURE;
86  } else {
87  ATH_MSG_INFO("Using network config [3-prong]: " << weightfile_3p);
88  }
89  }
90 
91  // Set the layer and node names in the weight file
93  config.input_layer_scalar = m_input_layer_scalar;
94  config.input_layer_tracks = m_input_layer_tracks;
95  config.input_layer_clusters = m_input_layer_clusters;
96  config.output_layer = m_output_layer;
97  config.output_node = m_output_node;
98 
99  // Load the weights and create the network
100  // 0p is for trigger only
101  if (!weightfile_0p.empty()) {
102  m_net_0p = std::make_unique<TauJetRNN>(weightfile_0p, config);
103  if (!m_net_0p) {
104  ATH_MSG_ERROR("No network configured for 0-prong taus.");
105  return StatusCode::FAILURE;
106  }
107  }
108 
109  m_net_1p = std::make_unique<TauJetRNN>(weightfile_1p, config);
110  if (!m_net_1p) {
111  ATH_MSG_ERROR("No network configured for 1-prong taus.");
112  return StatusCode::FAILURE;
113  }
114 
115  // 2p is optional
116  if (!weightfile_2p.empty()) {
117  m_net_2p = std::make_unique<TauJetRNN>(weightfile_2p, config);
118  if (!m_net_2p) {
119  ATH_MSG_ERROR("No network configured for 2-prong taus.");
120  return StatusCode::FAILURE;
121  }
122  }
123 
124  m_net_3p = std::make_unique<TauJetRNN>(weightfile_3p, config);
125  if (!m_net_3p) {
126  ATH_MSG_ERROR("No network configured for 3-prong taus.");
127  return StatusCode::FAILURE;
128  }
129 
130  return StatusCode::SUCCESS;
131 }
132 
134  // Output variable accessor
136 
137  // Set default score and overwrite later
138  output(tau) = -1111.0f;
139 
140  const auto nTracksCharged = tau.nTracksCharged();
141 
142  // Get input objects
143  std::vector<const xAOD::TauTrack *> tracks;
144  ATH_CHECK(get_tracks(tau, tracks));
145  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
147 
148  // Evaluate networks
149  if (nTracksCharged==0 && m_net_0p) {
150  output(tau) = m_net_0p->compute(tau, tracks, clusters);
151  }
152  else if (nTracksCharged == 1) {
153  output(tau) = m_net_1p->compute(tau, tracks, clusters);
154  }
155  else if (nTracksCharged == 2) {
156  if(m_net_2p) {
157  output(tau) = m_net_2p->compute(tau, tracks, clusters);
158  } else {
159  output(tau) = m_net_3p->compute(tau, tracks, clusters);
160  }
161  }
162  else if (nTracksCharged > 2) {
163  output(tau) = m_net_3p->compute(tau, tracks, clusters);
164  }
165 
166  return StatusCode::SUCCESS;
167 }
168 
170  return m_net_0p.get();
171 }
172 
174  return m_net_1p.get();
175 }
176 
178  return m_net_2p.get();
179 }
180 
182  return m_net_3p.get();
183 }
184 
185 StatusCode TauJetRNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
186  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
187 
188  // Skip unclassified tracks:
189  // - the track is a LRT and classifyLRT = false
190  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
191  // - track classification is not run (trigger)
194  while(it != tracks.end()) {
195  if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
196  it = tracks.erase(it);
197  }
198  else {
199  ++it;
200  }
201  }
202  }
203 
204  // Sort by descending pt
205  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
206  return lhs->pt() > rhs->pt();
207  };
208  std::sort(tracks.begin(), tracks.end(), cmp_pt);
209 
210  // Truncate tracks
211  if (tracks.size() > m_max_tracks) {
212  tracks.resize(m_max_tracks);
213  }
214  out = std::move(tracks);
215 
216  return StatusCode::SUCCESS;
217 }
218 
219 StatusCode TauJetRNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
220 
221  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
222 
223  std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters();
224  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) {
225  TLorentzVector clusterP4 = vertexedCluster.p4();
226  if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
227 
228  clusters.push_back(vertexedCluster);
229  }
230 
231  // Sort by descending et
232  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
233  const xAOD::CaloVertexedTopoCluster& rhs) {
234  return lhs.p4().Et() > rhs.p4().Et();
235  };
236  std::sort(clusters.begin(), clusters.end(), et_cmp);
237 
238  // Truncate clusters
239  if (clusters.size() > m_max_clusters) {
240  clusters.resize(m_max_clusters, clusters[0]);
241  }
242 
243  return StatusCode::SUCCESS;
244 }
xAOD::iterator
JetConstituentVector::iterator iterator
Definition: JetConstituentVector.cxx:68
TauJetRNNEvaluator::m_input_layer_tracks
std::string m_input_layer_tracks
Definition: TauJetRNNEvaluator.h:64
TauJetRNNEvaluator::get_rnn_0p
const TauJetRNN * get_rnn_0p() const
Definition: TauJetRNNEvaluator.cxx:169
TauJetRNNEvaluator::m_max_tracks
std::size_t m_max_tracks
Definition: TauJetRNNEvaluator.h:56
xAOD::CaloVertexedClusterBase::p4
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedClusterBase.h:88
TauJetRNNEvaluator::get_rnn_3p
const TauJetRNN * get_rnn_3p() const
Definition: TauJetRNNEvaluator.cxx:181
TauJetRNNEvaluator::m_doTrackClassification
bool m_doTrackClassification
Definition: TauJetRNNEvaluator.h:60
TauJetRNNEvaluator.h
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauJetRNNEvaluator::m_doVertexCorrection
bool m_doVertexCorrection
Definition: TauJetRNNEvaluator.h:59
SG::Accessor
Helper class to provide type-safe access to aux data.
Definition: Control/AthContainers/AthContainers/Accessor.h:68
TauJetRNNEvaluator::m_output_node
std::string m_output_node
Definition: TauJetRNNEvaluator.h:67
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
tauRecTools::getTauAxis
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
Definition: Reconstruction/tauRecTools/Root/HelperFunctions.cxx:33
TauJetRNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauJetRNNEvaluator.cxx:185
skel.it
it
Definition: skel.GENtoEVGEN.py:396
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauJetRNNEvaluator::~TauJetRNNEvaluator
virtual ~TauJetRNNEvaluator()
Definition: TauJetRNNEvaluator.cxx:40
TauJetRNNEvaluator::m_max_clusters
std::size_t m_max_clusters
Definition: TauJetRNNEvaluator.h:57
TauJetRNN
Wrapper around lwtnn to compute the output score of a neural network.
Definition: TauJetRNN.h:34
TauJetRNNEvaluator::m_output_varname
std::string m_output_varname
Definition: TauJetRNNEvaluator.h:51
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauJetRNNEvaluator::get_rnn_2p
const TauJetRNN * get_rnn_2p() const
Definition: TauJetRNNEvaluator.cxx:177
TauJetRNNEvaluator::m_weightfile_1p
std::string m_weightfile_1p
Definition: TauJetRNNEvaluator.h:53
TauJetRNNEvaluator::get_rnn_1p
const TauJetRNN * get_rnn_1p() const
Definition: TauJetRNNEvaluator.cxx:173
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::TauJet_v3::nTracksCharged
size_t nTracksCharged() const
Definition: TauJet_v3.cxx:532
TauJetRNNEvaluator::m_input_layer_clusters
std::string m_input_layer_clusters
Definition: TauJetRNNEvaluator.h:65
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauJetRNNEvaluator::m_net_1p
std::unique_ptr< TauJetRNN > m_net_1p
Definition: TauJetRNNEvaluator.h:71
TauJetRNNEvaluator::m_weightfile_0p
std::string m_weightfile_0p
Definition: TauJetRNNEvaluator.h:52
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
hist_file_dump.f
f
Definition: hist_file_dump.py:135
TauJetRNNEvaluator::m_output_layer
std::string m_output_layer
Definition: TauJetRNNEvaluator.h:66
TauJetRNNEvaluator::m_max_cluster_dr
float m_max_cluster_dr
Definition: TauJetRNNEvaluator.h:58
TauJetRNN.h
merge.output
output
Definition: merge.py:17
PathResolver.h
xAOD::TauTrack_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauJetRNNEvaluator::m_net_0p
std::unique_ptr< TauJetRNN > m_net_0p
Definition: TauJetRNNEvaluator.h:70
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
TauJetRNNEvaluator::m_weightfile_2p
std::string m_weightfile_2p
Definition: TauJetRNNEvaluator.h:54
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
xAOD::TauJet_v3::vertexedClusters
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
Definition: TauJet_v3.cxx:626
TauJetRNNEvaluator::m_net_2p
std::unique_ptr< TauJetRNN > m_net_2p
Definition: TauJetRNNEvaluator.h:72
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
TauJetRNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauJetRNNEvaluator.cxx:133
HelperFunctions.h
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:74
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauJetRNNEvaluator::TauJetRNNEvaluator
TauJetRNNEvaluator(const std::string &name="TauJetRNNEvaluator")
Definition: TauJetRNNEvaluator.cxx:14
TauJetRNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauJetRNNEvaluator.cxx:219
TauJetRNNEvaluator::m_weightfile_3p
std::string m_weightfile_3p
Definition: TauJetRNNEvaluator.h:55
TauJetRNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauJetRNNEvaluator.cxx:42
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauJetRNN::Config
Definition: TauJetRNN.h:37
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:514
TauJetRNNEvaluator::m_net_3p
std::unique_ptr< TauJetRNN > m_net_3p
Definition: TauJetRNNEvaluator.h:73
TauJetRNNEvaluator::m_input_layer_scalar
std::string m_input_layer_scalar
Definition: TauJetRNNEvaluator.h:63
xAOD::TauJetParameters::unclassified
@ unclassified
Definition: TauDefs.h:410