ATLAS Offline Software
Loading...
Searching...
No Matches
TauTrackRNNClassifier.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
6#define TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
7
8// ASG include(s)
9#include "AsgTools/AsgTool.h"
13
14// xAOD include(s)
15#include "xAODTau/TauJet.h"
18
19// local include(s)
23
24#include <memory>
25
32
33namespace tauRecTools
34{
35
36class TrackRNN;
37
38// We currently allow several input types
39// The "ValueMap" is for simple rank-1 inputs
40typedef std::map<std::string, double> ValueMap;
41// The "VectorMap" is for sequence inputs
42typedef std::map<std::string, std::vector<double> > VectorMap;
43
44typedef std::map<std::string, ValueMap> NodeMap;
45typedef std::map<std::string, VectorMap> SeqNodeMap;
46
47//______________________________________________________________________________
49 : public TauRecToolBase
50{
51public:
52
54
55 TauTrackRNNClassifier(const std::string& name="TauTrackRNNClassifier");
57
58 // retrieve all track classifier sub tools
59 virtual StatusCode initialize() override;
60 // pass all tracks in the tau cone to all track classifier sub tools
61 virtual StatusCode executeTrackClassifier(xAOD::TauJet& pTau, xAOD::TauTrackContainer& tauTrackContainer) const override;
62
63 private:
64 ToolHandleArray<TrackRNN> m_vClassifier {this, "Classifiers", {}};
65
66 SG::ReadHandleKey<xAOD::VertexContainer> m_vertexContainerKey {this, "Key_vertexInputContainer", "PrimaryVertices", "Vertex container key"};
67
68 Gaudi::Property<bool> m_classifyLRT{this, "classifyLRT", true};
69 Gaudi::Property<bool> m_classifyOnlyCoreTracks{this, "ClassifyOnlyCoreTracks", false};
70 Gaudi::Property<bool> m_skipBadTracks{this, "SkipBadTracks", false};
71
72}; // class TauTrackRNNClassifier
73
74//______________________________________________________________________________
76 : public TauRecToolBase
77{
82
83 public:
84
85 TrackRNN(const std::string& name);
86 ~TrackRNN();
87
88 // configure the MVA object and build a general map to store variables
89 // for possible MVA inputs. Only Variables defined in the root weights file
90 // are passed to the MVA object
91 virtual StatusCode initialize() override;
92
93 // executes MVA object to get the RNN scores and set classification flags
94 StatusCode classifyTracks(std::vector<xAOD::TauTrack*>& vTracks,
95 xAOD::TauJet& xTau,
96 const xAOD::VertexContainer* vertexContainer,
97 const xAOD::TauTrackContainer& tauTrackContainer,
98 bool skipTracks=false) const;
99
100private:
101 // set RNN input variables in the corresponding map entries
102 StatusCode calculateVars(const std::vector<xAOD::TauTrack*>& vTracks,
103 const xAOD::TauJet& xTau,
104 const xAOD::VertexContainer* vertexContainer,
105 VectorMap& valueMap) const;
106
107 // properties
108 Gaudi::Property<std::string> m_inputWeightsPath{this, "InputWeightsPath", ""};
109 Gaudi::Property<unsigned int> m_nMaxNtracks{this, "MaxNtracks", 0};
110 Gaudi::Property<bool> m_removeDuplicateChargedTracks {this, "removeDuplicateChargedTracks", false};
111
112 std::unique_ptr<lwtDev::LightweightGraph> m_RNNClassifier;
113
114}; // class TrackRNN
115
116} // namespace tauRecTools
117
118#endif // TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Property holding a SG store/key/clid from which a ReadHandle is made.
The base class for all tau tools.
Property holding a SG store/key/clid from which a ReadHandle is made.
TauRecToolBase(const std::string &name)
ToolHandleArray< TrackRNN > m_vClassifier
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainerKey
TauTrackRNNClassifier(const std::string &name="TauTrackRNNClassifier")
virtual StatusCode executeTrackClassifier(xAOD::TauJet &pTau, xAOD::TauTrackContainer &tauTrackContainer) const override
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< bool > m_classifyOnlyCoreTracks
Gaudi::Property< unsigned int > m_nMaxNtracks
Gaudi::Property< bool > m_removeDuplicateChargedTracks
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_inputWeightsPath
StatusCode calculateVars(const std::vector< xAOD::TauTrack * > &vTracks, const xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, VectorMap &valueMap) const
ASG_TOOL_CLASS2(TrackRNN, TauRecToolBase, ITauToolBase) public ~TrackRNN()
Create a proper constructor for Athena.
std::unique_ptr< lwtDev::LightweightGraph > m_RNNClassifier
StatusCode classifyTracks(std::vector< xAOD::TauTrack * > &vTracks, xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, const xAOD::TauTrackContainer &tauTrackContainer, bool skipTracks=false) const
Implementation of a TrackClassifier based on an RNN.
Definition BDTHelper.cxx:12
std::map< std::string, std::vector< double > > VectorMap
std::map< std::string, double > ValueMap
std::map< std::string, ValueMap > NodeMap
std::map< std::string, VectorMap > SeqNodeMap
VertexContainer_v1 VertexContainer
Definition of the current "Vertex container version".
TauJet_v3 TauJet
Definition of the current "tau version".
TauTrackContainer_v1 TauTrackContainer
Definition of the current TauTrack container version.