ATLAS Offline Software
TauDecayModeNNClassifier.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
6 #define TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
7 
8 // base class include(s)
10 
11 // xAOD include(s)
12 #include "xAODTau/TauJet.h"
13 
14 // lwtnn include(s)
15 #include "lwtnn/LightweightGraph.hh"
16 #include "lwtnn/parse_json.hh"
17 #include "lwtnn/Exceptions.hh"
18 
19 // standard library include(s)
20 #include <memory>
21 #include <vector>
22 #include <set>
23 #include <map>
24 
33 {
34 public:
36 
37  explicit TauDecayModeNNClassifier(const std::string &name = "TauDecayModeNNClassifier");
38  virtual ~TauDecayModeNNClassifier();
39 
40  virtual StatusCode initialize() override;
41  virtual StatusCode execute(xAOD::TauJet &xTau) const override;
42 
43 private:
45  std::string m_outputName;
46  std::string m_probPrefix;
47  std::string m_weightFile;
48  std::size_t m_maxTauTracks;
49  std::size_t m_maxNeutralPFOs;
50  std::size_t m_maxShotPFOs;
51  std::size_t m_maxConvTracks;
55 
63  virtual StatusCode getInputs(const xAOD::TauJet &xTau,
64  std::map<std::string, std::map<std::string, std::vector<double>>> &inputSeqMap) const;
66  std::unique_ptr<const lwt::LightweightGraph> m_lwtGraph;
67 };
68 
69 namespace tauRecTools
70 {
75  {
76  public:
78  static const std::size_t nClasses = 5;
79  static const std::set<std::string> sCommonP4Vars;
80  static const std::set<std::string> sTrackIPVars;
81  static const std::set<std::string> sNeutralPFOVars;
82  static const std::array<std::string, nClasses> sModeNames;
83  static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau);
84  static float deltaEta(const TLorentzVector &p4, const TLorentzVector &p4_tau);
85  static float deltaPhiECal(const TLorentzVector &p4, const std::pair<float, bool> &tau_phiTrkECal);
86  static float deltaEtaECal(const TLorentzVector &p4, const std::pair<float, bool> &tau_etaTrkECal);
95  template <typename T>
96  static T pfoAttr(const xAOD::PFO *pfo, const xAOD::PFODetails::PFOAttributes &attr);
97  static float ptSubRatio(const xAOD::PFO *pfo);
98  static float energyFracEM2(const xAOD::PFO *pfo, float energy_em2);
99  };
100 
105  {
106  public:
108  static float Log10Robust(const float val, const float min_val = 0.);
118  template <typename T>
119  static void sortAndKeep(std::vector<T> &vec, const std::size_t n_obj);
128  template <typename T>
129  static void initMapKeys(std::map<std::string, T> &empty_map, const std::set<std::string> &keys);
130  };
131 } // namespace tauRecTools
132 
133 #endif // TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
tauRecTools::TauDecayModeNNVariable::deltaEtaECal
static float deltaEtaECal(const TLorentzVector &p4, const std::pair< float, bool > &tau_etaTrkECal)
Definition: TauDecayModeNNClassifier.cxx:372
tauRecTools::TauDecayModeNNVariable::TauDecayModeNNVariable
TauDecayModeNNVariable()=delete
TauDecayModeNNClassifier::m_maxConvTracks
std::size_t m_maxConvTracks
Definition: TauDecayModeNNClassifier.h:51
TauDecayModeNNClassifier::getInputs
virtual StatusCode getInputs(const xAOD::TauJet &xTau, std::map< std::string, std::map< std::string, std::vector< double >>> &inputSeqMap) const
retrieve the input variables from a TauJet
Definition: TauDecayModeNNClassifier.cxx:166
tauRecTools::TauDecayModeNNVariable::nClasses
static const std::size_t nClasses
Definition: TauDecayModeNNClassifier.h:78
xAOD::PFODetails::PFOAttributes
PFOAttributes
Definition: Event/xAOD/xAODPFlow/xAODPFlow/PFODefs.h:28
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
tauRecTools::TauDecayModeNNVariable::pfoAttr
static T pfoAttr(const xAOD::PFO *pfo, const xAOD::PFODetails::PFOAttributes &attr)
retrieve the PFO attributes
tauRecTools::TauDecayModeNNVariable::deltaEta
static float deltaEta(const TLorentzVector &p4, const TLorentzVector &p4_tau)
Definition: TauDecayModeNNClassifier.cxx:361
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
tauRecTools::TauDecayModeNNVariable::deltaPhi
static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau)
Definition: TauDecayModeNNClassifier.cxx:356
TauDecayModeNNClassifier::m_probPrefix
std::string m_probPrefix
Definition: TauDecayModeNNClassifier.h:46
tauRecTools::TauDecayModeNNVariable::sNeutralPFOVars
static const std::set< std::string > sNeutralPFOVars
Definition: TauDecayModeNNClassifier.h:81
TauDecayModeNNClassifier::~TauDecayModeNNClassifier
virtual ~TauDecayModeNNClassifier()
Definition: TauDecayModeNNClassifier.cxx:42
vec
std::vector< size_t > vec
Definition: CombinationsGeneratorTest.cxx:12
TauDecayModeNNClassifier::TauDecayModeNNClassifier
TauDecayModeNNClassifier(const std::string &name="TauDecayModeNNClassifier")
Definition: TauDecayModeNNClassifier.cxx:27
tauRecTools::TauDecayModeNNVariable
A closely related class that calculates the input variables.
Definition: TauDecayModeNNClassifier.h:75
TauDecayModeNNClassifier::m_maxShotPFOs
std::size_t m_maxShotPFOs
Definition: TauDecayModeNNClassifier.h:50
tauRecTools::TauDecayModeNNHelper
A closely related class that provides helper functions.
Definition: TauDecayModeNNClassifier.h:105
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
tauRecTools::TauDecayModeNNHelper::sortAndKeep
static void sortAndKeep(std::vector< T > &vec, const std::size_t n_obj)
sort the objects and only keep the leading N objects in the vector
Definition: TauDecayModeNNClassifier.cxx:407
TauDecayModeNNClassifier::m_weightFile
std::string m_weightFile
Definition: TauDecayModeNNClassifier.h:47
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauDecayModeNNClassifier::m_ensureTrackConsistency
bool m_ensureTrackConsistency
Definition: TauDecayModeNNClassifier.h:53
tauRecTools::TauDecayModeNNVariable::sTrackIPVars
static const std::set< std::string > sTrackIPVars
Definition: TauDecayModeNNClassifier.h:80
TauDecayModeNNClassifier
Tau decay mode classifier using a neural network.
Definition: TauDecayModeNNClassifier.h:33
tauRecTools::TauDecayModeNNHelper::initMapKeys
static void initMapKeys(std::map< std::string, T > &empty_map, const std::set< std::string > &keys)
initialise the map with a set of defined keys
Definition: TauDecayModeNNClassifier.cxx:418
xAOD::PFO_v1
Class describing a particle flow object.
Definition: PFO_v1.h:35
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
TauDecayModeNNClassifier::m_maxTauTracks
std::size_t m_maxTauTracks
Definition: TauDecayModeNNClassifier.h:48
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
tauRecTools::TauDecayModeNNVariable::ptSubRatio
static float ptSubRatio(const xAOD::PFO *pfo)
Definition: TauDecayModeNNClassifier.cxx:389
TauDecayModeNNClassifier::m_decorateProb
bool m_decorateProb
Definition: TauDecayModeNNClassifier.h:54
TauDecayModeNNClassifier::m_lwtGraph
std::unique_ptr< const lwt::LightweightGraph > m_lwtGraph
lwtnn graph
Definition: TauDecayModeNNClassifier.h:66
tauRecTools::TauDecayModeNNVariable::sCommonP4Vars
static const std::set< std::string > sCommonP4Vars
Definition: TauDecayModeNNClassifier.h:79
TauDecayModeNNClassifier::m_neutralPFOPtCut
float m_neutralPFOPtCut
Definition: TauDecayModeNNClassifier.h:52
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
TauJet.h
TauDecayModeNNClassifier::m_outputName
std::string m_outputName
properties of the tool
Definition: TauDecayModeNNClassifier.h:45
TauDecayModeNNClassifier::execute
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
Definition: TauDecayModeNNClassifier.cxx:86
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:790
tauRecTools
Implementation of a TrackClassifier based on an RNN.
Definition: BDTHelper.cxx:12
TauDecayModeNNClassifier::m_maxNeutralPFOs
std::size_t m_maxNeutralPFOs
Definition: TauDecayModeNNClassifier.h:49
tauRecTools::TauDecayModeNNHelper::TauDecayModeNNHelper
TauDecayModeNNHelper()=delete
TauDecayModeNNClassifier::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauDecayModeNNClassifier.cxx:46
tauRecTools::TauDecayModeNNHelper::Log10Robust
static float Log10Robust(const float val, const float min_val=0.)
Definition: TauDecayModeNNClassifier.cxx:401
tauRecTools::TauDecayModeNNVariable::energyFracEM2
static float energyFracEM2(const xAOD::PFO *pfo, float energy_em2)
Definition: TauDecayModeNNClassifier.cxx:395
tauRecTools::TauDecayModeNNVariable::deltaPhiECal
static float deltaPhiECal(const TLorentzVector &p4, const std::pair< float, bool > &tau_phiTrkECal)
Definition: TauDecayModeNNClassifier.cxx:366
tauRecTools::TauDecayModeNNVariable::sModeNames
static const std::array< std::string, nClasses > sModeNames
Definition: TauDecayModeNNClassifier.h:82