ATLAS Offline Software
TauGNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 #include "tauRecTools/TauGNN.h"
8 
10 
11 #include <algorithm>
12 
13 
16  m_net(nullptr){
17 
18  declareProperty("NetworkFile", m_weightfile = "");
19  declareProperty("OutputVarname", m_output_varname = "GNTauScore");
20  declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau");
21  declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet");
22  declareProperty("MaxTracks", m_max_tracks = 30);
23  declareProperty("MaxClusters", m_max_clusters = 20);
24  declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
25  declareProperty("VertexCorrection", m_doVertexCorrection = true);
26  declareProperty("DecorateTracks", m_decorateTracks = false);
27  declareProperty("TrackClassification", m_doTrackClassification = true);
28  declareProperty("MinTauPt", m_minTauPt = 0.);
29 
30  // Naming conventions for the network weight files:
31  declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars");
32  declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars");
33  declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars");
34  declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb");
35  declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu");
36  }
37 
39 
41  ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks<<" tracks and "<<m_max_clusters<<" clusters...");
42 
43  std::string weightfile("");
44 
45  // Use PathResolver to search for the weight files
46  if (!m_weightfile.empty()) {
47  weightfile = find_file(m_weightfile);
48  if (weightfile.empty()) {
49  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile);
50  return StatusCode::FAILURE;
51  } else {
52  ATH_MSG_INFO("Using network config: " << weightfile);
53  }
54  }
55 
56  // Set the layer and node names in the weight file
58  config.input_layer_scalar = m_input_layer_scalar;
59  config.input_layer_tracks = m_input_layer_tracks;
60  config.input_layer_clusters = m_input_layer_clusters;
61  config.output_node_tau = m_outnode_tau;
62  config.output_node_jet = m_outnode_jet;
63 
64  // Load the weights and create the network
65  if (!weightfile.empty()) {
66  m_net = std::make_unique<TauGNN>(weightfile, config);
67  if (!m_net) {
68  ATH_MSG_ERROR("No network configured.");
69  return StatusCode::FAILURE;
70  }
71  }
72 
73  return StatusCode::SUCCESS;
74 }
75 
77  // Output variable Decorators
81  const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass");
82  // Set default score and overwrite later
83  output(tau) = -1111.0f;
84  out_ptau(tau) = -1111.0f;
85  out_pjet(tau) = -1111.0f;
86 
87  //Skip execution for low-pT taus to save resources
88  if (tau.pt() < m_minTauPt) {
89  return StatusCode::SUCCESS;
90  }
91 
92  // Get input objects
93  ATH_MSG_DEBUG("Fetching Tracks");
94  std::vector<const xAOD::TauTrack *> tracks;
95  ATH_CHECK(get_tracks(tau, tracks));
96  ATH_MSG_DEBUG("Fetching clusters");
97  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
99  ATH_MSG_DEBUG("Constituent fetching done...");
100 
101  // Truncate tracks
102  int numTracksMax = std::min(m_max_tracks, static_cast<int>(tracks.size()));
103  std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
104  // Evaluate networks
105  if (m_net) {
106  auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters);
107  output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau)));
108  out_ptau(tau)=out_f.at(m_outnode_tau);
109  out_pjet(tau)=out_f.at(m_outnode_jet);
110  if (m_decorateTracks){
111  for(unsigned int i=0;i<tracks.size();i++){
112  if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);}
113  else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc
114  }
115  }
116  }
117 
118  return StatusCode::SUCCESS;
119 }
120 
122  return m_net.get();
123 }
124 
125 
126 StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
127  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
128 
129  // Skip unclassified tracks:
130  // - the track is a LRT and classifyLRT = false
131  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
132  // - track classification is not run (trigger)
135  while(it != tracks.end()) {
136  if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
137  it = tracks.erase(it);
138  }
139  else {
140  ++it;
141  }
142  }
143  }
144 
145  // Sort by descending pt
146  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
147  return lhs->pt() > rhs->pt();
148  };
149  std::sort(tracks.begin(), tracks.end(), cmp_pt);
150  out = std::move(tracks);
151 
152  return StatusCode::SUCCESS;
153 }
154 
155 StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
156 
157  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
158 
159  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) {
160  TLorentzVector clusterP4 = vertexedCluster.p4();
161  if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
162 
163  clusters.push_back(vertexedCluster);
164  }
165 
166  // Sort by descending et
167  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
168  const xAOD::CaloVertexedTopoCluster& rhs) {
169  return lhs.p4().Et() > rhs.p4().Et();
170  };
171  std::sort(clusters.begin(), clusters.end(), et_cmp);
172 
173  // Truncate clusters
174  if (static_cast<int>(clusters.size()) > m_max_clusters) {
175  clusters.resize(m_max_clusters, clusters[0]);
176  }
177 
178  return StatusCode::SUCCESS;
179 }
xAOD::iterator
JetConstituentVector::iterator iterator
Definition: JetConstituentVector.cxx:68
TauGNNEvaluator::m_max_clusters
int m_max_clusters
Definition: TauGNNEvaluator.h:52
xAOD::CaloVertexedClusterBase::p4
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedClusterBase.h:88
TauGNNEvaluator::m_net
std::unique_ptr< TauGNN > m_net
Definition: TauGNNEvaluator.h:67
TauGNNEvaluator::m_input_layer_scalar
std::string m_input_layer_scalar
Definition: TauGNNEvaluator.h:60
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauGNNEvaluator::m_doTrackClassification
bool m_doTrackClassification
Definition: TauGNNEvaluator.h:56
SG::Accessor
Helper class to provide type-safe access to aux data.
Definition: Control/AthContainers/AthContainers/Accessor.h:68
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
tauRecTools::getTauAxis
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
Definition: Reconstruction/tauRecTools/Root/HelperFunctions.cxx:33
TauGNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauGNNEvaluator.cxx:126
skel.it
it
Definition: skel.GENtoEVGEN.py:396
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauGNNEvaluator::get_gnn
const TauGNN * get_gnn() const
Definition: TauGNNEvaluator.cxx:121
TauGNNEvaluator::m_input_layer_tracks
std::string m_input_layer_tracks
Definition: TauGNNEvaluator.h:61
TauGNNEvaluator::m_input_layer_clusters
std::string m_input_layer_clusters
Definition: TauGNNEvaluator.h:62
TauGNNEvaluator::m_outnode_tau
std::string m_outnode_tau
Definition: TauGNNEvaluator.h:63
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_output_ptau
std::string m_output_ptau
Definition: TauGNNEvaluator.h:48
TauGNNEvaluator::m_max_tracks
int m_max_tracks
Definition: TauGNNEvaluator.h:51
TauGNNEvaluator::m_weightfile
std::string m_weightfile
Definition: TauGNNEvaluator.h:50
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:40
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauGNNEvaluator.h
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
SG::Decorator
Helper class to provide type-safe access to aux data.
Definition: Decorator.h:59
TauGNNEvaluator::m_output_pjet
std::string m_output_pjet
Definition: TauGNNEvaluator.h:49
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
hist_file_dump.f
f
Definition: hist_file_dump.py:135
TauGNNEvaluator::m_decorateTracks
bool m_decorateTracks
Definition: TauGNNEvaluator.h:57
TauGNN.h
min
#define min(a, b)
Definition: cfImp.cxx:40
merge.output
output
Definition: merge.py:17
PathResolver.h
TauGNNEvaluator::m_max_cluster_dr
float m_max_cluster_dr
Definition: TauGNNEvaluator.h:53
xAOD::TauTrack_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
TauGNN
Wrapper around ONNXUtil to compute the output score of a model.
Definition: TauGNN.h:36
xAOD::TauJet_v3::vertexedClusters
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
Definition: TauJet_v3.cxx:626
TauGNNEvaluator::m_output_varname
std::string m_output_varname
Definition: TauGNNEvaluator.h:47
TauGNNEvaluator::m_minTauPt
float m_minTauPt
Definition: TauGNNEvaluator.h:54
TauGNNEvaluator::m_doVertexCorrection
bool m_doVertexCorrection
Definition: TauGNNEvaluator.h:55
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
HelperFunctions.h
TauGNNEvaluator::m_outnode_jet
std::string m_outnode_jet
Definition: TauGNNEvaluator.h:64
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:74
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:38
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauGNN::Config
Definition: TauGNN.h:39
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:14
TauGNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauGNNEvaluator.cxx:155
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:514
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:76
xAOD::TauJetParameters::unclassified
@ unclassified
Definition: TauDefs.h:410