ATLAS Offline Software
Loading...
Searching...
No Matches
TauTrackRNNClassifier.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
6#define TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
7
8// ASG include(s)
9#include "AsgTools/AsgTool.h"
13
14// xAOD include(s)
15#include "xAODTau/TauJet.h"
18
19// local include(s)
23
24#include <memory>
25
32
33namespace tauRecTools
34{
35
36class TrackRNN;
37
38// We currently allow several input types
39// The "ValueMap" is for simple rank-1 inputs
40typedef std::map<std::string, double> ValueMap;
41// The "VectorMap" is for sequence inputs
42typedef std::map<std::string, std::vector<double> > VectorMap;
43
44typedef std::map<std::string, ValueMap> NodeMap;
45typedef std::map<std::string, VectorMap> SeqNodeMap;
46
47//______________________________________________________________________________
49 : public TauRecToolBase
50{
51public:
52
54
55 TauTrackRNNClassifier(const std::string& name="TauTrackRNNClassifier");
57
58 // retrieve all track classifier sub tools
59 virtual StatusCode initialize() override;
60 // pass all tracks in the tau cone to all track classifier sub tools
61 virtual StatusCode executeTrackClassifier(xAOD::TauJet& pTau, xAOD::TauTrackContainer& tauTrackContainer) const override;
62
63 private:
64
65 StatusCode classifyLRTTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD::TauJet& xTau) const;
66
67 ToolHandleArray<TrackRNN> m_vClassifier {this, "Classifiers", {}};
68
69 SG::ReadHandleKey<xAOD::VertexContainer> m_vertexContainerKey {this, "Key_vertexInputContainer", "PrimaryVertices", "Vertex container key"};
70
71 Gaudi::Property<bool> m_classifyLRT{this, "classifyLRT", true};
72 Gaudi::Property<bool> m_classifyLRTWithDedicated{this, "classifyLRTWithDedicated", false};
73 Gaudi::Property<bool> m_classifyOnlyCoreTracks{this, "ClassifyOnlyCoreTracks", false};
74 Gaudi::Property<bool> m_skipBadTracks{this, "SkipBadTracks", false};
75
76}; // class TauTrackRNNClassifier
77
78//______________________________________________________________________________
80 : public TauRecToolBase
81{
86
87 public:
88
89 TrackRNN(const std::string& name);
90 ~TrackRNN();
91
92 // configure the MVA object and build a general map to store variables
93 // for possible MVA inputs. Only Variables defined in the root weights file
94 // are passed to the MVA object
95 virtual StatusCode initialize() override;
96
97 // executes MVA object to get the RNN scores and set classification flags
98 StatusCode classifyTracks(std::vector<xAOD::TauTrack*>& vTracks,
99 xAOD::TauJet& xTau,
100 const xAOD::VertexContainer* vertexContainer,
101 const xAOD::TauTrackContainer& tauTrackContainer,
102 bool skipTracks=false) const;
103
104private:
105 // set RNN input variables in the corresponding map entries
106 StatusCode calculateVars(const std::vector<xAOD::TauTrack*>& vTracks,
107 const xAOD::TauJet& xTau,
108 const xAOD::VertexContainer* vertexContainer,
109 VectorMap& valueMap) const;
110
111 // properties
112 Gaudi::Property<std::string> m_inputWeightsPath{this, "InputWeightsPath", ""};
113 Gaudi::Property<unsigned int> m_nMaxNtracks{this, "MaxNtracks", 0};
114 Gaudi::Property<bool> m_removeDuplicateChargedTracks {this, "removeDuplicateChargedTracks", false};
115
116 std::unique_ptr<lwtDev::LightweightGraph> m_RNNClassifier;
117
118}; // class TrackRNN
119
120} // namespace tauRecTools
121
122#endif // TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Property holding a SG store/key/clid from which a ReadHandle is made.
The base class for all tau tools.
Property holding a SG store/key/clid from which a ReadHandle is made.
TauRecToolBase(const std::string &name)
ToolHandleArray< TrackRNN > m_vClassifier
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainerKey
TauTrackRNNClassifier(const std::string &name="TauTrackRNNClassifier")
Gaudi::Property< bool > m_classifyLRTWithDedicated
virtual StatusCode executeTrackClassifier(xAOD::TauJet &pTau, xAOD::TauTrackContainer &tauTrackContainer) const override
virtual StatusCode initialize() override
Tool initializer.
StatusCode classifyLRTTracks(std::vector< xAOD::TauTrack * > &vTracks, xAOD::TauJet &xTau) const
Gaudi::Property< bool > m_classifyOnlyCoreTracks
Gaudi::Property< unsigned int > m_nMaxNtracks
Gaudi::Property< bool > m_removeDuplicateChargedTracks
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_inputWeightsPath
StatusCode calculateVars(const std::vector< xAOD::TauTrack * > &vTracks, const xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, VectorMap &valueMap) const
ASG_TOOL_CLASS2(TrackRNN, TauRecToolBase, ITauToolBase) public ~TrackRNN()
Create a proper constructor for Athena.
std::unique_ptr< lwtDev::LightweightGraph > m_RNNClassifier
StatusCode classifyTracks(std::vector< xAOD::TauTrack * > &vTracks, xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, const xAOD::TauTrackContainer &tauTrackContainer, bool skipTracks=false) const
Implementation of a TrackClassifier based on an RNN.
Definition BDTHelper.cxx:12
std::map< std::string, std::vector< double > > VectorMap
std::map< std::string, double > ValueMap
std::map< std::string, ValueMap > NodeMap
std::map< std::string, VectorMap > SeqNodeMap
VertexContainer_v1 VertexContainer
Definition of the current "Vertex container version".
TauJet_v3 TauJet
Definition of the current "tau version".
TauTrackContainer_v1 TauTrackContainer
Definition of the current TauTrack container version.