ATLAS Offline Software
TauTrackRNNClassifier.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
6 #define TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
7 
8 // ASG include(s)
9 #include "AsgTools/AsgTool.h"
13 
14 // xAOD include(s)
15 #include "xAODTau/TauJet.h"
18 
19 // local include(s)
23 
24 #include <memory>
25 
33 namespace tauRecTools
34 {
35 
36 class TrackRNN;
37 
38 // We currently allow several input types
39 // The "ValueMap" is for simple rank-1 inputs
40 typedef std::map<std::string, double> ValueMap;
41 // The "VectorMap" is for sequence inputs
42 typedef std::map<std::string, std::vector<double> > VectorMap;
43 
44 typedef std::map<std::string, ValueMap> NodeMap;
45 typedef std::map<std::string, VectorMap> SeqNodeMap;
46 
47 //______________________________________________________________________________
49  : public TauRecToolBase
50 {
51 public:
52 
54 
55  TauTrackRNNClassifier(const std::string& name="TauTrackRNNClassifier");
57 
58  // retrieve all track classifier sub tools
59  virtual StatusCode initialize() override;
60  // pass all tracks in the tau cone to all track classifier sub tools
61  virtual StatusCode executeTrackClassifier(xAOD::TauJet& pTau, xAOD::TauTrackContainer& tauTrackContainer) const override;
62 
63  private:
64  ToolHandleArray<TrackRNN> m_vClassifier {this, "Classifiers", {}};
65 
66  SG::ReadHandleKey<xAOD::VertexContainer> m_vertexContainerKey {this, "Key_vertexInputContainer", "PrimaryVertices", "Vertex container key"};
67 
68  Gaudi::Property<bool> m_classifyLRT{this, "classifyLRT", true};
69  Gaudi::Property<bool> m_classifyOnlyCoreTracks{this, "ClassifyOnlyCoreTracks", false};
70 
71 }; // class TauTrackRNNClassifier
72 
73 //______________________________________________________________________________
74 class TrackRNN
75  : public TauRecToolBase
76 {
81 
82  public:
83 
84  TrackRNN(const std::string& name);
85  ~TrackRNN();
86 
87  // configure the MVA object and build a general map to store variables
88  // for possible MVA inputs. Only Variables defined in the root weights file
89  // are passed to the MVA object
90  virtual StatusCode initialize() override;
91 
92  // executes MVA object to get the RNN scores and set classification flags
93  StatusCode classifyTracks(std::vector<xAOD::TauTrack*>& vTracks,
94  xAOD::TauJet& xTau,
95  const xAOD::VertexContainer* vertexContainer,
96  bool skipTracks=false) const;
97 
98 private:
99  // set RNN input variables in the corresponding map entries
100  StatusCode calculateVars(const std::vector<xAOD::TauTrack*>& vTracks,
101  const xAOD::TauJet& xTau,
102  const xAOD::VertexContainer* vertexContainer,
103  VectorMap& valueMap) const;
104 
105  // properties
106  Gaudi::Property<std::string> m_inputWeightsPath{this, "InputWeightsPath", ""};
107  Gaudi::Property<unsigned int> m_nMaxNtracks{this, "MaxNtracks", 0};
108 
109  std::unique_ptr<lwtDev::LightweightGraph> m_RNNClassifier;
110 
111 }; // class TrackRNN
112 
113 } // namespace tauRecTools
114 
115 #endif // TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
tauRecTools::NodeMap
std::map< std::string, ValueMap > NodeMap
Definition: TauTrackRNNClassifier.h:44
PropertyWrapper.h
tauRecTools::TrackRNN
Definition: TauTrackRNNClassifier.h:76
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
tauRecTools::TauTrackRNNClassifier::m_vertexContainerKey
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainerKey
Definition: TauTrackRNNClassifier.h:66
LightweightGraph.h
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
tauRecTools::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauTrackRNNClassifier.h:42
TauTrackContainer.h
tauRecTools::TrackRNN::classifyTracks
StatusCode classifyTracks(std::vector< xAOD::TauTrack * > &vTracks, xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, bool skipTracks=false) const
Definition: TauTrackRNNClassifier.cxx:181
SG::ReadHandleKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Definition: StoreGate/StoreGate/ReadHandleKey.h:39
tauRecTools::TauTrackRNNClassifier::m_vClassifier
ToolHandleArray< TrackRNN > m_vClassifier
Definition: TauTrackRNNClassifier.h:64
tauRecTools::TrackRNN::m_RNNClassifier
std::unique_ptr< lwtDev::LightweightGraph > m_RNNClassifier
Definition: TauTrackRNNClassifier.h:109
tauRecTools::TrackRNN::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauTrackRNNClassifier.cxx:162
tauRecTools::TrackRNN::calculateVars
StatusCode calculateVars(const std::vector< xAOD::TauTrack * > &vTracks, const xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, VectorMap &valueMap) const
Definition: TauTrackRNNClassifier.cxx:266
tauRecTools::TauTrackRNNClassifier::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauTrackRNNClassifier.cxx:34
ToolHandleArray.h
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
tauRecTools::TrackRNN::m_inputWeightsPath
Gaudi::Property< std::string > m_inputWeightsPath
Definition: TauTrackRNNClassifier.h:106
tauRecTools::TauTrackRNNClassifier::TauTrackRNNClassifier
TauTrackRNNClassifier(const std::string &name="TauTrackRNNClassifier")
Definition: TauTrackRNNClassifier.cxx:24
tauRecTools::TauTrackRNNClassifier::m_classifyOnlyCoreTracks
Gaudi::Property< bool > m_classifyOnlyCoreTracks
Definition: TauTrackRNNClassifier.h:69
DataVector
Derived DataVector<T>.
Definition: DataVector.h:794
parse_json.h
tauRecTools::TrackRNN::m_nMaxNtracks
Gaudi::Property< unsigned int > m_nMaxNtracks
Definition: TauTrackRNNClassifier.h:107
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
tauRecTools::ValueMap
std::map< std::string, double > ValueMap
Definition: TauTrackRNNClassifier.h:36
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
tauRecTools::TauTrackRNNClassifier::m_classifyLRT
Gaudi::Property< bool > m_classifyLRT
Definition: TauTrackRNNClassifier.h:68
VertexContainer.h
TauJet.h
tauRecTools::TauTrackRNNClassifier::executeTrackClassifier
virtual StatusCode executeTrackClassifier(xAOD::TauJet &pTau, xAOD::TauTrackContainer &tauTrackContainer) const override
Definition: TauTrackRNNClassifier.cxx:47
tauRecTools
Implementation of a TrackClassifier based on an RNN.
Definition: BDTHelper.cxx:12
tauRecTools::SeqNodeMap
std::map< std::string, VectorMap > SeqNodeMap
Definition: TauTrackRNNClassifier.h:45
tauRecTools::TauTrackRNNClassifier
Definition: TauTrackRNNClassifier.h:50
AsgTool.h
tauRecTools::TauTrackRNNClassifier::~TauTrackRNNClassifier
~TauTrackRNNClassifier()
Definition: TauTrackRNNClassifier.cxx:29