ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNDataLoader.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#pragma once
6
7#include "xAODTau/TauJet.h"
8
10
16
17// Functions to calculate (scalar) input variables
18// Returns a status code indicating success
19namespace TauScalarVars{
20 bool absEta(const xAOD::TauJet &tau, float &out);
21 bool centFrac(const xAOD::TauJet &tau, float &out);
22 bool isolFrac(const xAOD::TauJet &tau, float &out);
23 bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out);
24 bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out);
25 bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out);
26 bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out);
27 bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out);
28 bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out);
29 bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out);
30 bool mEflowApprox(const xAOD::TauJet &tau, float &out);
31 bool dRmax(const xAOD::TauJet &tau, float &out);
32 bool trFlightPathSig(const xAOD::TauJet &tau, float &out);
33 bool massTrkSys(const xAOD::TauJet &tau, float &out);
34 bool pt(const xAOD::TauJet &tau, float &out);
35 bool pt_tau_log(const xAOD::TauJet &tau, float &out);
36 bool ptDetectorAxis(const xAOD::TauJet &tau, float &out);
37 bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out);
38
39 //functions to calculate input variables needed for the eVeto RNN
40 bool ptJetSeed_log (const xAOD::TauJet &tau, float &out);
41 bool absleadTrackEta (const xAOD::TauJet &tau, float &out);
42 bool leadTrackDeltaEta (const xAOD::TauJet &tau, float &out);
43 bool leadTrackDeltaPhi (const xAOD::TauJet &tau, float &out);
44 bool leadTrackProbNNorHT (const xAOD::TauJet &tau, float &out);
45 bool EMFracFixed (const xAOD::TauJet &tau, float &out);
46 bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, float &out);
47 bool hadLeakFracFixed (const xAOD::TauJet &tau, float &out);
48 bool PSFrac (const xAOD::TauJet &tau, float &out);
49 bool ClustersMeanCenterLambda (const xAOD::TauJet &tau, float &out);
50 bool ClustersMeanEMProbability (const xAOD::TauJet &tau, float &out);
51 bool ClustersMeanFirstEngDens (const xAOD::TauJet &tau, float &out);
52 bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out);
53 bool ClustersMeanSecondLambda (const xAOD::TauJet &tau, float &out);
54 bool EMPOverTrkSysP (const xAOD::TauJet &tau, float &out);
55}//namespace TauScalarVars
56
58 public:
59 struct Config {
60 std::string nnFile;
61 std::string input_layer_scalar;
62 std::string input_layer_tracks;
64 std::string output_node_tau;
65 std::string output_node_jet;
71 bool useTRT;
72 };
74 std::shared_ptr<const FlavorTagInference::SaltModel> salt_model,
75 const Config& config
76 );
77 ~TauGNNDataLoader() = default;
78 private:
79 using ScalarCalcByRef_t = std::function<bool(const xAOD::TauJet &, float &)>;
80 using ScalarCalc_t = std::function<float(const xAOD::IParticle*)>;
81 ScalarCalc_t getScalarCalc(const std::string &name) const;
82 inline static const std::unordered_map<std::string, ScalarCalcByRef_t> m_func_map = {
83 {"isolFrac", TauScalarVars::isolFrac},
84 {"centFrac", TauScalarVars::centFrac},
85 {"etOverPtLeadTrk", TauScalarVars::etOverPtLeadTrk},
86 {"innerTrkAvgDist", TauScalarVars::innerTrkAvgDist},
87 {"absipSigLeadTrk", TauScalarVars::absipSigLeadTrk},
88 {"SumPtTrkFrac", TauScalarVars::SumPtTrkFrac},
89 {"sumEMCellEtOverLeadTrkPt", TauScalarVars::sumEMCellEtOverLeadTrkPt},
90 {"EMPOverTrkSysP", TauScalarVars::EMPOverTrkSysP},
91 {"ptRatioEflowApprox", TauScalarVars::ptRatioEflowApprox},
92 {"mEflowApprox", TauScalarVars::mEflowApprox},
93 {"dRmax", TauScalarVars::dRmax},
94 {"trFlightPathSig", TauScalarVars::trFlightPathSig},
95 {"massTrkSys", TauScalarVars::massTrkSys},
96 {"pt", TauScalarVars::pt}
97 };
98};
~TauGNNDataLoader()=default
ScalarCalc_t getScalarCalc(const std::string &name) const
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
TauGNNDataLoader(std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
Class mimicking the AthMessaging class from the offline software.
Class providing the definition of the 4-vector interface.
bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool ptDetectorAxis(const xAOD::TauJet &tau, float &out)
bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out)
bool pt(const xAOD::TauJet &tau, float &out)
bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out)
bool EMFracFixed(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool PSFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out)
bool ptJetSeed_log(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool absEta(const xAOD::TauJet &tau, float &out)
bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out)
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out)
bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool pt_tau_log(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
bool absleadTrackEta(const xAOD::TauJet &tau, float &out)
TauJet_v3 TauJet
Definition of the current "tau version".