Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TauGNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9 
10 #include <algorithm>
11 
12 
15  m_net_inclusive(nullptr),
16  m_net_0p(nullptr), m_net_1p(nullptr), m_net_2p(nullptr), m_net_3p(nullptr) {
17 
18  declareProperty("NetworkFileInclusive", m_weightfile_inclusive = "");
19  declareProperty("NetworkFile0P", m_weightfile_0p = "");
20  declareProperty("NetworkFile1P", m_weightfile_1p = "");
21  declareProperty("NetworkFile2P", m_weightfile_2p = "");
22  declareProperty("NetworkFile3P", m_weightfile_3p = "");
23 
24  declareProperty("OutputVarname", m_output_varname = "GNTauScore");
25  declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau");
26  declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet");
27  declareProperty("OutputDiscriminant", m_output_discriminant = Discriminant::NegLogPJet,
28  "Discriminant used to calculate the output score: 0 -> -log(PJet), 1 -> PTau");
29 
30  declareProperty("MaxTracks", m_max_tracks = 30);
31  declareProperty("MaxClusters", m_max_clusters = 20);
32  declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
33 
34  declareProperty("VertexCorrection", m_doVertexCorrection = true);
35  declareProperty("DecorateTracks", m_decorateTracks = false);
36  declareProperty("TrackClassification", m_doTrackClassification = true);
37  declareProperty("MinTauPt", m_minTauPt = 0.);
38 
39  // Prongness selection minimum track pT
40  declareProperty("MinProngTrackPt", m_min_prong_track_pt = 0);
41 
42  // Naming conventions for the network weight files:
43  declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars");
44  declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars");
45  declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars");
46  declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb");
47  declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu");
48  }
49 
51 
53  ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks<<" tracks and "<<m_max_clusters<<" clusters...");
54 
55  // Set the layer and node names in the weight file
57  config.input_layer_scalar = m_input_layer_scalar;
58  config.input_layer_tracks = m_input_layer_tracks;
59  config.input_layer_clusters = m_input_layer_clusters;
60  config.output_node_tau = m_outnode_tau;
61  config.output_node_jet = m_outnode_jet;
62 
63  // We can either use an inclussive GNN (e.g. Offline GNTauv0), or a prong-dependent GNN (e.g. HLT GNTau), not both!
64 
65  if(!m_weightfile_inclusive.empty()) { // Prong-inclusive network
66  if(!m_weightfile_0p.empty() || !m_weightfile_1p.empty() || !m_weightfile_2p.empty() || !m_weightfile_3p.empty()) {
67  ATH_MSG_ERROR("Cannot load both prong-inclusive and prong-dependent networks!");
68  return StatusCode::FAILURE;
69  }
70 
71  ATH_MSG_INFO("Loading prong-inclusive TauID GNN");
73  if(!m_net_inclusive) return StatusCode::FAILURE;
74 
75  } else { // Prong-dependent networks
76 
77  // 0-prong is optional
78  if(!m_weightfile_0p.empty()) {
79  ATH_MSG_INFO("Loading 0-prong TauID GNN");
81  if(!m_net_0p) return StatusCode::FAILURE;
82  }
83 
84  ATH_MSG_INFO("Loading 1-prong TauID GNN");
86  if(!m_net_1p) return StatusCode::FAILURE;
87 
88  // 2-prong is optional
89  if(!m_weightfile_2p.empty()) {
90  ATH_MSG_INFO("Loading 2-prong TauID GNN");
92  if(!m_net_2p) return StatusCode::FAILURE;
93  }
94 
95  ATH_MSG_INFO("Loading 3-prong TauID GNN");
97  if(!m_net_3p) return StatusCode::FAILURE;
98  }
99 
100  if(m_output_discriminant < Discriminant::NegLogPJet || m_output_discriminant > Discriminant::PTau) {
101  ATH_MSG_FATAL("Invalid TauGNNEvaluator discriminant setting: " << m_output_discriminant);
102  }
103 
104  return StatusCode::SUCCESS;
105 }
106 
107 std::unique_ptr<TauGNN> TauGNNEvaluator::load_network(const std::string& network_file, const TauGNN::Config& config) const {
108  // Use PathResolver to search for the weight files
109  if(network_file.empty()) return nullptr;
110 
111  const std::string pr_network_file = find_file(network_file);
112  if(pr_network_file.empty()) {
113  ATH_MSG_ERROR("Could not find network weights: " << network_file);
114  return nullptr;
115  }
116 
117  ATH_MSG_INFO("Using network config: " << pr_network_file);
118 
119  // Load the weights and create the network
120  std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(pr_network_file, config);
121  if(!net) ATH_MSG_ERROR("No network configured.");
122 
123  return net;
124 }
125 
127  // Output variable Decorators
129  const SG::Accessor<float> out_ptau(m_output_ptau);
130  const SG::Accessor<float> out_pjet(m_output_pjet);
131  const SG::Decorator<char> out_trkclass("GNTau_TrackClass");
132  // Set default score and overwrite later
133  output(tau) = -1111.0f;
134  out_ptau(tau) = -1111.0f;
135  out_pjet(tau) = -1111.0f;
136 
137  //Skip execution for low-pT taus to save resources
138  if (tau.pt() < m_minTauPt) {
139  return StatusCode::SUCCESS;
140  }
141 
142  // Get input objects
143  ATH_MSG_DEBUG("Fetching Tracks");
144  std::vector<const xAOD::TauTrack *> tracks;
145  ATH_CHECK(get_tracks(tau, tracks));
146  ATH_MSG_DEBUG("Fetching clusters");
147  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
149  ATH_MSG_DEBUG("Constituent fetching done...");
150 
151  // Truncate tracks
152  int numTracksMax = std::min(m_max_tracks, static_cast<int>(tracks.size()));
153  std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
154 
155  // Network outputs
156  std::map<std::string, float> out_f;
157  std::map<std::string, std::vector<char>> out_vc;
158  std::map<std::string, std::vector<float>> out_vf;
159 
160  // Evaluate networks
161  if(m_net_inclusive) {
162  std::tie(out_f, out_vc, out_vf) = m_net_inclusive->compute(tau, trackVec, clusters);
163  } else {
164  // First we calculate the tau prongness
165  int n_tracks = tau.nTracksCharged();
167  n_tracks = 0;
168  for(const xAOD::TauTrack* track : tracks) {
169  if(track->pt() > m_min_prong_track_pt) n_tracks++;
170  }
171  }
172  ATH_MSG_DEBUG("Tau prongness: " << n_tracks);
173 
174  if(n_tracks == 0 && m_net_0p) std::tie(out_f, out_vc, out_vf) = m_net_0p->compute(tau, trackVec, clusters);
175  else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) = m_net_1p->compute(tau, trackVec, clusters);
176  else if(n_tracks == 2) {
177  if(m_net_2p) std::tie(out_f, out_vc, out_vf) = m_net_2p->compute(tau, trackVec, clusters);
178  else std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau, trackVec, clusters);
179  } else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau, trackVec, clusters);
180  }
181 
182  // Store scores only if the inferences actually ran
183  if(out_f.contains(m_outnode_tau)) {
184  if(m_output_discriminant == Discriminant::NegLogPJet) {
185  output(tau) = std::log10(1/(1-out_f.at(m_outnode_tau)));
186  } else if(m_output_discriminant == Discriminant::PTau) {
187  output(tau) = out_f.at(m_outnode_tau);
188  }
189 
190  out_ptau(tau) = out_f.at(m_outnode_tau);
191  out_pjet(tau) = out_f.at(m_outnode_jet);
192 
193  if(m_decorateTracks) {
194  for(size_t i = 0; i < tracks.size(); i++) {
195  if(i < out_vc.at("track_class").size()) out_trkclass(*tracks.at(i)) = out_vc.at("track_class").at(i);
196  else out_trkclass(*tracks.at(i)) = '9'; //Dummy value for tracks outside range of out_vc
197  }
198  }
199  }
200 
201  return StatusCode::SUCCESS;
202 }
203 
204 
205 StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
206  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
207 
208  // Skip unclassified tracks:
209  // - the track is a LRT and classifyLRT = false
210  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
211  // - track classification is not run (trigger)
214  while(it != tracks.end()) {
215  if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
216  it = tracks.erase(it);
217  }
218  else {
219  ++it;
220  }
221  }
222  }
223 
224  // Sort by descending pt
225  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
226  return lhs->pt() > rhs->pt();
227  };
228  std::sort(tracks.begin(), tracks.end(), cmp_pt);
229  out = std::move(tracks);
230 
231  return StatusCode::SUCCESS;
232 }
233 
234 StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
235 
236  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
237 
238  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) {
239  TLorentzVector clusterP4 = vertexedCluster.p4();
240  if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
241 
242  clusters.push_back(vertexedCluster);
243  }
244 
245  // Sort by descending et
246  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
247  const xAOD::CaloVertexedTopoCluster& rhs) {
248  return lhs.p4().Et() > rhs.p4().Et();
249  };
250  std::sort(clusters.begin(), clusters.end(), et_cmp);
251 
252  // Truncate clusters
253  if (static_cast<int>(clusters.size()) > m_max_clusters) {
254  clusters.resize(m_max_clusters, clusters[0]);
255  }
256 
257  return StatusCode::SUCCESS;
258 }
xAOD::iterator
JetConstituentVector::iterator iterator
Definition: JetConstituentVector.cxx:68
TauGNNEvaluator::m_max_clusters
int m_max_clusters
Definition: TauGNNEvaluator.h:69
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
xAOD::CaloVertexedClusterBase::p4
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedClusterBase.h:88
TauGNNEvaluator::m_input_layer_scalar
std::string m_input_layer_scalar
Definition: TauGNNEvaluator.h:77
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauGNNEvaluator::m_doTrackClassification
bool m_doTrackClassification
Definition: TauGNNEvaluator.h:73
SG::Accessor< float >
TauGNNEvaluator::m_weightfile_2p
std::string m_weightfile_2p
Definition: TauGNNEvaluator.h:64
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
tauRecTools::getTauAxis
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
Definition: Reconstruction/tauRecTools/Root/HelperFunctions.cxx:33
TauGNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauGNNEvaluator.cxx:205
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
python.base_data.config
config
Definition: base_data.py:21
TauGNNEvaluator::m_output_discriminant
unsigned int m_output_discriminant
Definition: TauGNNEvaluator.h:59
skel.it
it
Definition: skel.GENtoEVGEN.py:407
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauGNNEvaluator::m_net_3p
std::unique_ptr< TauGNN > m_net_3p
Definition: TauGNNEvaluator.h:88
TauGNNEvaluator::m_input_layer_tracks
std::string m_input_layer_tracks
Definition: TauGNNEvaluator.h:78
TauGNNEvaluator::load_network
std::unique_ptr< TauGNN > load_network(const std::string &network_file, const TauGNN::Config &config) const
Definition: TauGNNEvaluator.cxx:107
TauGNNEvaluator::m_input_layer_clusters
std::string m_input_layer_clusters
Definition: TauGNNEvaluator.h:79
TauGNNEvaluator::m_outnode_tau
std::string m_outnode_tau
Definition: TauGNNEvaluator.h:80
TauGNNEvaluator::m_net_1p
std::unique_ptr< TauGNN > m_net_1p
Definition: TauGNNEvaluator.h:86
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_output_ptau
std::string m_output_ptau
Definition: TauGNNEvaluator.h:57
TauGNNEvaluator::m_max_tracks
int m_max_tracks
Definition: TauGNNEvaluator.h:68
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:52
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauGNNEvaluator.h
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
xAOD::TauJet_v3::nTracksCharged
size_t nTracksCharged() const
Definition: TauJet_v3.cxx:532
TauGNNEvaluator::m_net_0p
std::unique_ptr< TauGNN > m_net_0p
Definition: TauGNNEvaluator.h:85
SG::Decorator< char >
TauGNNEvaluator::m_output_pjet
std::string m_output_pjet
Definition: TauGNNEvaluator.h:58
lumiFormat.i
int i
Definition: lumiFormat.py:85
TauGNNEvaluator::m_net_2p
std::unique_ptr< TauGNN > m_net_2p
Definition: TauGNNEvaluator.h:87
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
hist_file_dump.f
f
Definition: hist_file_dump.py:141
TauGNNEvaluator::m_decorateTracks
bool m_decorateTracks
Definition: TauGNNEvaluator.h:74
TauGNNEvaluator::m_weightfile_0p
std::string m_weightfile_0p
Definition: TauGNNEvaluator.h:62
TauGNNEvaluator::m_net_inclusive
std::unique_ptr< TauGNN > m_net_inclusive
Definition: TauGNNEvaluator.h:84
TauGNNEvaluator::m_weightfile_inclusive
std::string m_weightfile_inclusive
Definition: TauGNNEvaluator.h:61
merge.output
output
Definition: merge.py:17
PathResolver.h
TauGNNEvaluator::m_max_cluster_dr
float m_max_cluster_dr
Definition: TauGNNEvaluator.h:70
xAOD::TauTrack_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
xAOD::TauJet_v3::vertexedClusters
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
Definition: TauJet_v3.cxx:626
TauGNNEvaluator::m_output_varname
std::string m_output_varname
Definition: TauGNNEvaluator.h:56
TauGNNEvaluator::m_minTauPt
float m_minTauPt
Definition: TauGNNEvaluator.h:71
TauGNNEvaluator::m_doVertexCorrection
bool m_doVertexCorrection
Definition: TauGNNEvaluator.h:72
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
HelperFunctions.h
TauGNNEvaluator::m_outnode_jet
std::string m_outnode_jet
Definition: TauGNNEvaluator.h:81
TauGNNEvaluator::m_min_prong_track_pt
float m_min_prong_track_pt
Definition: TauGNNEvaluator.h:66
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNNEvaluator::m_weightfile_1p
std::string m_weightfile_1p
Definition: TauGNNEvaluator.h:63
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:50
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
TauGNNEvaluator::m_weightfile_3p
std::string m_weightfile_3p
Definition: TauGNNEvaluator.h:65
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauGNN::Config
Definition: TauGNN.h:35
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:13
TauGNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauGNNEvaluator.cxx:234
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:514
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:126
xAOD::TauJetParameters::unclassified
@ unclassified
Definition: TauDefs.h:410