ATLAS Offline Software
TauGNNEvaluator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 #include "tauRecTools/TauGNN.h"
8 
10 
11 #include <algorithm>
12 
13 
16  m_net(nullptr){
17 
18  declareProperty("NetworkFile", m_weightfile = "");
19  declareProperty("OutputVarname", m_output_varname = "GNTauScore");
20  declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau");
21  declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet");
22  declareProperty("MaxTracks", m_max_tracks = 30);
23  declareProperty("MaxClusters", m_max_clusters = 20);
24  declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
25  declareProperty("VertexCorrection", m_doVertexCorrection = true);
26  declareProperty("DecorateTracks", m_decorateTracks = false);
27  declareProperty("TrackClassification", m_doTrackClassification = true);
28  declareProperty("MinTauPt", m_minTauPt = 0.);
29 
30  // Naming conventions for the network weight files:
31  declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars");
32  declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars");
33  declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars");
34  declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb");
35  declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu");
36  }
37 
39 
41  ATH_MSG_INFO("Initializing TauGNNEvaluator");
42 
43  std::string weightfile("");
44 
45  // Use PathResolver to search for the weight files
46  if (!m_weightfile.empty()) {
47  weightfile = find_file(m_weightfile);
48  if (weightfile.empty()) {
49  ATH_MSG_ERROR("Could not find network weights: " << m_weightfile);
50  return StatusCode::FAILURE;
51  } else {
52  ATH_MSG_INFO("Using network config: " << weightfile);
53  }
54  }
55 
56  // Set the layer and node names in the weight file
58  config.input_layer_scalar = m_input_layer_scalar;
59  config.input_layer_tracks = m_input_layer_tracks;
60  config.input_layer_clusters = m_input_layer_clusters;
61  config.output_node_tau = m_outnode_tau;
62  config.output_node_jet = m_outnode_jet;
63 
64  // Load the weights and create the network
65  if (!weightfile.empty()) {
66  m_net = std::make_unique<TauGNN>(weightfile, config);
67  if (!m_net) {
68  ATH_MSG_ERROR("No network configured.");
69  return StatusCode::FAILURE;
70  }
71  }
72 
73  return StatusCode::SUCCESS;
74 }
75 
77  // Output variable Decorators
81  const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass");
82  // Set default score and overwrite later
83  output(tau) = -1111.0f;
84  out_ptau(tau) = -1111.0f;
85  out_pjet(tau) = -1111.0f;
86 
87  //Skip execution for low-pT taus to save resources
88  if (tau.pt() < m_minTauPt) {
89  return StatusCode::SUCCESS;
90  }
91 
92  // Get input objects
93  std::vector<const xAOD::TauTrack *> tracks;
94  ATH_CHECK(get_tracks(tau, tracks));
95  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
97 
98  // Truncate tracks
99  int numTracksMax = std::min(m_max_tracks, tracks.size());
100  std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
101  // Evaluate networks
102  if (m_net) {
103  auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters);
104  output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau)));
105  out_ptau(tau)=out_f.at(m_outnode_tau);
106  out_pjet(tau)=out_f.at(m_outnode_jet);
107  if (m_decorateTracks){
108  for(unsigned int i=0;i<tracks.size();i++){
109  if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);}
110  else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc
111  }
112  }
113  }
114 
115  return StatusCode::SUCCESS;
116 }
117 
119  return m_net.get();
120 }
121 
122 
123 StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
124  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
125 
126  // Skip unclassified tracks:
127  // - the track is a LRT and classifyLRT = false
128  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
129  // - track classification is not run (trigger)
132  while(it != tracks.end()) {
133  if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
134  it = tracks.erase(it);
135  }
136  else {
137  ++it;
138  }
139  }
140  }
141 
142  // Sort by descending pt
143  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
144  return lhs->pt() > rhs->pt();
145  };
146  std::sort(tracks.begin(), tracks.end(), cmp_pt);
147  out = std::move(tracks);
148 
149  return StatusCode::SUCCESS;
150 }
151 
152 StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
153 
154  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
155 
156  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) {
157  TLorentzVector clusterP4 = vertexedCluster.p4();
158  if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
159 
160  clusters.push_back(vertexedCluster);
161  }
162 
163  // Sort by descending et
164  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
165  const xAOD::CaloVertexedTopoCluster& rhs) {
166  return lhs.p4().Et() > rhs.p4().Et();
167  };
168  std::sort(clusters.begin(), clusters.end(), et_cmp);
169 
170  // Truncate clusters
171  if (clusters.size() > m_max_clusters) {
172  clusters.resize(m_max_clusters, clusters[0]);
173  }
174 
175  return StatusCode::SUCCESS;
176 }
xAOD::iterator
JetConstituentVector::iterator iterator
Definition: JetConstituentVector.cxx:68
python.CaloRecoConfig.f
f
Definition: CaloRecoConfig.py:127
xAOD::CaloVertexedClusterBase::p4
virtual FourMom_t p4() const final
The full 4-momentum of the particle.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedClusterBase.h:88
TauGNNEvaluator::m_net
std::unique_ptr< TauGNN > m_net
Definition: TauGNNEvaluator.h:67
TauGNNEvaluator::m_input_layer_scalar
std::string m_input_layer_scalar
Definition: TauGNNEvaluator.h:60
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
TauGNNEvaluator::m_doTrackClassification
bool m_doTrackClassification
Definition: TauGNNEvaluator.h:56
SG::Accessor
Helper class to provide type-safe access to aux data.
Definition: Control/AthContainers/AthContainers/Accessor.h:66
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
tauRecTools::getTauAxis
TLorentzVector getTauAxis(const xAOD::TauJet &tau, bool doVertexCorrection=true)
Return the four momentum of the tau axis The tau axis is widely used to select clusters and cells in ...
Definition: Reconstruction/tauRecTools/Root/HelperFunctions.cxx:33
TauGNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauGNNEvaluator.cxx:123
TauGNNEvaluator::m_max_clusters
std::size_t m_max_clusters
Definition: TauGNNEvaluator.h:52
skel.it
it
Definition: skel.GENtoEVGEN.py:423
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauGNNEvaluator::get_gnn
const TauGNN * get_gnn() const
Definition: TauGNNEvaluator.cxx:118
TauGNNEvaluator::m_input_layer_tracks
std::string m_input_layer_tracks
Definition: TauGNNEvaluator.h:61
TauGNNEvaluator::m_input_layer_clusters
std::string m_input_layer_clusters
Definition: TauGNNEvaluator.h:62
TauGNNEvaluator::m_outnode_tau
std::string m_outnode_tau
Definition: TauGNNEvaluator.h:63
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauGNNEvaluator::m_output_ptau
std::string m_output_ptau
Definition: TauGNNEvaluator.h:48
TauGNNEvaluator::m_weightfile
std::string m_weightfile
Definition: TauGNNEvaluator.h:50
TauGNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauGNNEvaluator.cxx:40
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauGNNEvaluator.h
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
SG::Decorator
Helper class to provide type-safe access to aux data.
Definition: Decorator.h:58
TauGNNEvaluator::m_output_pjet
std::string m_output_pjet
Definition: TauGNNEvaluator.h:49
lumiFormat.i
int i
Definition: lumiFormat.py:92
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
TauGNNEvaluator::m_decorateTracks
bool m_decorateTracks
Definition: TauGNNEvaluator.h:57
TauGNN.h
TauGNNEvaluator::m_max_tracks
std::size_t m_max_tracks
Definition: TauGNNEvaluator.h:51
min
#define min(a, b)
Definition: cfImp.cxx:40
merge.output
output
Definition: merge.py:17
PathResolver.h
TauGNNEvaluator::m_max_cluster_dr
float m_max_cluster_dr
Definition: TauGNNEvaluator.h:53
xAOD::TauTrack_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
TauRecToolBase::find_file
std::string find_file(const std::string &fname) const
Definition: TauRecToolBase.cxx:19
TauGNN
Wrapper around ONNXUtil to compute the output score of a model.
Definition: TauGNN.h:36
xAOD::TauJet_v3::vertexedClusters
std::vector< xAOD::CaloVertexedTopoCluster > vertexedClusters() const
Definition: TauJet_v3.cxx:626
TauGNNEvaluator::m_output_varname
std::string m_output_varname
Definition: TauGNNEvaluator.h:47
TauGNNEvaluator::m_minTauPt
float m_minTauPt
Definition: TauGNNEvaluator.h:54
TauGNNEvaluator::m_doVertexCorrection
bool m_doVertexCorrection
Definition: TauGNNEvaluator.h:55
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
HelperFunctions.h
TauGNNEvaluator::m_outnode_jet
std::string m_outnode_jet
Definition: TauGNNEvaluator.h:64
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:72
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNNEvaluator::~TauGNNEvaluator
virtual ~TauGNNEvaluator()
Definition: TauGNNEvaluator.cxx:38
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauGNN::Config
Definition: TauGNN.h:39
TauGNNEvaluator::TauGNNEvaluator
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Definition: TauGNNEvaluator.cxx:14
TauGNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauGNNEvaluator.cxx:152
xAOD::TauJet_v3::allTracks
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Definition: TauJet_v3.cxx:514
TauGNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauGNNEvaluator.cxx:76
xAOD::TauJetParameters::unclassified
@ unclassified
Definition: TauDefs.h:410