ATLAS Offline Software
TauGNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9 
10 #include <algorithm>
11 
12 
15 
17 
19  ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks.value()<<" tracks and "<<m_max_clusters<<" clusters...");
20 
21  // We can either use an inclussive GNN (e.g. Offline GNTauv0), or a prong-dependent GNN (e.g. HLT GNTau), not both!
22 
23  if(!m_weightfile_inclusive.empty()) { // Prong-inclusive network
24  if(!m_weightfile_0p.empty() || !m_weightfile_1p.empty() || !m_weightfile_2p.empty() || !m_weightfile_3p.empty()) {
25  ATH_MSG_ERROR("Cannot load both prong-inclusive and prong-dependent networks!");
26  return StatusCode::FAILURE;
27  }
28 
29  ATH_MSG_INFO("Loading prong-inclusive TauID GNN");
31  if(!m_net_inclusive) return StatusCode::FAILURE;
32 
33  } else { // Prong-dependent networks
34 
35  // 0-prong is optional
36  if(!m_weightfile_0p.empty()) {
37  ATH_MSG_INFO("Loading 0-prong TauID GNN");
39  if(!m_net_0p) return StatusCode::FAILURE;
40  }
41 
42  ATH_MSG_INFO("Loading 1-prong TauID GNN");
44  if(!m_net_1p) return StatusCode::FAILURE;
45 
46  // 2-prong is optional
47  if(!m_weightfile_2p.empty()) {
48  ATH_MSG_INFO("Loading 2-prong TauID GNN");
50  if(!m_net_2p) return StatusCode::FAILURE;
51  }
52 
53  ATH_MSG_INFO("Loading 3-prong TauID GNN");
55  if(!m_net_3p) return StatusCode::FAILURE;
56  }
57 
58  if(m_output_discriminant < Discriminant::NegLogPJet || m_output_discriminant > Discriminant::PTau) {
59  ATH_MSG_FATAL("Invalid TauGNNEvaluator discriminant setting: " << m_output_discriminant);
60  }
61 
62  if (!m_tauContainerName.empty()){
64  ATH_CHECK(m_scoreHandleKey.initialize());
65  }
66 
67  return StatusCode::SUCCESS;
68 }
69 
70 std::unique_ptr<TauGNN> TauGNNEvaluator::load_network(const std::string& network_file) const {
71  // Use PathResolver to search for the weight files
72  if(network_file.empty()) return nullptr;
73 
74  const std::string pr_network_file = find_file(network_file);
75  if(pr_network_file.empty()) {
76  ATH_MSG_ERROR("Could not find network weights: " << network_file);
77  return nullptr;
78  }
79 
80  ATH_MSG_INFO("Using network config: " << pr_network_file);
81 
82  // Load the weights and create the network
84  config.nnFile = pr_network_file;
85  config.input_layer_scalar = m_input_layer_scalar.value();
86  config.input_layer_tracks = m_input_layer_tracks.value();
87  config.input_layer_clusters = m_input_layer_clusters.value();
88  config.output_node_tau = m_outnode_tau.value();
89  config.output_node_jet = m_outnode_jet.value();
90  config.n_max_tracks = m_max_tracks.value();
91  config.n_max_clusters = m_max_clusters.value();
92  config.max_dr_cluster = m_max_cluster_dr.value();
93  config.doVertexCorrection = m_doVertexCorrection.value();
94  config.trackClassification = m_doTrackClassification.value();
95  config.useTRT = m_useTRT.value();
96 
97  std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(config);
98  if(!net) ATH_MSG_ERROR("No network configured.");
99 
100  return net;
101 }
102 
104  // Output variable Decorators
106  const SG::Accessor<float> out_ptau(m_output_ptau);
107  const SG::Accessor<float> out_pjet(m_output_pjet);
108  const SG::Decorator<char> out_trkclass("GNTau_TrackClass");
109  // Set default score and overwrite later
110  output(tau) = -1111.0f;
111  out_ptau(tau) = -1111.0f;
112  out_pjet(tau) = -1111.0f;
113 
114  //Skip execution for low-pT taus to save resources
115  if (tau.pt() < m_minTauPt) {
116  return StatusCode::SUCCESS;
117  }
118 
119  // save CPU when running PHYS derivations
120  if (m_applyLooseTrackSel) {
121  if (tau.nTracks()>5) return StatusCode::SUCCESS;
122  }
123 
124  // save CPU when running in RAWtoALL for tau trigger monitoring purpose
125  if (m_applyTightTrackSel) {
126  if (tau.nTracks()!=1 && tau.nTracks()!=3) return StatusCode::SUCCESS;
127  }
128 
129  // Network outputs
130  std::map<std::string, float> out_f;
131  std::map<std::string, std::vector<char>> out_vc;
132  std::map<std::string, std::vector<float>> out_vf;
133 
134  // Evaluate networks
135  ATH_MSG_DEBUG("Evaluating GNN for tau with nTracks = " << tau.nTracksCharged());
136  if(m_net_inclusive) {
137  std::tie(out_f, out_vc, out_vf) = m_net_inclusive->compute(tau);
138  } else {
139  // First we calculate the tau prongness
140  int n_tracks = tau.nTracksCharged();
141  // in trigger, we need to apply a min pT cut on the tracks to count the prongs,
142  // as no track classification is available
144  auto trks = tau.allTracks();
145  const float threshold = m_min_prong_track_pt;
146  n_tracks = std::count_if(trks.begin(), trks.end(),
147  [&threshold](const xAOD::TauTrack* trk) { return trk->pt() > threshold; }
148  );
149  }
150 
151  if(n_tracks == 0 && m_net_0p) std::tie(out_f, out_vc, out_vf) = m_net_0p->compute(tau);
152  else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) = m_net_1p->compute(tau);
153  else if(n_tracks == 2) {
154  if(m_net_2p) std::tie(out_f, out_vc, out_vf) = m_net_2p->compute(tau);
155  else std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau);
156  } else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) = m_net_3p->compute(tau);
157  }
158 
159  // Store scores only if the inferences actually ran
160  if(out_f.contains(m_outnode_tau)) {
161  if(m_output_discriminant == Discriminant::NegLogPJet) {
162  output(tau) = std::log10(1/(1-out_f.at(m_outnode_tau)));
163  } else if(m_output_discriminant == Discriminant::PTau) {
164  output(tau) = out_f.at(m_outnode_tau);
165  }
166 
167  out_ptau(tau) = out_f.at(m_outnode_tau);
168  out_pjet(tau) = out_f.at(m_outnode_jet);
169  }
170 
171  return StatusCode::SUCCESS;
172 }
173 
TauGNNEvaluator::m_doTrackClassification
Gaudi::Property< bool > m_doTrackClassification
Definition: TauGNNEvaluator.h:73
TauGNNEvaluator::m_output_pjet
Gaudi::Property< std::string > m_output_pjet
Definition: TauGNNEvaluator.h:66
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
TauGNNEvaluator::m_applyLooseTrackSel
Gaudi::Property< bool > m_applyLooseTrackSel
Definition: TauGNNEvaluator.h:76
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
SG::Accessor< float >
TauGNNEvaluator::load_network
std::unique_ptr< TauGNN > load_network(const std::string &network_file) const
Definition: TauGNNEvaluator.cxx:70
python.base_data.config
config
Definition: base_data.py:20
xAOD::TauJet_v3::nTracks
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Definition: TauJet_v3.cxx:488
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
TauGNNEvaluator::m_net_3p
std::unique_ptr< TauGNN > m_net_3p
Definition: TauGNNEvaluator.h:87
TauGNNEvaluator::m_weightfile_inclusive
Gaudi::Property< std::string > m_weightfile_inclusive
Definition: TauGNNEvaluator.h:56
TauGNNEvaluator::m_max_clusters
Gaudi::Property< int > m_max_clusters
Definition: TauGNNEvaluator.h:70
TauGNNEvaluator::m_net_1p
std::unique_ptr< TauGNN > m_net_1p
Definition: TauGNNEvaluator.h:85
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_weightfile_3p
Gaudi::Property< std::string > m_weightfile_3p
Definition: TauGNNEvaluator.h:60
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:18
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauGNNEvaluator::m_applyTightTrackSel
Gaudi::Property< bool > m_applyTightTrackSel
Definition: TauGNNEvaluator.h:77
TauGNNEvaluator.h
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
TauGNNEvaluator::m_max_cluster_dr
Gaudi::Property< float > m_max_cluster_dr
Definition: TauGNNEvaluator.h:71
xAOD::TauJet_v3::nTracksCharged
size_t nTracksCharged() const
Definition: TauJet_v3.cxx:494
TauGNNEvaluator::m_net_0p
std::unique_ptr< TauGNN > m_net_0p
Definition: TauGNNEvaluator.h:84
TauGNNEvaluator::m_useTRT
Gaudi::Property< bool > m_useTRT
Definition: TauGNNEvaluator.h:74
SG::Decorator< char >
TauGNNEvaluator::m_net_2p
std::unique_ptr< TauGNN > m_net_2p
Definition: TauGNNEvaluator.h:86
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
TauGNNEvaluator::m_weightfile_2p
Gaudi::Property< std::string > m_weightfile_2p
Definition: TauGNNEvaluator.h:59
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauGNNEvaluator::m_output_discriminant
Gaudi::Property< unsigned int > m_output_discriminant
Definition: TauGNNEvaluator.h:67
TauGNNEvaluator::m_tauContainerName
Gaudi::Property< std::string > m_tauContainerName
Definition: TauGNNEvaluator.h:52
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
TauGNNEvaluator::m_max_tracks
Gaudi::Property< int > m_max_tracks
Definition: TauGNNEvaluator.h:69
TauGNNEvaluator::m_net_inclusive
std::unique_ptr< TauGNN > m_net_inclusive
Definition: TauGNNEvaluator.h:83
TauGNNEvaluator::m_minTauPt
Gaudi::Property< float > m_minTauPt
Definition: TauGNNEvaluator.h:75
TauGNNEvaluator::m_outnode_tau
Gaudi::Property< std::string > m_outnode_tau
Definition: TauGNNEvaluator.h:78
merge.output
output
Definition: merge.py:16
TauGNNEvaluator::m_input_layer_tracks
Gaudi::Property< std::string > m_input_layer_tracks
Definition: TauGNNEvaluator.h:62
PathResolver.h
TauGNNEvaluator::m_input_layer_scalar
Gaudi::Property< std::string > m_input_layer_scalar
Definition: TauGNNEvaluator.h:61
TauGNNEvaluator::m_min_prong_track_pt
Gaudi::Property< float > m_min_prong_track_pt
Definition: TauGNNEvaluator.h:80
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
threshold
Definition: chainparser.cxx:74
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
TauGNNEvaluator::m_doVertexCorrection
Gaudi::Property< bool > m_doVertexCorrection
Definition: TauGNNEvaluator.h:72
TauGNNEvaluator::m_outnode_jet
Gaudi::Property< std::string > m_outnode_jet
Definition: TauGNNEvaluator.h:79
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
TauGNNEvaluator::m_weightfile_0p
Gaudi::Property< std::string > m_weightfile_0p
Definition: TauGNNEvaluator.h:57
HelperFunctions.h
TauGNNEvaluator::m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_1p
Definition: TauGNNEvaluator.h:58
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:16
TauGNNEvaluator::m_input_layer_clusters
Gaudi::Property< std::string > m_input_layer_clusters
Definition: TauGNNEvaluator.h:63
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:13
TauGNNEvaluator::m_scoreHandleKey
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
Definition: TauGNNEvaluator.h:53
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:482
TauGNNEvaluator::m_output_ptau
Gaudi::Property< std::string > m_output_ptau
Definition: TauGNNEvaluator.h:65
TauGNNDataLoader::Config
Definition: TauGNNDataLoader.h:59
TauGNNEvaluator::m_output_varname
Gaudi::Property< std::string > m_output_varname
Definition: TauGNNEvaluator.h:64
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:103