ATLAS Offline Software
Loading...
Searching...
No Matches
TauDecayModeNNClassifier.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
6#define TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
7
8// base class include(s)
10
12
13// xAOD include(s)
14#include "xAODTau/TauJet.h"
15
16// lwtnn include(s)
17#include "lwtnn/LightweightGraph.hh"
18#include "lwtnn/parse_json.hh"
19#include "lwtnn/Exceptions.hh"
20
21// standard library include(s)
22#include <memory>
23#include <vector>
24#include <set>
25#include <map>
26
33
35{
36public:
38
39 explicit TauDecayModeNNClassifier(const std::string &name = "TauDecayModeNNClassifier");
41
42 virtual StatusCode initialize() override;
43 virtual StatusCode execute(xAOD::TauJet &xTau) const override;
44
45private:
47 Gaudi::Property<std::string> m_outputName{this, "OutputName", "NNDecayMode"};
48 Gaudi::Property<std::string> m_probPrefix{this, "ProbPrefix", "NNDecayModeProb_"};
49 Gaudi::Property<std::string> m_weightFile{this, "WeightFile", ""};
50 Gaudi::Property<std::size_t> m_maxTauTracks{this, "MaxTauTracks", 3};
51 Gaudi::Property<std::size_t> m_maxNeutralPFOs{this, "MaxNeutralPFOs", 8};
52 Gaudi::Property<std::size_t> m_maxShotPFOs{this, "MaxShotPFOs", 6};
53 Gaudi::Property<std::size_t> m_maxConvTracks{this, "MaxConvTracks", 4};
54 Gaudi::Property<float> m_neutralPFOPtCut{this, "NeutralPFOPtCut", 1.5};
55 Gaudi::Property<bool> m_ensureTrackConsistency{this, "EnsureTrackConsistency", true};
56 Gaudi::Property<bool> m_decorateProb{this, "DecorateProb", true};
57
66 virtual StatusCode getInputs(const xAOD::TauJet &xTau,
67 std::map<std::string, std::map<std::string, std::vector<double>>> &inputSeqMap) const;
69 std::unique_ptr<const lwt::LightweightGraph> m_lwtGraph;
70};
71
72namespace tauRecTools
73{
78 {
79 public:
81 static const std::size_t nClasses = 5;
82 static const std::set<std::string> sCommonP4Vars;
83 static const std::set<std::string> sTrackIPVars;
84 static const std::set<std::string> sNeutralPFOVars;
85 static const std::array<std::string, nClasses> sModeNames;
86 static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau);
87 static float deltaEta(const TLorentzVector &p4, const TLorentzVector &p4_tau);
88 static float deltaPhiECal(const TLorentzVector &p4, const std::pair<float, bool> &tau_phiTrkECal);
89 static float deltaEtaECal(const TLorentzVector &p4, const std::pair<float, bool> &tau_etaTrkECal);
98 template <typename T>
99 static T pfoAttr(const xAOD::PFO *pfo, const xAOD::PFODetails::PFOAttributes &attr);
100 };
101
106 {
107 public:
109 static float Log10Robust(const float val, const float min_val = 0.);
119 template <typename T>
120 static void sortAndKeep(std::vector<T> &vec, const std::size_t n_obj);
129 template <typename T>
130 static void initMapKeys(std::map<std::string, T> &empty_map, const std::set<std::string> &keys);
131 };
132} // namespace tauRecTools
133
134#endif // TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
std::vector< size_t > vec
The base class for all tau tools.
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_decorateProb
virtual StatusCode getInputs(const xAOD::TauJet &xTau, std::map< std::string, std::map< std::string, std::vector< double > > > &inputSeqMap) const
retrieve the input variables from a TauJet
TauDecayModeNNClassifier(const std::string &name="TauDecayModeNNClassifier")
Gaudi::Property< bool > m_ensureTrackConsistency
Gaudi::Property< std::size_t > m_maxTauTracks
Gaudi::Property< std::string > m_outputName
properties of the tool
Gaudi::Property< std::string > m_weightFile
Gaudi::Property< std::size_t > m_maxConvTracks
Gaudi::Property< float > m_neutralPFOPtCut
virtual StatusCode initialize() override
Tool initializer.
std::unique_ptr< const lwt::LightweightGraph > m_lwtGraph
lwtnn graph
Gaudi::Property< std::string > m_probPrefix
Gaudi::Property< std::size_t > m_maxShotPFOs
Gaudi::Property< std::size_t > m_maxNeutralPFOs
TauRecToolBase(const std::string &name)
static void sortAndKeep(std::vector< T > &vec, const std::size_t n_obj)
sort the objects and only keep the leading N objects in the vector
static float Log10Robust(const float val, const float min_val=0.)
static void initMapKeys(std::map< std::string, T > &empty_map, const std::set< std::string > &keys)
initialise the map with a set of defined keys
static float deltaEta(const TLorentzVector &p4, const TLorentzVector &p4_tau)
static const std::set< std::string > sCommonP4Vars
static float deltaEtaECal(const TLorentzVector &p4, const std::pair< float, bool > &tau_etaTrkECal)
static const std::set< std::string > sTrackIPVars
static const std::set< std::string > sNeutralPFOVars
static T pfoAttr(const xAOD::PFO *pfo, const xAOD::PFODetails::PFOAttributes &attr)
retrieve the PFO attributes
static const std::array< std::string, nClasses > sModeNames
static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau)
static float deltaPhiECal(const TLorentzVector &p4, const std::pair< float, bool > &tau_phiTrkECal)
Implementation of a TrackClassifier based on an RNN.
Definition BDTHelper.cxx:12
PFO_v1 PFO
Definition of the current "pfo version".
Definition PFO.h:17
TauJet_v3 TauJet
Definition of the current "tau version".