ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNDataLoader.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7using ScalarCalcByRef_t = std::function<bool(const xAOD::TauJet &, float &)>;
8using ScalarCalc_t = std::function<float(const xAOD::IParticle*)>;
9
11 std::shared_ptr<const FlavorTagInference::SaltModel> salt_model,
13) :
15 asg::AsgMessaging("TauGNNDataLoader")
16 {
17 scalarInputName = config.input_layer_scalar;
18 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* scalar_input_node = nullptr;
19 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* track_input_node = nullptr;
20 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* cluster_input_node = nullptr;
21 for (const auto &in_node : graph_config.inputs) {
22 if (in_node.name == config.input_layer_scalar) {
23 scalar_input_node = &in_node;
24 ATH_MSG_DEBUG("Found scalar input node: " << in_node.name);
25 break;
26 }
27 }
28 for (const auto &in_node : graph_config.input_sequences) {
29 if (in_node.name == config.input_layer_tracks) {
30 track_input_node = &in_node;
31 ATH_MSG_DEBUG("Found track input node: " << in_node.name);
32 }
33 if (in_node.name == config.input_layer_clusters) {
34 cluster_input_node = &in_node;
35 ATH_MSG_DEBUG("Found cluster input node: " << in_node.name);
36 }
37 }
38
39 // Fill the variable names of each input layer into the corresponding vector
40 if (scalar_input_node) {
41 for (const auto &in : scalar_input_node->variables) {
42 addScalarLoader(in.name, getScalarCalc(in.name));
43 }
44 } else {
45 ATH_MSG_ERROR("Scalar input node 'tau_vars' not found in the model input configuration");
46 throw std::runtime_error("Scalar input node 'tau_vars' not found in the model input configuration");
47 }
48
49 if (track_input_node) {
51 trk_config.name = "tautracks";
52 trk_config.output_name = config.input_layer_tracks;
55 trk_config.max_n_constituents = config.n_max_tracks;
57 trk_config.inputs = {};
58 for (const auto &in : track_input_node->variables) {
59 if (!config.useTRT && (in.name == "eProbabilityHT")) {
60 ATH_MSG_WARNING("Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
61 trk_config.inputs.push_back({"eProbabilityHT_noTRT", FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
62 continue;
63 }
64 trk_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
65 }
66 addVectorLoader(config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
67 } else {
68 ATH_MSG_ERROR("Track input node '" + config.input_layer_tracks + "' not found in the model input configuration");
69 throw std::runtime_error("Track input node '" + config.input_layer_tracks + "' not found in the model input configuration");
70 }
71
72 if (cluster_input_node) {
74 cls_config.name = "tauclusters";
75 cls_config.output_name = config.input_layer_clusters;
78 cls_config.max_n_constituents = config.n_max_clusters;
80 cls_config.inputs = {};
81 for (const auto &in : cluster_input_node->variables) {
82 cls_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
83 }
84 addVectorLoader(config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config, config.max_dr_cluster, config.doVertexCorrection));
85 } else {
86 ATH_MSG_ERROR("Cluster input node '" + config.input_layer_clusters + "' not found in the model input configuration");
87 throw std::runtime_error("Cluster input node '" + config.input_layer_clusters + "' not found in the model input configuration");
88 }
89}
90
91ScalarCalc_t TauGNNDataLoader::getScalarCalc(const std::string &name) const {
92 // Retrieve calculator function
93 ScalarCalcByRef_t func = nullptr;
94 try {
95 func = m_func_map.at(name);
96 } catch (const std::out_of_range &e) {
97 ATH_MSG_ERROR("Variable '" << name << "' not defined");
98 throw std::runtime_error("Variable '" + name + "' not defined");
99 }
100 return [func](const xAOD::IParticle* p) {
101 auto tau = dynamic_cast<const xAOD::TauJet*>(p);
102 float out;
103 if (!tau) {
104 throw std::runtime_error("Invalid TauJet pointer");
105 }
106 auto success = func(*tau, out);
107 if (!success) {
108 throw std::runtime_error("Error in scalar variable calculation ");
109 }
110 return out;
111 };
112}
113
114
115namespace TauScalarVars {
117
118bool absEta(const xAOD::TauJet &tau, float &out) {
119 out = std::abs(tau.eta());
120 return true;
121}
122
123bool centFrac(const xAOD::TauJet &tau, float &out) {
124 float centFrac;
125 const auto success = tau.detail(TauDetail::centFrac, centFrac);
126 //out = std::min(centFrac, 1.0f);
127 out = centFrac;
128 return success;
129}
130
131bool isolFrac(const xAOD::TauJet &tau, float &out) {
132 float isolFrac;
133 const auto success = tau.detail(TauDetail::isolFrac, isolFrac);
134 //out = std::min(isolFrac, 1.0f);
135 out = isolFrac;
136 return success;
137}
138
139bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out) {
140 float etOverPtLeadTrk;
141 const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk);
142 out = etOverPtLeadTrk;
143 return success;
144}
145
146bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out) {
147 float innerTrkAvgDist;
148 const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist);
149 out = innerTrkAvgDist;
150 return success;
151}
152
153bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out) {
154 float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.;
155 //out = std::min(std::abs(ipSigLeadTrk), 30.0f);
156 out = std::abs(ipSigLeadTrk);
157 return true;
158}
159
160bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out) {
162 const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt);
164 return success;
165}
166
167bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out) {
168 float SumPtTrkFrac;
169 const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac);
170 out = SumPtTrkFrac;
171 return success;
172}
173
174bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out) {
175 float EMPOverTrkSysP;
176 const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP);
177 out = EMPOverTrkSysP;
178 return success;
179}
180
181bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out) {
182 float ptRatioEflowApprox;
183 const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox);
184 //out = std::min(ptRatioEflowApprox, 4.0f);
185 out = ptRatioEflowApprox;
186 return success;
187}
188
189bool mEflowApprox(const xAOD::TauJet &tau, float &out) {
190 float mEflowApprox;
191 const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox);
192 out = mEflowApprox;
193 return success;
194}
195
196bool dRmax(const xAOD::TauJet &tau, float &out) {
197 float dRmax;
198 const auto success = tau.detail(TauDetail::dRmax, dRmax);
199 out = dRmax;
200 return success;
201}
202
203bool trFlightPathSig(const xAOD::TauJet &tau, float &out) {
204 float trFlightPathSig;
205 const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig);
206 out = trFlightPathSig;
207 return success;
208}
209
210bool massTrkSys(const xAOD::TauJet &tau, float &out) {
211 float massTrkSys;
212 const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys);
213 out = massTrkSys;
214 return success;
215}
216
217bool pt(const xAOD::TauJet &tau, float &out) {
218 out = tau.pt();
219 return true;
220}
221
222bool pt_tau_log(const xAOD::TauJet &tau, float &out) {
223 out = std::log10(std::max(tau.pt() / 1000., 1e-6));
224 return true;
225}
226
227bool ptDetectorAxis(const xAOD::TauJet &tau, float &out) {
228 out = tau.ptDetectorAxis();
229 return true;
230}
231
232bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out) {
233 out = tau.ptIntermediateAxis();
234 return true;
235}
236
237bool ptJetSeed_log(const xAOD::TauJet &tau, float &out) {
238 out = std::log10(std::max(tau.ptJetSeed(), 1e-3));
239 return true;
240}
241
242bool absleadTrackEta(const xAOD::TauJet &tau, float &out){
243 static const SG::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK");
244 float absEtaLeadTrack = acc_absEtaLeadTrack(tau);
245 out = std::max(0.f, absEtaLeadTrack);
246 return true;
247}
248
249bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out){
250 float absDeltaEta = tau.nTracks() > 0 ? std::abs( tau.track(0)->track()->eta() - tau.eta() ) : -1111.;
251 out = std::max(0.f, absDeltaEta);
252 return true;
253}
254
255bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out){
256 float absDeltaPhi = tau.nTracks() > 0 ? std::abs( tau.track(0)->track()->p4().DeltaPhi(tau.p4()) ) : -1111.;
257 out = std::max(0.f, absDeltaPhi);
258 return true;
259}
260
261bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out){
262 auto tracks = tau.allTracks();
263
264 // Sort tracks in descending pt order
265 if (!tracks.empty()) {
266 auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
267 return lhs->pt() > rhs->pt();
268 };
269 std::sort(tracks.begin(), tracks.end(), cmp_pt);
270
271 const xAOD::TauTrack* tauLeadTrack = tracks.at(0);
272 const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track();
273 float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
274 static const SG::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
275 float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle);
276 out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT;
277 }
278 else {
279 out = 0.;
280 }
281 return true;
282}
283
284bool EMFracFixed(const xAOD::TauJet &tau, float &out){
285 static const SG::ConstAccessor<float> acc_emFracFixed("EMFracFixed");
286 float emFracFixed = acc_emFracFixed(tau);
287 out = std::max(emFracFixed, 0.0f);
288 return true;
289}
290
291bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out){
292 static const SG::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk");
293 float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau);
294 out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f);
295 return true;
296}
297
298bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out){
299 static const SG::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed");
300 float hadLeakFracFixed = acc_hadLeakFracFixed(tau);
301 out = std::max(0.f, hadLeakFracFixed);
302 return true;
303}
304
305bool PSFrac(const xAOD::TauJet &tau, float &out){
306 float PSFrac;
307 const auto success = tau.detail(TauDetail::PSSFraction, PSFrac);
308 out = std::max(0.f,PSFrac);
309 return success;
310}
311
312bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out){
314 const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda);
315 out = std::max(0.f, ClustersMeanCenterLambda);
316 return success;
317}
318
319bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out){
321 const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability);
322 out = std::max(0.f, ClustersMeanEMProbability);
323 return success;
324}
325
326bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out){
328 const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens);
329 out = std::max(-10.f, ClustersMeanFirstEngDens);
330 return success;
331}
332
333bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out){
335 const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac);
336 out = std::max(0.f, ClustersMeanPresamplerFrac);
337 return success;
338}
339
340bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out){
342 const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda);
343 out = std::max(0.f, ClustersMeanSecondLambda);
344 return success;
345}
346} //namespace Scalar
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
if(febId1==febId2)
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
SaltModelGraphConfig::GraphConfig graph_config
Helper class to provide constant type-safe access to aux data.
ScalarCalc_t getScalarCalc(const std::string &name) const
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
TauGNNDataLoader(std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
AsgMessaging(const std::string &name)
Constructor with a name.
Class providing the definition of the 4-vector interface.
double ptDetectorAxis() const
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition TauJet_v3.cxx:96
virtual double pt() const
The transverse momentum ( ) of the particle.
double ptIntermediateAxis() const
bool detail(TauJetParameters::Detail detail, int &value) const
Get and set values of common details variables via enum.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
double ptJetSeed() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
virtual double eta() const
The pseudorapidity ( ) of the particle.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual double pt() const
The transverse momentum ( ) of the particle.
float d0SigTJVA() const
const TrackParticle * track() const
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
bool summaryValue(uint8_t &value, const SummaryType &information) const
Accessor for TrackSummary values.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool ptDetectorAxis(const xAOD::TauJet &tau, float &out)
bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out)
bool pt(const xAOD::TauJet &tau, float &out)
bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out)
bool EMFracFixed(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool PSFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out)
bool ptJetSeed_log(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool absEta(const xAOD::TauJet &tau, float &out)
bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out)
xAOD::TauJetParameters::Detail TauDetail
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out)
bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool pt_tau_log(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
bool absleadTrackEta(const xAOD::TauJet &tau, float &out)
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Detail
Enum for tau parameters - used mainly for backward compatibility with the analysis code.
Definition TauDefs.h:156
TrackParticle_v1 TrackParticle
Reference the current persistent version:
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".
@ eProbabilityHT
Electron probability from High Threshold (HT) information [float].