ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNDataLoader.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7using ScalarCalcByRef_t = std::function<bool(const xAOD::TauJet &, float &)>;
8using ScalarCalc_t = std::function<float(const xAOD::IParticle*)>;
9
11 std::shared_ptr<const FlavorTagInference::SaltModel> salt_model,
13) :
15 asg::AsgMessaging("TauGNNDataLoader")
16 {
17 scalarInputName = config.input_layer_scalar;
18 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* scalar_input_node = nullptr;
19 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* track_input_node = nullptr;
20 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* cluster_input_node = nullptr;
22 for (const auto &in_node : graph_config.inputs) {
23 if (in_node.name == config.input_layer_scalar) {
24 scalar_input_node = &in_node;
25 ATH_MSG_DEBUG("Found scalar input node: " << in_node.name);
26 break;
27 }
28 }
29 for (const auto &in_node : graph_config.input_sequences) {
30 if (in_node.name == config.input_layer_tracks) {
31 track_input_node = &in_node;
32 ATH_MSG_DEBUG("Found track input node: " << in_node.name);
33 }
34 if (in_node.name == config.input_layer_clusters) {
35 cluster_input_node = &in_node;
36 ATH_MSG_DEBUG("Found cluster input node: " << in_node.name);
37 }
38 if (in_node.name == config.input_layer_hits) {
39 hit_input_node = &in_node;
40 ATH_MSG_DEBUG("Found hit input node: " << in_node.name);
41 }
42 }
43
44 // Fill the variable names of each input layer into the corresponding vector
45 if (scalar_input_node) {
46 for (const auto &in : scalar_input_node->variables) {
47 addScalarLoader(in.name, getScalarCalc(in.name));
48 }
49 } else if(!config.input_layer_scalar.empty()) {
50 ATH_MSG_ERROR("Scalar input node '" + config.input_layer_scalar + "' not found in the model input configuration");
51 throw std::runtime_error("Scalar input node '" + config.input_layer_scalar + "' not found in the model input configuration");
52 }
53
54 if (track_input_node) {
56 trk_config.name = "tautracks";
57 trk_config.output_name = config.input_layer_tracks;
60 trk_config.max_n_constituents = config.n_max_tracks;
62 trk_config.inputs = {};
63 for (const auto &in : track_input_node->variables) {
64 if (!config.useTRT && (in.name == "eProbabilityHT")) {
65 ATH_MSG_WARNING("Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
66 trk_config.inputs.push_back({"eProbabilityHT_noTRT", FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
67 continue;
68 }
69 trk_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
70 }
71 addVectorLoader(config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
72 } else if(!config.input_layer_tracks.empty() && config.n_max_tracks > 0) {
73 ATH_MSG_ERROR("Track input node '" + config.input_layer_tracks + "' not found in the model input configuration");
74 throw std::runtime_error("Track input node '" + config.input_layer_tracks + "' not found in the model input configuration");
75 }
76
77 if (cluster_input_node) {
79 cls_config.name = "tauclusters";
80 cls_config.output_name = config.input_layer_clusters;
83 cls_config.max_n_constituents = config.n_max_clusters;
85 cls_config.inputs = {};
86 for (const auto &in : cluster_input_node->variables) {
87 cls_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
88 }
89 addVectorLoader(config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config, config.max_dr_cluster, config.doVertexCorrection));
90 } else if(!config.input_layer_clusters.empty() && config.n_max_clusters > 0) {
91 ATH_MSG_ERROR("Cluster input node '" + config.input_layer_clusters + "' not found in the model input configuration");
92 throw std::runtime_error("Cluster input node '" + config.input_layer_clusters + "' not found in the model input configuration");
93 }
94
95 if (hit_input_node) {
97 cls_config.name = "tauhits";
98 cls_config.output_name = config.input_layer_hits;
101 cls_config.max_n_constituents = config.n_max_hits;
103 cls_config.inputs = {};
104 for (const auto &in : hit_input_node->variables) {
105 cls_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
106 }
107 addVectorLoader(config.input_layer_hits, std::make_shared<FlavorTagInference::ConstituentLoaderTauHit>(cls_config, config.hits_decor_name));
108 } else if(!config.input_layer_hits.empty() && config.n_max_hits > 0) {
109 ATH_MSG_ERROR("Hit input node '" + config.input_layer_hits + "' not found in the model input configuration");
110 throw std::runtime_error("Hit input node '" + config.input_layer_hits + "' not found in the model input configuration");
111 }
112}
113
114ScalarCalc_t TauGNNDataLoader::getScalarCalc(const std::string &name) const {
115 // Retrieve calculator function
116 ScalarCalcByRef_t func = nullptr;
117 try {
118 func = m_func_map.at(name);
119 } catch (const std::out_of_range &e) {
120 ATH_MSG_ERROR("Variable '" << name << "' not defined");
121 throw std::runtime_error("Variable '" + name + "' not defined");
122 }
123 return [func](const xAOD::IParticle* p) {
124 auto tau = dynamic_cast<const xAOD::TauJet*>(p);
125 float out;
126 if (!tau) {
127 throw std::runtime_error("Invalid TauJet pointer");
128 }
129 auto success = func(*tau, out);
130 if (!success) {
131 throw std::runtime_error("Error in scalar variable calculation ");
132 }
133 return out;
134 };
135}
136
137
138namespace TauScalarVars {
140
141bool eta(const xAOD::TauJet &tau, float &out) {
142 out = tau.eta();
143 return true;
144}
145
146bool absEta(const xAOD::TauJet &tau, float &out) {
147 out = std::abs(tau.eta());
148 return true;
149}
150
151bool centFrac(const xAOD::TauJet &tau, float &out) {
152 float centFrac;
153 const auto success = tau.detail(TauDetail::centFrac, centFrac);
154 //out = std::min(centFrac, 1.0f);
155 out = centFrac;
156 return success;
157}
158
159bool isolFrac(const xAOD::TauJet &tau, float &out) {
160 float isolFrac;
161 const auto success = tau.detail(TauDetail::isolFrac, isolFrac);
162 //out = std::min(isolFrac, 1.0f);
163 out = isolFrac;
164 return success;
165}
166
167bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out) {
168 float etOverPtLeadTrk;
169 const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk);
170 out = etOverPtLeadTrk;
171 return success;
172}
173
174bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out) {
175 float innerTrkAvgDist;
176 const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist);
177 out = innerTrkAvgDist;
178 return success;
179}
180
181bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out) {
182 float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.;
183 //out = std::min(std::abs(ipSigLeadTrk), 30.0f);
184 out = std::abs(ipSigLeadTrk);
185 return true;
186}
187
188bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out) {
190 const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt);
192 return success;
193}
194
195bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out) {
196 float SumPtTrkFrac;
197 const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac);
198 out = SumPtTrkFrac;
199 return success;
200}
201
202bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out) {
203 float EMPOverTrkSysP;
204 const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP);
205 out = EMPOverTrkSysP;
206 return success;
207}
208
209bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out) {
210 float ptRatioEflowApprox;
211 const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox);
212 //out = std::min(ptRatioEflowApprox, 4.0f);
213 out = ptRatioEflowApprox;
214 return success;
215}
216
217bool mEflowApprox(const xAOD::TauJet &tau, float &out) {
218 float mEflowApprox;
219 const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox);
220 out = mEflowApprox;
221 return success;
222}
223
224bool dRmax(const xAOD::TauJet &tau, float &out) {
225 float dRmax;
226 const auto success = tau.detail(TauDetail::dRmax, dRmax);
227 out = dRmax;
228 return success;
229}
230
231bool trFlightPathSig(const xAOD::TauJet &tau, float &out) {
232 float trFlightPathSig;
233 const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig);
234 out = trFlightPathSig;
235 return success;
236}
237
238bool massTrkSys(const xAOD::TauJet &tau, float &out) {
239 float massTrkSys;
240 const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys);
241 out = massTrkSys;
242 return success;
243}
244
245bool pt(const xAOD::TauJet &tau, float &out) {
246 out = tau.pt();
247 return true;
248}
249
250bool pt_tau_log(const xAOD::TauJet &tau, float &out) {
251 out = std::log10(std::max(tau.pt() / 1000., 1e-6));
252 return true;
253}
254
255bool ptDetectorAxis(const xAOD::TauJet &tau, float &out) {
256 out = tau.ptDetectorAxis();
257 return true;
258}
259
260bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out) {
261 out = tau.ptIntermediateAxis();
262 return true;
263}
264
265bool ptJetSeed(const xAOD::TauJet &tau, float &out) {
266 out = tau.ptJetSeed();
267 return true;
268}
269
270bool etaJetSeed(const xAOD::TauJet &tau, float &out) {
271 out = tau.etaJetSeed();
272 return true;
273}
274
275bool ptJetSeed_log(const xAOD::TauJet &tau, float &out) {
276 out = std::log10(std::max(tau.ptJetSeed(), 1e-3));
277 return true;
278}
279
280bool absleadTrackEta(const xAOD::TauJet &tau, float &out){
281 static const SG::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK");
282 float absEtaLeadTrack = acc_absEtaLeadTrack(tau);
283 out = std::max(0.f, absEtaLeadTrack);
284 return true;
285}
286
287bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out){
288 float absDeltaEta = tau.nTracks() > 0 ? std::abs( tau.track(0)->track()->eta() - tau.eta() ) : -1111.;
289 out = std::max(0.f, absDeltaEta);
290 return true;
291}
292
293bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out){
294 float absDeltaPhi = tau.nTracks() > 0 ? std::abs( tau.track(0)->track()->p4().DeltaPhi(tau.p4()) ) : -1111.;
295 out = std::max(0.f, absDeltaPhi);
296 return true;
297}
298
299bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out){
300 auto tracks = tau.allTracks();
301
302 // Sort tracks in descending pt order
303 if (!tracks.empty()) {
304 auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
305 return lhs->pt() > rhs->pt();
306 };
307 std::sort(tracks.begin(), tracks.end(), cmp_pt);
308
309 const xAOD::TauTrack* tauLeadTrack = tracks.at(0);
310 const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track();
311 float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
312 static const SG::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
313 float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle);
314 out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT;
315 }
316 else {
317 out = 0.;
318 }
319 return true;
320}
321
322bool EMFracFixed(const xAOD::TauJet &tau, float &out){
323 static const SG::ConstAccessor<float> acc_emFracFixed("EMFracFixed");
324 float emFracFixed = acc_emFracFixed(tau);
325 out = std::max(emFracFixed, 0.0f);
326 return true;
327}
328
329bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out){
330 static const SG::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk");
331 float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau);
332 out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f);
333 return true;
334}
335
336bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out){
337 static const SG::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed");
338 float hadLeakFracFixed = acc_hadLeakFracFixed(tau);
339 out = std::max(0.f, hadLeakFracFixed);
340 return true;
341}
342
343bool PSFrac(const xAOD::TauJet &tau, float &out){
344 float PSFrac;
345 const auto success = tau.detail(TauDetail::PSSFraction, PSFrac);
346 out = std::max(0.f,PSFrac);
347 return success;
348}
349
350bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out){
352 const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda);
353 out = std::max(0.f, ClustersMeanCenterLambda);
354 return success;
355}
356
357bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out){
359 const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability);
360 out = std::max(0.f, ClustersMeanEMProbability);
361 return success;
362}
363
364bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out){
366 const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens);
367 out = std::max(-10.f, ClustersMeanFirstEngDens);
368 return success;
369}
370
371bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out){
373 const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac);
374 out = std::max(0.f, ClustersMeanPresamplerFrac);
375 return success;
376}
377
378bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out){
380 const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda);
381 out = std::max(0.f, ClustersMeanSecondLambda);
382 return success;
383}
384} //namespace Scalar
Scalar eta() const
pseudorapidity method
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
if(febId1==febId2)
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
SaltModelGraphConfig::GraphConfig graph_config
Helper class to provide constant type-safe access to aux data.
ScalarCalc_t getScalarCalc(const std::string &name) const
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
TauGNNDataLoader(std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
AsgMessaging(const std::string &name)
Constructor with a name.
Class providing the definition of the 4-vector interface.
double ptDetectorAxis() const
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition TauJet_v3.cxx:96
virtual double pt() const
The transverse momentum ( ) of the particle.
double ptIntermediateAxis() const
bool detail(TauJetParameters::Detail detail, int &value) const
Get and set values of common details variables via enum.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
double ptJetSeed() const
double etaJetSeed() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
virtual double eta() const
The pseudorapidity ( ) of the particle.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual double pt() const
The transverse momentum ( ) of the particle.
float d0SigTJVA() const
const TrackParticle * track() const
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
bool summaryValue(uint8_t &value, const SummaryType &information) const
Accessor for TrackSummary values.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool ptDetectorAxis(const xAOD::TauJet &tau, float &out)
bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out)
bool pt(const xAOD::TauJet &tau, float &out)
bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out)
bool EMFracFixed(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool PSFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out)
bool ptJetSeed_log(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out)
bool etaJetSeed(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool absEta(const xAOD::TauJet &tau, float &out)
bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out)
xAOD::TauJetParameters::Detail TauDetail
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out)
bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool pt_tau_log(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
bool ptJetSeed(const xAOD::TauJet &tau, float &out)
bool absleadTrackEta(const xAOD::TauJet &tau, float &out)
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Detail
Enum for tau parameters - used mainly for backward compatibility with the analysis code.
Definition TauDefs.h:156
TrackParticle_v1 TrackParticle
Reference the current persistent version:
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".
@ eProbabilityHT
Electron probability from High Threshold (HT) information [float].