ATLAS Offline Software
Loading...
Searching...
No Matches
TauDecayModeNNClassifier.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
5// local include(s)
7
8// helper function include(s)
10
11// standard library include(s)
12#include <array>
13#include <functional>
14#include <algorithm>
15#include <fstream>
16
17using PFOPtr = const xAOD::PFO *;
18using TrkPtr = const xAOD::TauTrack *;
22using ValueMap = std::map<std::string, double>;
23using VectorMap = std::map<std::string, std::vector<double>>;
24using InputMap = std::map<std::string, ValueMap>;
25using InputSequenceMap = std::map<std::string, VectorMap>;
26
28 : TauRecToolBase(name)
29{
30}
31
35
37{
38 ATH_MSG_INFO("Initializing TauDecayModeNNClassifier");
39
40 // find input JSON file
41 std::string weightFile = find_file(m_weightFile);
42 if (weightFile.empty())
43 {
44 ATH_MSG_ERROR("Could not find network weights: " << m_weightFile);
45 return StatusCode::FAILURE;
46 }
47 ATH_MSG_INFO("Loaded network configuration from: " << weightFile);
48
49 // load lwt graph configuration
50 std::ifstream inputFile(weightFile);
51 lwt::GraphConfig lwtGraphConfig;
52 try
53 {
54 lwtGraphConfig = lwt::parse_json_graph(inputFile);
55 }
56 catch (const std::logic_error &e)
57 {
58 ATH_MSG_ERROR("Error parsing network config: " << e.what());
59 return StatusCode::FAILURE;
60 }
61
62 // configure neural network
63 try
64 {
65 m_lwtGraph = std::make_unique<lwt::LightweightGraph>(lwtGraphConfig, lwtGraphConfig.outputs.cbegin()->first);
66 }
67 catch (const lwt::NNConfigurationException &e)
68 {
69 ATH_MSG_ERROR("Error configuring network: " << e.what());
70 return StatusCode::FAILURE;
71 }
72
73 return StatusCode::SUCCESS;
74}
75
77{
78 // inputs
79 // ------
80 // m_inputMap will not hold any information,
81 // but it is required by the lwtnn API.
82 //
83 InputMap inputMapDummy;
84 InputSequenceMap inputSeqMap;
85 std::set<std::string> branches = {"TauTrack", "NeutralPFO", "ShotPFO", "ConvTrack"};
86 DMHelper::initMapKeys(inputSeqMap, branches);
87
88 ATH_CHECK(getInputs(xTau, inputSeqMap));
89
90 // output
91 // ------
92 ValueMap outputs;
93
94 // inference
95 // ---------
96 try
97 {
98 outputs = m_lwtGraph->compute(inputMapDummy, inputSeqMap);
99 }
100 catch (const std::exception &e)
101 {
102 ATH_MSG_ERROR("Error evaluating the network: " << e.what());
103 return StatusCode::FAILURE;
104 }
105
106 // Results
107 // -------
108 // Decay modes are "1p0n", "1p1n", "1pXn", "3p0n", "3pXn",
109 // here they are encoded as 0, 1, 2, 3, 4
110 //
111 std::array<float, DMVar::nClasses> probs;
112 // the prefix to match to output name in the json weight file
113 std::string prefix = "c_";
114 for (std::size_t i = 0; i < probs.size(); ++i)
115 {
116 probs[i] = outputs.at(prefix + DMVar::sModeNames[i]);
117 }
118
119 // Determine decay mode from classification results
120 // If requested: ensures consistency between reconstructed number of tracks and decay mode
121 // For non 1 / 3-track taus, classification is performed by maximum mode probability
122 //
123 std::array<float, DMVar::nClasses>::const_iterator itMax;
124 if (m_ensureTrackConsistency && xTau.nTracks() == 1)
125 {
126 // maximum probability of "1p0n", "1p1n", "1pXn"
127 itMax = std::max_element(probs.cbegin(), probs.cbegin() + 3);
128 }
129 else if (m_ensureTrackConsistency && xTau.nTracks() == 3)
130 {
131 // maximum probability of "3p0n", "3pXn"
132 itMax = std::max_element(probs.cbegin() + 3, probs.cend());
133 }
134 else
135 {
136 // maximum probability of all
137 itMax = std::max_element(probs.cbegin(), probs.cend());
138 }
139
140 const SG::Accessor<int> accDecayMode(m_outputName);
141 accDecayMode(xTau) = std::distance(probs.cbegin(), itMax);
142
143 if (m_decorateProb)
144 {
145 for (std::size_t i = 0; i < probs.size(); ++i)
146 {
147 const std::string probName = m_probPrefix + DMVar::sModeNames[i];
148 const SG::Accessor<float> accProb(probName);
149 accProb(xTau) = probs[i];
150 }
151 }
152
153 return StatusCode::SUCCESS;
154}
155
156StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSequenceMap &inputSeqMap) const
157{
158 std::vector<TrkPtr> vTauTracks;
159 std::vector<PFOPtr> vNeutralPFOs;
160 std::vector<PFOPtr> vShotPFOs;
161 std::vector<TrkPtr> vConvTracks;
162
163 // set objects
164 // -----------
165
166 // classified tau tracks
168
169 // neutral PFOs
170 for (std::size_t i = 0; i < xTau.nNeutralPFOs(); ++i)
171 {
172 const auto pfo = xTau.neutralPFO(i);
173 // Apply pt threshold
174 if (pfo->pt() < m_neutralPFOPtCut * 1e3)
175 continue;
176 vNeutralPFOs.push_back(pfo);
177 }
178
179 // shot PFOs
180 for (std::size_t i = 0; i < xTau.nShotPFOs(); ++i)
181 {
182 const auto pfo = xTau.shotPFO(i);
183 // skip PFOs without photons
184 int nPhotons{-1};
185 try
186 {
187 nPhotons = DMVar::pfoAttr<int>(pfo, PFOAttributes::tauShots_nPhotons);
188 }
189 catch (const std::exception &e)
190 {
191 ATH_MSG_ERROR("Error retrieving tauShots_nPhotons: " << e.what());
192 return StatusCode::FAILURE;
193 }
194 if (nPhotons < 1)
195 continue;
196 vShotPFOs.push_back(pfo);
197 }
198
199 // classified conversion tracks
201
206
207 // set variables
208 // -------------
209
210 // tau variables
211 const TLorentzVector &tau_p4 = xTau.p4(xAOD::TauJetParameters::TauCalibType::IntermediateAxis);
212
213 // pair: (1st) the value, (2nd) successfully retrieved
214 std::pair<float, bool> tau_etaTrkECal{0., false};
215 std::pair<float, bool> tau_phiTrkECal{0., false};
216 if (xTau.nTracks() > 0)
217 {
218 TrkPtr trk = xTau.track(0);
219 if (!trk->detail(xAOD::TauJetParameters::CaloSamplingPhiEM, tau_phiTrkECal.first))
220 {
221 ATH_MSG_WARNING("Failed to retrieve extrapolated track phi in ECal");
222 }
223 else
224 {
225 tau_phiTrkECal.second = true;
226 }
227 if (!trk->detail(xAOD::TauJetParameters::CaloSamplingEtaEM, tau_etaTrkECal.first))
228 {
229 ATH_MSG_WARNING("Failed to retrieve extrapolated track eta in ECal");
230 }
231 else
232 {
233 tau_etaTrkECal.second = true;
234 }
235 }
236
237 // a function to set the common 4-momentum variables, this is needed for all later
238 auto setCommonP4Vars = [&tau_p4, &tau_etaTrkECal, &tau_phiTrkECal](VectorMap &in_seq_map, const TLorentzVector &obj_p4) {
239 in_seq_map["dphiECal"].push_back(DMVar::deltaPhiECal(obj_p4, tau_phiTrkECal));
240 in_seq_map["detaECal"].push_back(DMVar::deltaEtaECal(obj_p4, tau_etaTrkECal));
241 in_seq_map["dphi"].push_back(DMVar::deltaPhi(obj_p4, tau_p4));
242 in_seq_map["deta"].push_back(DMVar::deltaEta(obj_p4, tau_p4));
243 in_seq_map["pt_log"].push_back(DMHelper::Log10Robust(obj_p4.Pt()));
244 in_seq_map["jetpt_log"].push_back(DMHelper::Log10Robust(tau_p4.Pt()));
245 };
246
247 // a function to set the track impact parameter variables
248 auto setTrackIPVars = [](VectorMap &in_seq_map, const TrkPtr &trk) {
249 in_seq_map["d0TJVA"].push_back(trk->d0TJVA());
250 in_seq_map["d0SigTJVA"].push_back(trk->d0SigTJVA());
251 in_seq_map["z0sinthetaTJVA"].push_back(trk->z0sinthetaTJVA());
252 in_seq_map["z0sinthetaSigTJVA"].push_back(trk->z0sinthetaSigTJVA());
253 };
254
255 // a function to set the neutral pfo variables
256 auto setNeutralPFOVars = [](VectorMap &in_seq_map, const PFOPtr &pfo) {
257 // get the attributes of a given PFO object
258 auto getAttr = std::bind(DMVar::pfoAttr<float>, pfo, std::placeholders::_1);
259 auto getAttrInt = std::bind(DMVar::pfoAttr<int>, pfo, std::placeholders::_1);
260
261 in_seq_map["FIRST_ETA"].push_back(getAttr(PFOAttributes::cellBased_FIRST_ETA));
262 in_seq_map["SECOND_R_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_SECOND_R), 1e-3f));
263 in_seq_map["DELTA_THETA"].push_back(getAttr(PFOAttributes::cellBased_DELTA_THETA));
264 in_seq_map["CENTER_LAMBDA_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_CENTER_LAMBDA), 1e-3f));
265 in_seq_map["LONGITUDINAL"].push_back(getAttr(PFOAttributes::cellBased_LONGITUDINAL));
266 in_seq_map["ENG_FRAC_CORE"].push_back(getAttr(PFOAttributes::cellBased_ENG_FRAC_CORE));
267 in_seq_map["SECOND_ENG_DENS_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_SECOND_ENG_DENS), 1e-6f));
268 in_seq_map["NPosECells_EM1"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM1));
269 in_seq_map["NPosECells_EM2"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM2));
270 in_seq_map["energy_EM1"].push_back(getAttr(PFOAttributes::cellBased_energy_EM1));
271 in_seq_map["energy_EM2"].push_back(getAttr(PFOAttributes::cellBased_energy_EM2));
272 in_seq_map["EM1CoreFrac"].push_back(getAttr(PFOAttributes::cellBased_EM1CoreFrac));
273 in_seq_map["firstEtaWRTClusterPosition_EM1"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1));
274 in_seq_map["firstEtaWRTClusterPosition_EM2"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM2));
275 in_seq_map["secondEtaWRTClusterPosition_EM1_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM1), 1e-6f));
276 in_seq_map["secondEtaWRTClusterPosition_EM2_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2), 1e-6f));
277 };
278
279 // set tau tracks variables
280 VectorMap &chrg_map = inputSeqMap.at("TauTrack");
283 for (const auto &trk : vTauTracks)
284 {
285 setCommonP4Vars(chrg_map, trk->p4());
286 setTrackIPVars(chrg_map, trk);
287 }
288
289 // set Neutral PFOs variables
290 VectorMap &neut_map = inputSeqMap.at("NeutralPFO");
293 for (const auto &pfo : vNeutralPFOs)
294 {
295 setCommonP4Vars(neut_map, pfo->p4());
296 try
297 {
298 setNeutralPFOVars(neut_map, pfo);
299 }
300 catch (const std::exception &e)
301 {
302 ATH_MSG_ERROR("Error setting neutral PFO variables: " << e.what());
303 return StatusCode::FAILURE;
304 }
305 }
306
307 // set Shot PFOs variables
308 VectorMap &shot_map = inputSeqMap.at("ShotPFO");
310 for (const auto &pfo : vShotPFOs)
311 {
312 setCommonP4Vars(shot_map, pfo->p4());
313 }
314
315 // set Conversion tracks variables
316 VectorMap &conv_map = inputSeqMap.at("ConvTrack");
319 for (const auto &trk : vConvTracks)
320 {
321 setCommonP4Vars(conv_map, trk->p4());
322 setTrackIPVars(conv_map, trk);
323 }
324
325 return StatusCode::SUCCESS;
326}
327
328// Helper functions
329namespace tauRecTools
330{
331 const std::set<std::string> TauDecayModeNNVariable::sCommonP4Vars = {
332 "dphiECal", "detaECal", "dphi", "deta", "pt_log", "jetpt_log"};
333
334 const std::set<std::string> TauDecayModeNNVariable::sTrackIPVars = {
335 "d0TJVA", "d0SigTJVA", "z0sinthetaTJVA", "z0sinthetaSigTJVA"};
336
337 const std::set<std::string> TauDecayModeNNVariable::sNeutralPFOVars = {
338 "FIRST_ETA", "SECOND_R_log", "DELTA_THETA", "CENTER_LAMBDA_log", "LONGITUDINAL", "ENG_FRAC_CORE",
339 "SECOND_ENG_DENS_log", "NPosECells_EM1", "NPosECells_EM2", "energy_EM1", "energy_EM2", "EM1CoreFrac",
340 "firstEtaWRTClusterPosition_EM1", "firstEtaWRTClusterPosition_EM2",
341 "secondEtaWRTClusterPosition_EM1_log", "secondEtaWRTClusterPosition_EM2_log"};
342
343 const std::array<std::string, TauDecayModeNNVariable::nClasses> TauDecayModeNNVariable::sModeNames = {
344 "1p0n", "1p1n", "1pXn", "3p0n", "3pXn"};
345
346 float TauDecayModeNNVariable::deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau)
347 {
348 return p4_tau.DeltaPhi(p4);
349 }
350
351 float TauDecayModeNNVariable::deltaEta(const TLorentzVector &p4, const TLorentzVector &p4_tau)
352 {
353 return p4.Eta() - p4_tau.Eta();
354 }
355
356 float TauDecayModeNNVariable::deltaPhiECal(const TLorentzVector &p4, const std::pair<float, bool> &tau_phiTrkECal)
357 {
358 // if not retrieved, then set to 0. (mean value)
359 return tau_phiTrkECal.second ? TVector2::Phi_mpi_pi(p4.Phi() - tau_phiTrkECal.first) : 0.0f;
360 }
361
362 float TauDecayModeNNVariable::deltaEtaECal(const TLorentzVector &p4, const std::pair<float, bool> &tau_etaTrkECal)
363 {
364 // if not retrieved, then set to 0. (mean value)
365 return tau_etaTrkECal.second ? p4.Eta() - tau_etaTrkECal.first : 0.0f;
366 }
367
368 template <typename T>
370 {
371 T val{static_cast<T>(0)};
372 if (!pfo->attribute(attr, val))
373 {
374 throw std::runtime_error("Can not retrieve PFO attribute! enum = " + std::to_string(static_cast<unsigned>(attr)));
375 }
376 return val;
377 }
378
379 float TauDecayModeNNHelper::Log10Robust(const float val, const float min_val)
380 {
381 return TMath::Log10(std::max(val, min_val));
382 }
383
384 template <typename T>
385 void TauDecayModeNNHelper::sortAndKeep(std::vector<T> &vec, const std::size_t n_obj)
386 {
387 auto cmp_pt = [](const T lhs, const T rhs) { return lhs->pt() > rhs->pt(); };
388 std::sort(vec.begin(), vec.end(), cmp_pt);
389 if (vec.size() > n_obj)
390 {
391 vec.erase(vec.begin() + n_obj, vec.end());
392 }
393 }
394
395 template <typename T>
396 void TauDecayModeNNHelper::initMapKeys(std::map<std::string, T> &empty_map,
397 const std::set<std::string> &keys)
398 {
399 // T can be any type
400 for (const auto &key : keys)
401 {
402 empty_map[key];
403 }
404 }
405} // namespace tauRecTools
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
std::vector< size_t > vec
std::map< std::string, std::vector< double > > VectorMap
std::map< std::string, VectorMap > InputSequenceMap
const xAOD::TauTrack * TrkPtr
std::map< std::string, double > ValueMap
const xAOD::PFO * PFOPtr
std::map< std::string, ValueMap > InputMap
xAOD::PFODetails::PFOAttributes PFOAttributes
tauRecTools::TauDecayModeNNVariable DMVar
tauRecTools::TauDecayModeNNHelper DMHelper
Helper class to provide type-safe access to aux data.
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_decorateProb
virtual StatusCode getInputs(const xAOD::TauJet &xTau, std::map< std::string, std::map< std::string, std::vector< double > > > &inputSeqMap) const
retrieve the input variables from a TauJet
TauDecayModeNNClassifier(const std::string &name="TauDecayModeNNClassifier")
Gaudi::Property< bool > m_ensureTrackConsistency
Gaudi::Property< std::size_t > m_maxTauTracks
Gaudi::Property< std::string > m_outputName
properties of the tool
Gaudi::Property< std::string > m_weightFile
Gaudi::Property< std::size_t > m_maxConvTracks
Gaudi::Property< float > m_neutralPFOPtCut
virtual StatusCode initialize() override
Tool initializer.
std::unique_ptr< const lwt::LightweightGraph > m_lwtGraph
lwtnn graph
Gaudi::Property< std::string > m_probPrefix
Gaudi::Property< std::size_t > m_maxShotPFOs
Gaudi::Property< std::size_t > m_maxNeutralPFOs
TauRecToolBase(const std::string &name)
std::string find_file(const std::string &fname) const
A closely related class that provides helper functions.
static void sortAndKeep(std::vector< T > &vec, const std::size_t n_obj)
sort the objects and only keep the leading N objects in the vector
static float Log10Robust(const float val, const float min_val=0.)
static void initMapKeys(std::map< std::string, T > &empty_map, const std::set< std::string > &keys)
initialise the map with a set of defined keys
A closely related class that calculates the input variables.
static float deltaEta(const TLorentzVector &p4, const TLorentzVector &p4_tau)
static const std::set< std::string > sCommonP4Vars
static float deltaEtaECal(const TLorentzVector &p4, const std::pair< float, bool > &tau_etaTrkECal)
static const std::set< std::string > sTrackIPVars
static const std::set< std::string > sNeutralPFOVars
static T pfoAttr(const xAOD::PFO *pfo, const xAOD::PFODetails::PFOAttributes &attr)
retrieve the PFO attributes
static const std::array< std::string, nClasses > sModeNames
static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau)
static float deltaPhiECal(const TLorentzVector &p4, const std::pair< float, bool > &tau_phiTrkECal)
bool attribute(PFODetails::PFOAttributes AttributeType, T &anAttribute) const
get a PFO Variable via enum
size_t nNeutralPFOs() const
Get the number of neutral PFO particles associated with this tau.
size_t nShotPFOs() const
Get the number of shot PFO particles associated with this tau.
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition TauJet_v3.cxx:96
const PFO * shotPFO(size_t i) const
Get the pointer to a given shot PFO associated with this tau.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
const PFO * neutralPFO(size_t i) const
Get the pointer to a given neutral PFO associated with this tau.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
std::vector< const TauTrack * > tracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Get the v<const pointer> to a given tauTrack collection associated with this tau.
bool detail(TauJetParameters::TrackDetail detail, float &value) const
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Implementation of a TrackClassifier based on an RNN.
Definition BDTHelper.cxx:12
PFO_v1 PFO
Definition of the current "pfo version".
Definition PFO.h:17
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".