ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNDataLoader.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
5#pragma once
6
7#include "xAODTau/TauJet.h"
8
10
17
18// Functions to calculate (scalar) input variables
19// Returns a status code indicating success
20namespace TauScalarVars{
21 bool eta(const xAOD::TauJet &tau, float &out);
22 bool absEta(const xAOD::TauJet &tau, float &out);
23 bool centFrac(const xAOD::TauJet &tau, float &out);
24 bool isolFrac(const xAOD::TauJet &tau, float &out);
25 bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out);
26 bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out);
27 bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out);
28 bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out);
29 bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out);
30 bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out);
31 bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out);
32 bool mEflowApprox(const xAOD::TauJet &tau, float &out);
33 bool dRmax(const xAOD::TauJet &tau, float &out);
34 bool trFlightPathSig(const xAOD::TauJet &tau, float &out);
35 bool massTrkSys(const xAOD::TauJet &tau, float &out);
36 bool pt(const xAOD::TauJet &tau, float &out);
37 bool pt_tau_log(const xAOD::TauJet &tau, float &out);
38 bool ptDetectorAxis(const xAOD::TauJet &tau, float &out);
39 bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out);
40 bool ptJetSeed(const xAOD::TauJet &tau, float &out);
41 bool etaJetSeed(const xAOD::TauJet &tau, float &out);
42
43 //functions to calculate input variables needed for the eVeto RNN
44 bool ptJetSeed_log (const xAOD::TauJet &tau, float &out);
45 bool absleadTrackEta (const xAOD::TauJet &tau, float &out);
46 bool leadTrackDeltaEta (const xAOD::TauJet &tau, float &out);
47 bool leadTrackDeltaPhi (const xAOD::TauJet &tau, float &out);
48 bool leadTrackProbNNorHT (const xAOD::TauJet &tau, float &out);
49 bool EMFracFixed (const xAOD::TauJet &tau, float &out);
50 bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, float &out);
51 bool hadLeakFracFixed (const xAOD::TauJet &tau, float &out);
52 bool PSFrac (const xAOD::TauJet &tau, float &out);
53 bool ClustersMeanCenterLambda (const xAOD::TauJet &tau, float &out);
54 bool ClustersMeanEMProbability (const xAOD::TauJet &tau, float &out);
55 bool ClustersMeanFirstEngDens (const xAOD::TauJet &tau, float &out);
56 bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out);
57 bool ClustersMeanSecondLambda (const xAOD::TauJet &tau, float &out);
58 bool EMPOverTrkSysP (const xAOD::TauJet &tau, float &out);
59}//namespace TauScalarVars
60
62 public:
63 struct Config {
64 std::string nnFile;
65 std::string input_layer_scalar;
66 std::string input_layer_tracks;
68 std::string input_layer_hits;
69 std::string output_node_tau;
70 std::string output_node_jet;
74 size_t n_max_hits;
77 bool useTRT;
78 std::string hits_decor_name;
79 };
81 std::shared_ptr<const FlavorTagInference::SaltModel> salt_model,
82 const Config& config
83 );
84 ~TauGNNDataLoader() = default;
85 private:
86 using ScalarCalcByRef_t = std::function<bool(const xAOD::TauJet &, float &)>;
87 using ScalarCalc_t = std::function<float(const xAOD::IParticle*)>;
88 ScalarCalc_t getScalarCalc(const std::string &name) const;
89 inline static const std::unordered_map<std::string, ScalarCalcByRef_t> m_func_map = {
90 {"isolFrac", TauScalarVars::isolFrac},
91 {"centFrac", TauScalarVars::centFrac},
92 {"etOverPtLeadTrk", TauScalarVars::etOverPtLeadTrk},
93 {"innerTrkAvgDist", TauScalarVars::innerTrkAvgDist},
94 {"absipSigLeadTrk", TauScalarVars::absipSigLeadTrk},
95 {"SumPtTrkFrac", TauScalarVars::SumPtTrkFrac},
96 {"sumEMCellEtOverLeadTrkPt", TauScalarVars::sumEMCellEtOverLeadTrkPt},
97 {"EMPOverTrkSysP", TauScalarVars::EMPOverTrkSysP},
98 {"ptRatioEflowApprox", TauScalarVars::ptRatioEflowApprox},
99 {"mEflowApprox", TauScalarVars::mEflowApprox},
100 {"dRmax", TauScalarVars::dRmax},
101 {"trFlightPathSig", TauScalarVars::trFlightPathSig},
102 {"massTrkSys", TauScalarVars::massTrkSys},
103 {"pt", TauScalarVars::pt},
104 {"eta", TauScalarVars::eta},
105 {"ptJetSeed", TauScalarVars::ptJetSeed},
106 {"etaJetSSeed", TauScalarVars::etaJetSeed}
107 };
108};
Scalar eta() const
pseudorapidity method
~TauGNNDataLoader()=default
ScalarCalc_t getScalarCalc(const std::string &name) const
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
TauGNNDataLoader(std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
Class mimicking the AthMessaging class from the offline software.
Class providing the definition of the 4-vector interface.
bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool ptDetectorAxis(const xAOD::TauJet &tau, float &out)
bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out)
bool pt(const xAOD::TauJet &tau, float &out)
bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out)
bool EMFracFixed(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool PSFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out)
bool ptJetSeed_log(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out)
bool etaJetSeed(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool absEta(const xAOD::TauJet &tau, float &out)
bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out)
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out)
bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool pt_tau_log(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
bool eta(const xAOD::TauJet &tau, float &out)
bool ptJetSeed(const xAOD::TauJet &tau, float &out)
bool absleadTrackEta(const xAOD::TauJet &tau, float &out)
TauJet_v3 TauJet
Definition of the current "tau version".