ATLAS Offline Software
Loading...
Searching...
No Matches
MvaTESEvaluator.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5// local include(s)
7
8
9MvaTESEvaluator::MvaTESEvaluator(const std::string& name)
10 : TauRecToolBase(name) {
11}
12
13
15
16 const std::string weightFile = find_file(m_sWeightFileName);
17 m_bdtHelper = std::make_unique<tauRecTools::BDTHelper>();
18 ATH_CHECK(m_bdtHelper->initialize(weightFile));
19
20 if (!m_sWeightFileName0p.empty()) {
21 const std::string weightFile0p = find_file(m_sWeightFileName0p);
22 m_bdtHelper0p = std::make_unique<tauRecTools::BDTHelper>();
23 ATH_CHECK(m_bdtHelper0p->initialize(weightFile0p));
24 }
25
26 return StatusCode::SUCCESS;
27}
28
29
30StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
31
32 std::map<TString, float*> availableVars;
34
35 // Declare input variables to the reader
36 if (!inTrigger()) {
37 availableVars.insert( std::make_pair("TauJetsAuxDyn.mu", &vars.mu) );
38 availableVars.insert( std::make_pair("TauJetsAuxDyn.nVtxPU", &vars.nVtxPU) );
39 availableVars.insert( std::make_pair("TauJetsAuxDyn.rho", &vars.rho) );
40 availableVars.insert( std::make_pair("TauJetsAuxDyn.ClustersMeanCenterLambda", &vars.center_lambda) );
41 availableVars.insert( std::make_pair("TauJetsAuxDyn.ClustersMeanFirstEngDens", &vars.first_eng_dens) );
42 availableVars.insert( std::make_pair("TauJetsAuxDyn.ClustersMeanSecondLambda", &vars.second_lambda) );
43 availableVars.insert( std::make_pair("TauJetsAuxDyn.ClustersMeanPresamplerFrac", &vars.presampler_frac) );
44 availableVars.insert( std::make_pair("TauJetsAuxDyn.ClustersMeanEMProbability", &vars.eprobability) );
45 availableVars.insert( std::make_pair("TauJetsAuxDyn.ptIntermediateAxisEM/TauJetsAuxDyn.ptIntermediateAxis", &vars.ptEM_D_ptLC) );
46 availableVars.insert( std::make_pair("TauJetsAuxDyn.ptIntermediateAxis/TauJetsAuxDyn.ptCombined", &vars.ptLC_D_ptCombined) );
47 availableVars.insert( std::make_pair("TauJetsAuxDyn.etaPanTauCellBased", &vars.etaConstituent) );
48 if (m_bdtHelper0p && xTau.nTracks()==0) {
49 availableVars.insert( std::make_pair("log(TauJetsAuxDyn.ptCombined)", &vars.logPtCombined) );
50 availableVars.insert( std::make_pair("TauJetsAuxDyn.LeadClusterFrac", &vars.lead_cluster_frac) );
51 availableVars.insert( std::make_pair("TauJetsAuxDyn.centFrac", &vars.centFrac) );
52 availableVars.insert( std::make_pair("TauJetsAuxDyn.UpsilonCluster", &vars.upsilon_cluster) );
53 availableVars.insert( std::make_pair("TauJetsAuxDyn.ptJetSeed/TauJetsAuxDyn.ptCombined", &vars.ptSeed_D_ptCombined) );
54 }
55 else {
56 availableVars.insert( std::make_pair("TauJetsAuxDyn.ptPanTauCellBased/TauJetsAuxDyn.ptCombined", &vars.ptConstituent_D_ptCombined) );
57 availableVars.insert( std::make_pair("TauJetsAuxDyn.ptCombined", &vars.ptCombined) );
58 availableVars.insert( std::make_pair("TauJetsAuxDyn.PanTau_BDTValue_1p0n_vs_1p1n", &vars.PanTauBDT_1p0n_vs_1p1n) );
59 availableVars.insert( std::make_pair("TauJetsAuxDyn.PanTau_BDTValue_1p1n_vs_1pXn", &vars.PanTauBDT_1p1n_vs_1pXn) );
60 availableVars.insert( std::make_pair("TauJetsAuxDyn.PanTau_BDTValue_3p0n_vs_3pXn", &vars.PanTauBDT_3p0n_vs_3pXn) );
61 availableVars.insert( std::make_pair("TauJetsAuxDyn.nTracks", &vars.nTracks) );
62 availableVars.insert( std::make_pair("TauJetsAuxDyn.PFOEngRelDiff", &vars.PFOEngRelDiff) );
63 }
64 }
65 else {
66 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.mu", &vars.mu) );
67 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ClustersMeanCenterLambda", &vars.center_lambda) );
68 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ClustersMeanFirstEngDens", &vars.first_eng_dens) );
69 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ClustersMeanSecondLambda", &vars.second_lambda) );
70 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ClustersMeanPresamplerFrac", &vars.presampler_frac) );
71 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ClustersMeanEMProbability", &vars.eprobability) );
72 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.LeadClusterFrac", &vars.lead_cluster_frac) );
73 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.SecondClusterFrac", &vars.second_cluster_frac) );
74 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ThirdClusterFrac", &vars.third_cluster_frac) );
75 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.UpsilonCluster", &vars.upsilon_cluster) );
76 availableVars.insert( std::make_pair("log(TrigTauJetsAuxDyn.ptDetectorAxis)", &vars.logPtDetectorAxis) );
77 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.etaDetectorAxis", &vars.etaDetectorAxis) );
78 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ptIntermediateAxisEM/TrigTauJetsAuxDyn.ptDetectorAxis", &vars.ptEM_D_ptLC) );
79 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.ptDetectorAxis/TrigTauJetsAuxDyn.ptJetSeed", &vars.ptDetectorAxis_D_ptJetSeed) );
80 availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.centFrac", &vars.centFrac) );
81 }
82
83 // Retrieve average pileup
84 static const SG::ConstAccessor<float> acc_mu("mu");
85 vars.mu = acc_mu(xTau);
86
87 // Retrieve cluster moments
93
94 static const SG::ConstAccessor<float> acc_ptIntermediateAxisEM("ptIntermediateAxisEM");
95 float ptEM = acc_ptIntermediateAxisEM(xTau);
96
97 if (!inTrigger()) {
98 static const SG::ConstAccessor<float> acc_ptCombined("ptCombined");
99 float ptCombined = acc_ptCombined(xTau);
100
101 if (ptCombined==0.) {
103 // apply MVA calibration as default
104 xTau.setP4(1., xTau.etaPanTauCellBased(), xTau.phiPanTauCellBased(), 0.);
105 return StatusCode::SUCCESS;
106 }
107
108 static const SG::ConstAccessor<int> acc_nVtxPU("nVtxPU");
109 vars.nVtxPU = acc_nVtxPU(xTau);
110
111 static const SG::ConstAccessor<float> acc_rho("rho");
112 vars.rho = acc_rho(xTau);
113
114 float ptLC = xTau.ptIntermediateAxis();
115
116 float ptConstituent = xTau.ptPanTauCellBased();
118
119 vars.ptEM_D_ptLC = (ptLC != 0.) ? ptEM / ptLC : 0.;
120 vars.ptLC_D_ptCombined = ptLC / ptCombined;
121
122 float ptMVA = 0.;
123
124 if (m_bdtHelper0p && xTau.nTracks()==0) {
125 vars.logPtCombined = std::log(ptCombined);
126 vars.ptSeed_D_ptCombined = xTau.ptJetSeed() / ptCombined;
127
128 static const SG::ConstAccessor<float> acc_UpsilonCluster("UpsilonCluster");
129 vars.upsilon_cluster = acc_UpsilonCluster(xTau);
130
131 static const SG::ConstAccessor<float> acc_LeadClusterFrac("LeadClusterFrac");
132 vars.lead_cluster_frac = acc_LeadClusterFrac(xTau);
133
135
136 ptMVA = float( ptCombined * m_bdtHelper0p->getResponse(availableVars) );
137 }
138 else {
139 vars.ptCombined = ptCombined;
140 vars.ptConstituent_D_ptCombined = ptConstituent / ptCombined;
141
142 // Retrieve substructure info
143 static const SG::ConstAccessor<float> acc_PanTauBDT_1p0n_vs_1p1n("PanTau_BDTValue_1p0n_vs_1p1n");
144 static const SG::ConstAccessor<float> acc_PanTauBDT_1p1n_vs_1pXn("PanTau_BDTValue_1p1n_vs_1pXn");
145 static const SG::ConstAccessor<float> acc_PanTauBDT_3p0n_vs_3pXn("PanTau_BDTValue_3p0n_vs_3pXn");
146 // BDT values are initialised to -1111, while actual scores (when evaluated) are within [-5,1], so take max between BDT score and -5-epsilon
147 vars.PanTauBDT_1p0n_vs_1p1n = std::max(acc_PanTauBDT_1p0n_vs_1p1n(xTau), -5.1f);
148 vars.PanTauBDT_1p1n_vs_1pXn = std::max(acc_PanTauBDT_1p1n_vs_1pXn(xTau), -5.1f);
149 vars.PanTauBDT_3p0n_vs_3pXn = std::max(acc_PanTauBDT_3p0n_vs_3pXn(xTau), -5.1f);
150 vars.nTracks = static_cast<float>(xTau.nTracks());
152
153 ptMVA = float( ptCombined * m_bdtHelper->getResponse(availableVars) );
154 }
155
156 if (ptMVA<1.) ptMVA=1.;
158 // apply MVA calibration as default
159 xTau.setP4(ptMVA, vars.etaConstituent, xTau.phiPanTauCellBased(), 0.);
160 }
161 else {
162 // protection but should never happen
163 if (xTau.ptDetectorAxis()==0. || xTau.ptJetSeed()==0.) {
165 xTau.setP4(1., xTau.etaDetectorAxis(), xTau.phiDetectorAxis(), 0.);
166 return StatusCode::SUCCESS;
167 }
168
169 vars.logPtDetectorAxis = std::log(xTau.ptDetectorAxis());
170 vars.etaDetectorAxis = xTau.etaDetectorAxis();
171 vars.ptEM_D_ptLC = ptEM / xTau.ptDetectorAxis();
173
174 static const SG::ConstAccessor<float> acc_UpsilonCluster("UpsilonCluster");
175 static const SG::ConstAccessor<float> acc_LeadClusterFrac("LeadClusterFrac");
176 static const SG::ConstAccessor<float> acc_SecondClusterFrac("SecondClusterFrac");
177 static const SG::ConstAccessor<float> acc_ThirdClusterFrac("ThirdClusterFrac");
178
179 vars.upsilon_cluster = acc_UpsilonCluster(xTau);
180 vars.lead_cluster_frac = acc_LeadClusterFrac(xTau);
181 vars.second_cluster_frac = acc_SecondClusterFrac(xTau);
182 vars.third_cluster_frac = acc_ThirdClusterFrac(xTau);
183
185
186 float ptMVA = float( xTau.ptDetectorAxis() * m_bdtHelper->getResponse(availableVars) );
187 if (ptMVA<1.) ptMVA=1.;
188
190 // apply MVA calibration
191 xTau.setP4(ptMVA, vars.etaDetectorAxis, xTau.phiDetectorAxis(), 0.);
192 }
193
194 ATH_MSG_DEBUG("final calib:" << xTau.pt() << " " << xTau.eta() << " " << xTau.phi() << " " << xTau.e());
195
196 return StatusCode::SUCCESS;
197}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_DEBUG(x)
Gaudi::Property< std::string > m_sWeightFileName
std::unique_ptr< tauRecTools::BDTHelper > m_bdtHelper
std::unique_ptr< tauRecTools::BDTHelper > m_bdtHelper0p
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
virtual StatusCode initialize() override
Tool initializer.
MvaTESEvaluator(const std::string &name="MvaTESEvaluator")
Gaudi::Property< std::string > m_sWeightFileName0p
Helper class to provide constant type-safe access to aux data.
TauRecToolBase(const std::string &name)
std::string find_file(const std::string &fname) const
bool inTrigger() const
virtual double phi() const
The azimuthal angle ( ) of the particle.
double ptPanTauCellBased() const
double ptDetectorAxis() const
virtual double pt() const
The transverse momentum ( ) of the particle.
virtual double e() const
The total energy of the particle.
Definition TauJet_v3.cxx:87
double phiDetectorAxis() const
double etaDetectorAxis() const
double ptIntermediateAxis() const
bool detail(TauJetParameters::Detail detail, int &value) const
Get and set values of common details variables via enum.
void setP4(double pt, double eta, double phi, double m)
Set methods for IParticle values.
double ptJetSeed() const
double etaPanTauCellBased() const
virtual double eta() const
The pseudorapidity ( ) of the particle.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
double phiPanTauCellBased() const
@ centFrac
Get centrality fraction.
Definition TauDefs.h:200
TauJet_v3 TauJet
Definition of the current "tau version".