ATLAS Offline Software
TauTrackRNNClassifier.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
6 #define TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
7 
8 // ASG include(s)
9 #include "AsgTools/AsgTool.h"
12 
13 // xAOD include(s)
14 #include "xAODTau/TauJet.h"
17 
18 // local include(s)
22 
23 #include <memory>
24 
32 namespace tauRecTools
33 {
34 
35 class TrackRNN;
36 
37 // We currently allow several input types
38 // The "ValueMap" is for simple rank-1 inputs
39 typedef std::map<std::string, double> ValueMap;
40 // The "VectorMap" is for sequence inputs
41 typedef std::map<std::string, std::vector<double> > VectorMap;
42 
43 typedef std::map<std::string, ValueMap> NodeMap;
44 typedef std::map<std::string, VectorMap> SeqNodeMap;
45 
46 //______________________________________________________________________________
48  : public TauRecToolBase
49 {
50 public:
51 
53 
54  TauTrackRNNClassifier(const std::string& name="TauTrackRNNClassifier");
56 
57  // retrieve all track classifier sub tools
58  virtual StatusCode initialize() override;
59  // pass all tracks in the tau cone to all track classifier sub tools
60  virtual StatusCode executeTrackClassifier(xAOD::TauJet& pTau, xAOD::TauTrackContainer& tauTrackContainer) const override;
61 
62  private:
63  ToolHandleArray<TrackRNN> m_vClassifier {this, "Classifiers", {}};
64 
65  SG::ReadHandleKey<xAOD::VertexContainer> m_vertexContainerKey {this, "Key_vertexInputContainer", "PrimaryVertices", "Vertex container key"};
66 
68 
69 }; // class TauTrackRNNClassifier
70 
71 //______________________________________________________________________________
72 class TrackRNN
73  : public TauRecToolBase
74 {
79 
80  public:
81 
82  TrackRNN(const std::string& name);
83  ~TrackRNN();
84 
85  // configure the MVA object and build a general map to store variables
86  // for possible MVA inputs. Only Variables defined in the root weights file
87  // are passed to the MVA object
88  virtual StatusCode initialize() override;
89 
90  // executes MVA object to get the RNN scores and set classification flags
91  StatusCode classifyTracks(std::vector<xAOD::TauTrack*>& vTracks,
92  xAOD::TauJet& xTau,
93  const xAOD::VertexContainer* vertexContainer,
94  bool skipTracks=false) const;
95 
96 private:
97  // set RNN input variables in the corresponding map entries
98  StatusCode calculateVars(const std::vector<xAOD::TauTrack*>& vTracks,
99  const xAOD::TauJet& xTau,
100  const xAOD::VertexContainer* vertexContainer,
101  VectorMap& valueMap) const;
102 
103  // configurable variables
104  std::string m_inputWeightsPath;
105  unsigned int m_nMaxNtracks;
106 
107  std::unique_ptr<lwtDev::LightweightGraph> m_RNNClassifier;
108 
109 }; // class TrackRNN
110 
111 } // namespace tauRecTools
112 
113 #endif // TAURECTOOLS_TAUTRACKRNNCLASSIFIER_H
tauRecTools::NodeMap
std::map< std::string, ValueMap > NodeMap
Definition: TauTrackRNNClassifier.h:43
tauRecTools::TrackRNN
Definition: TauTrackRNNClassifier.h:74
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
tauRecTools::TauTrackRNNClassifier::m_vertexContainerKey
SG::ReadHandleKey< xAOD::VertexContainer > m_vertexContainerKey
Definition: TauTrackRNNClassifier.h:65
LightweightGraph.h
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
tauRecTools::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauTrackRNNClassifier.h:41
TauTrackContainer.h
tauRecTools::TrackRNN::classifyTracks
StatusCode classifyTracks(std::vector< xAOD::TauTrack * > &vTracks, xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, bool skipTracks=false) const
Definition: TauTrackRNNClassifier.cxx:161
SG::ReadHandleKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Definition: StoreGate/StoreGate/ReadHandleKey.h:39
tauRecTools::TauTrackRNNClassifier::m_vClassifier
ToolHandleArray< TrackRNN > m_vClassifier
Definition: TauTrackRNNClassifier.h:63
tauRecTools::TauTrackRNNClassifier::m_classifyLRT
bool m_classifyLRT
Definition: TauTrackRNNClassifier.h:67
tauRecTools::TrackRNN::m_RNNClassifier
std::unique_ptr< lwtDev::LightweightGraph > m_RNNClassifier
Definition: TauTrackRNNClassifier.h:107
tauRecTools::TrackRNN::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauTrackRNNClassifier.cxx:142
tauRecTools::TrackRNN::calculateVars
StatusCode calculateVars(const std::vector< xAOD::TauTrack * > &vTracks, const xAOD::TauJet &xTau, const xAOD::VertexContainer *vertexContainer, VectorMap &valueMap) const
Definition: TauTrackRNNClassifier.cxx:246
tauRecTools::TauTrackRNNClassifier::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauTrackRNNClassifier.cxx:35
ToolHandleArray.h
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
tauRecTools::TauTrackRNNClassifier::TauTrackRNNClassifier
TauTrackRNNClassifier(const std::string &name="TauTrackRNNClassifier")
Definition: TauTrackRNNClassifier.cxx:24
DataVector
Derived DataVector<T>.
Definition: DataVector.h:581
parse_json.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
tauRecTools::ValueMap
std::map< std::string, double > ValueMap
Definition: TauTrackRNNClassifier.h:35
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
tauRecTools::TrackRNN::m_inputWeightsPath
std::string m_inputWeightsPath
Definition: TauTrackRNNClassifier.h:104
VertexContainer.h
tauRecTools::TrackRNN::m_nMaxNtracks
unsigned int m_nMaxNtracks
Definition: TauTrackRNNClassifier.h:105
TauJet.h
tauRecTools::TauTrackRNNClassifier::executeTrackClassifier
virtual StatusCode executeTrackClassifier(xAOD::TauJet &pTau, xAOD::TauTrackContainer &tauTrackContainer) const override
Definition: TauTrackRNNClassifier.cxx:48
tauRecTools
Implementation of a TrackClassifier based on an RNN.
Definition: BDTHelper.cxx:12
tauRecTools::SeqNodeMap
std::map< std::string, VectorMap > SeqNodeMap
Definition: TauTrackRNNClassifier.h:44
tauRecTools::TauTrackRNNClassifier
Definition: TauTrackRNNClassifier.h:49
AsgTool.h
tauRecTools::TauTrackRNNClassifier::~TauTrackRNNClassifier
~TauTrackRNNClassifier()
Definition: TauTrackRNNClassifier.cxx:30