ATLAS Offline Software
TauGNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9 
10 #include <algorithm>
11 
12 
15  m_net_inclusive(nullptr),
16  m_net_0p(nullptr), m_net_1p(nullptr), m_net_2p(nullptr), m_net_3p(nullptr) {
17 }
18 
20 
22  ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks.value()<<" tracks and "<<m_max_clusters<<" clusters...");
23 
24  // Set the layer and node names in the weight file
26  config.input_layer_scalar = m_input_layer_scalar;
27  config.input_layer_tracks = m_input_layer_tracks;
28  config.input_layer_clusters = m_input_layer_clusters;
29  config.output_node_tau = m_outnode_tau;
30  config.output_node_jet = m_outnode_jet;
31 
32  // We can either use an inclussive GNN (e.g. Offline GNTauv0), or a prong-dependent GNN (e.g. HLT GNTau), not both!
33 
34  if(!m_weightfile_inclusive.empty()) { // Prong-inclusive network
35  if(!m_weightfile_0p.empty() || !m_weightfile_1p.empty() || !m_weightfile_2p.empty() || !m_weightfile_3p.empty()) {
36  ATH_MSG_ERROR("Cannot load both prong-inclusive and prong-dependent networks!");
37  return StatusCode::FAILURE;
38  }
39 
40  ATH_MSG_INFO("Loading prong-inclusive TauID GNN");
42  if(!m_net_inclusive) return StatusCode::FAILURE;
43 
44  } else { // Prong-dependent networks
45 
46  // 0-prong is optional
47  if(!m_weightfile_0p.empty()) {
48  ATH_MSG_INFO("Loading 0-prong TauID GNN");
50  if(!m_net_0p) return StatusCode::FAILURE;
51  }
52 
53  ATH_MSG_INFO("Loading 1-prong TauID GNN");
55  if(!m_net_1p) return StatusCode::FAILURE;
56 
57  // 2-prong is optional
58  if(!m_weightfile_2p.empty()) {
59  ATH_MSG_INFO("Loading 2-prong TauID GNN");
61  if(!m_net_2p) return StatusCode::FAILURE;
62  }
63 
64  ATH_MSG_INFO("Loading 3-prong TauID GNN");
66  if(!m_net_3p) return StatusCode::FAILURE;
67  }
68 
69  if(m_output_discriminant < Discriminant::NegLogPJet || m_output_discriminant > Discriminant::PTau) {
70  ATH_MSG_FATAL("Invalid TauGNNEvaluator discriminant setting: " << m_output_discriminant);
71  }
72 
73  if (!m_tauContainerName.empty()){
75  ATH_CHECK(m_scoreHandleKey.initialize());
76  }
77 
78  return StatusCode::SUCCESS;
79 }
80 
81 std::unique_ptr<TauGNN> TauGNNEvaluator::load_network(const std::string& network_file, const TauGNN::Config& config) const {
82  // Use PathResolver to search for the weight files
83  if(network_file.empty()) return nullptr;
84 
85  const std::string pr_network_file = find_file(network_file);
86  if(pr_network_file.empty()) {
87  ATH_MSG_ERROR("Could not find network weights: " << network_file);
88  return nullptr;
89  }
90 
91  ATH_MSG_INFO("Using network config: " << pr_network_file);
92 
93  // Load the weights and create the network
94  std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(pr_network_file, config);
95  if(!net) ATH_MSG_ERROR("No network configured.");
96 
97  return net;
98 }
99 
101  // Output variable Decorators
103  const SG::Accessor<float> out_ptau(m_output_ptau);
104  const SG::Accessor<float> out_pjet(m_output_pjet);
105  const SG::Decorator<char> out_trkclass("GNTau_TrackClass");
106  // Set default score and overwrite later
107  output(tau) = -1111.0f;
108  out_ptau(tau) = -1111.0f;
109  out_pjet(tau) = -1111.0f;
110 
111  //Skip execution for low-pT taus to save resources
112  if (tau.pt() < m_minTauPt) {
113  return StatusCode::SUCCESS;
114  }
115 
116  // save CPU when running PHYS derivations
117  if (m_applyLooseTrackSel) {
118  if (tau.nTracks()>5) return StatusCode::SUCCESS;
119  }
120 
121  // save CPU when running in RAWtoALL for tau trigger monitoring purpose
122  if (m_applyTightTrackSel) {
123  if (tau.nTracks()!=1 && tau.nTracks()!=3) return StatusCode::SUCCESS;
124  }
125 
126  // Get input objects
127  ATH_MSG_DEBUG("Fetching Tracks");
128  std::vector<const xAOD::TauTrack *> tracks;
129  ATH_CHECK(get_tracks(tau, tracks));
130  ATH_MSG_DEBUG("Fetching clusters");
131  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
133  ATH_MSG_DEBUG("Constituent fetching done...");
134 
135  // Truncate tracks
136  int numTracksMax = std::min(m_max_tracks.value(), static_cast<int>(tracks.size()));
137  std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
138 
139  // Network outputs
140  std::map<std::string, float> out_f;
141  std::map<std::string, std::vector<char>> out_vc;
142  std::map<std::string, std::vector<float>> out_vf;
143 
144  // Evaluate networks
145  if(m_net_inclusive) {
146  std::tie(out_f, out_vc, out_vf) = m_net_inclusive->compute(tau, trackVec, clusters);
147  } else {
148  // First we calculate the tau prongness
149  int n_tracks = tau.nTracksCharged();
151  n_tracks = 0;
152  for(const xAOD::TauTrack* track : tracks) {
153  if(track->pt() > m_min_prong_track_pt) n_tracks++;
154  }
155  }
156  ATH_MSG_DEBUG("Tau prongness: " << n_tracks);
157 
158  if(n_tracks == 0 && m_net_0p) std::tie(out_f, out_vc, out_vf) = m_net_0p->compute(tau, trackVec, clusters);
159  else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) = m_net_1p->compute(tau, trackVec, clusters);
160  else if(n_tracks == 2) {
161  if(m_net_2p) std::tie(out_f, out_vc, out_vf) = m_net_2p->compute(tau, trackVec, clusters);
162  else std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau, trackVec, clusters);
163  } else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau, trackVec, clusters);
164  }
165 
166  // Store scores only if the inferences actually ran
167  if(out_f.contains(m_outnode_tau)) {
168  if(m_output_discriminant == Discriminant::NegLogPJet) {
169  output(tau) = std::log10(1/(1-out_f.at(m_outnode_tau)));
170  } else if(m_output_discriminant == Discriminant::PTau) {
171  output(tau) = out_f.at(m_outnode_tau);
172  }
173 
174  out_ptau(tau) = out_f.at(m_outnode_tau);
175  out_pjet(tau) = out_f.at(m_outnode_jet);
176  }
177 
178  return StatusCode::SUCCESS;
179 }
180 
181 
182 StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
183  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
184 
185  // Skip unclassified tracks:
186  // - the track is a LRT and classifyLRT = false
187  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
188  // - track classification is not run (trigger)
191  while(it != tracks.end()) {
192  if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
193  it = tracks.erase(it);
194  }
195  else {
196  ++it;
197  }
198  }
199  }
200 
201  // Sort by descending pt
202  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
203  return lhs->pt() > rhs->pt();
204  };
205  std::sort(tracks.begin(), tracks.end(), cmp_pt);
206  out = std::move(tracks);
207 
208  return StatusCode::SUCCESS;
209 }
210 
211 StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
212 
213  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
214 
215  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) {
216  TLorentzVector clusterP4 = vertexedCluster.p4();
217  if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
218 
219  clusters.push_back(vertexedCluster);
220  }
221 
222  // Sort by descending et
223  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
224  const xAOD::CaloVertexedTopoCluster& rhs) {
225  return lhs.p4().Et() > rhs.p4().Et();
226  };
227  std::sort(clusters.begin(), clusters.end(), et_cmp);
228 
229  // Truncate clusters
230  if (static_cast<int>(clusters.size()) > m_max_clusters) {
231  clusters.resize(m_max_clusters, clusters[0]);
232  }
233 
234  return StatusCode::SUCCESS;
235 }
xAOD::iterator
JetConstituentVector::iterator iterator
Definition: JetConstituentVector.cxx:68
TauGNNEvaluator::m_doTrackClassification
Gaudi::Property< bool > m_doTrackClassification
Definition: TauGNNEvaluator.h:78
TauGNNEvaluator::m_output_pjet
Gaudi::Property< std::string > m_output_pjet
Definition: TauGNNEvaluator.h:71
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
xAOD::CaloVertexedClusterBase::p4
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedClusterBase.h:88
TauGNNEvaluator::m_applyLooseTrackSel
Gaudi::Property< bool > m_applyLooseTrackSel
Definition: TauGNNEvaluator.h:80
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
SG::Accessor< float >
tauRecTools::getTauAxis
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
Definition: Reconstruction/tauRecTools/Root/HelperFunctions.cxx:33
TauGNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauGNNEvaluator.cxx:182
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
python.base_data.config
config
Definition: base_data.py:20
xAOD::TauJet_v3::nTracks
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Definition: TauJet_v3.cxx:488
skel.it
it
Definition: skel.GENtoEVGEN.py:407
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
TauGNNEvaluator::m_net_3p
std::unique_ptr< TauGNN > m_net_3p
Definition: TauGNNEvaluator.h:94
TauGNNEvaluator::load_network
std::unique_ptr< TauGNN > load_network(const std::string &network_file, const TauGNN::Config &config) const
Definition: TauGNNEvaluator.cxx:81
TauGNNEvaluator::m_weightfile_inclusive
Gaudi::Property< std::string > m_weightfile_inclusive
Definition: TauGNNEvaluator.h:64
TauGNNEvaluator::m_max_clusters
Gaudi::Property< int > m_max_clusters
Definition: TauGNNEvaluator.h:75
TauGNNEvaluator::m_net_1p
std::unique_ptr< TauGNN > m_net_1p
Definition: TauGNNEvaluator.h:92
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_weightfile_3p
Gaudi::Property< std::string > m_weightfile_3p
Definition: TauGNNEvaluator.h:68
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:21
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauGNNEvaluator::m_applyTightTrackSel
Gaudi::Property< bool > m_applyTightTrackSel
Definition: TauGNNEvaluator.h:81
TauGNNEvaluator.h
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
TauGNNEvaluator::m_max_cluster_dr
Gaudi::Property< float > m_max_cluster_dr
Definition: TauGNNEvaluator.h:76
xAOD::TauJet_v3::nTracksCharged
size_t nTracksCharged() const
Definition: TauJet_v3.cxx:494
TauGNNEvaluator::m_net_0p
std::unique_ptr< TauGNN > m_net_0p
Definition: TauGNNEvaluator.h:91
SG::Decorator< char >
TauGNNEvaluator::m_net_2p
std::unique_ptr< TauGNN > m_net_2p
Definition: TauGNNEvaluator.h:93
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
TauGNNEvaluator::m_weightfile_2p
Gaudi::Property< std::string > m_weightfile_2p
Definition: TauGNNEvaluator.h:67
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNNEvaluator::m_output_discriminant
Gaudi::Property< unsigned int > m_output_discriminant
Definition: TauGNNEvaluator.h:72
TauGNNEvaluator::m_tauContainerName
Gaudi::Property< std::string > m_tauContainerName
Definition: TauGNNEvaluator.h:60
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
TauGNNEvaluator::m_max_tracks
Gaudi::Property< int > m_max_tracks
Definition: TauGNNEvaluator.h:74
TauGNNEvaluator::m_net_inclusive
std::unique_ptr< TauGNN > m_net_inclusive
Definition: TauGNNEvaluator.h:90
TauGNNEvaluator::m_minTauPt
Gaudi::Property< float > m_minTauPt
Definition: TauGNNEvaluator.h:79
TauGNNEvaluator::m_outnode_tau
Gaudi::Property< std::string > m_outnode_tau
Definition: TauGNNEvaluator.h:86
merge.output
output
Definition: merge.py:16
TauGNNEvaluator::m_input_layer_tracks
Gaudi::Property< std::string > m_input_layer_tracks
Definition: TauGNNEvaluator.h:84
PathResolver.h
TauGNNEvaluator::m_input_layer_scalar
Gaudi::Property< std::string > m_input_layer_scalar
Definition: TauGNNEvaluator.h:83
TauGNNEvaluator::m_min_prong_track_pt
Gaudi::Property< float > m_min_prong_track_pt
Definition: TauGNNEvaluator.h:82
xAOD::TauTrack_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
TauGNNEvaluator::m_doVertexCorrection
Gaudi::Property< bool > m_doVertexCorrection
Definition: TauGNNEvaluator.h:77
TauGNNEvaluator::m_outnode_jet
Gaudi::Property< std::string > m_outnode_jet
Definition: TauGNNEvaluator.h:87
xAOD::TauJet_v3::vertexedClusters
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
Definition: TauJet_v3.cxx:586
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
TauGNNEvaluator::m_weightfile_0p
Gaudi::Property< std::string > m_weightfile_0p
Definition: TauGNNEvaluator.h:65
HelperFunctions.h
TauGNNEvaluator::m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_1p
Definition: TauGNNEvaluator.h:66
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:19
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauGNN::Config
Definition: TauGNN.h:39
TauGNNEvaluator::m_input_layer_clusters
Gaudi::Property< std::string > m_input_layer_clusters
Definition: TauGNNEvaluator.h:85
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:13
TauGNNEvaluator::m_scoreHandleKey
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
Definition: TauGNNEvaluator.h:61
TauGNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauGNNEvaluator.cxx:211
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:482
TauGNNEvaluator::m_output_ptau
Gaudi::Property< std::string > m_output_ptau
Definition: TauGNNEvaluator.h:70
TauGNNEvaluator::m_output_varname
Gaudi::Property< std::string > m_output_varname
Definition: TauGNNEvaluator.h:69
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:100
xAOD::TauJetParameters::unclassified
@ unclassified
Definition: TauDefs.h:410