ATLAS Offline Software
TauJetRNNEvaluator.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUJETRNNEVALUATOR_H
6 #define TAURECTOOLS_TAUJETRNNEVALUATOR_H
7 
9 
10 #include "xAODTau/TauJet.h"
12 
13 #include <memory>
14 
15 class TauJetRNN;
16 
28 public:
30 
31  TauJetRNNEvaluator(const std::string &name = "TauJetRNNEvaluator");
32  virtual ~TauJetRNNEvaluator();
33 
34  virtual StatusCode initialize() override;
35  virtual StatusCode execute(xAOD::TauJet &tau) const override;
36  // Getter for the underlying RNN implementation
37  const TauJetRNN* get_rnn_0p() const;
38  const TauJetRNN* get_rnn_1p() const;
39  const TauJetRNN* get_rnn_2p() const;
40  const TauJetRNN* get_rnn_3p() const;
41 
42  // Selects tracks to be used as input to the network
44  std::vector<const xAOD::TauTrack *> &out) const;
45 
46  // Selects clusters to be used as input to the network
48  std::vector<xAOD::CaloVertexedTopoCluster> &out) const;
49 
50 private:
51  std::string m_output_varname;
52  std::string m_weightfile_0p;
53  std::string m_weightfile_1p;
54  std::string m_weightfile_2p;
55  std::string m_weightfile_3p;
56  std::size_t m_max_tracks;
57  std::size_t m_max_clusters;
61 
62  // Configuration of the weight file
63  std::string m_input_layer_scalar;
64  std::string m_input_layer_tracks;
66  std::string m_output_layer;
67  std::string m_output_node;
68 
69  // Wrappers for lwtnn
70  std::unique_ptr<TauJetRNN> m_net_0p;
71  std::unique_ptr<TauJetRNN> m_net_1p;
72  std::unique_ptr<TauJetRNN> m_net_2p;
73  std::unique_ptr<TauJetRNN> m_net_3p;
74 };
75 
76 #endif // TAURECTOOLS_TAUJETRNNEVALUATOR_H
TauJetRNNEvaluator::m_input_layer_tracks
std::string m_input_layer_tracks
Definition: TauJetRNNEvaluator.h:64
TauJetRNNEvaluator::get_rnn_0p
const TauJetRNN * get_rnn_0p() const
Definition: TauJetRNNEvaluator.cxx:169
TauJetRNNEvaluator::m_max_tracks
std::size_t m_max_tracks
Definition: TauJetRNNEvaluator.h:56
TauJetRNNEvaluator::get_rnn_3p
const TauJetRNN * get_rnn_3p() const
Definition: TauJetRNNEvaluator.cxx:181
TauJetRNNEvaluator::m_doTrackClassification
bool m_doTrackClassification
Definition: TauJetRNNEvaluator.h:60
TauJetRNNEvaluator::m_doVertexCorrection
bool m_doVertexCorrection
Definition: TauJetRNNEvaluator.h:59
TauJetRNNEvaluator::m_output_node
std::string m_output_node
Definition: TauJetRNNEvaluator.h:67
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
TauJetRNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauJetRNNEvaluator.cxx:185
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauJetRNNEvaluator::~TauJetRNNEvaluator
virtual ~TauJetRNNEvaluator()
Definition: TauJetRNNEvaluator.cxx:40
TauJetRNNEvaluator::m_max_clusters
std::size_t m_max_clusters
Definition: TauJetRNNEvaluator.h:57
TauJetRNN
Wrapper around lwtnn to compute the output score of a neural network.
Definition: TauJetRNN.h:34
TauJetRNNEvaluator::m_output_varname
std::string m_output_varname
Definition: TauJetRNNEvaluator.h:51
TauJetRNNEvaluator::get_rnn_2p
const TauJetRNN * get_rnn_2p() const
Definition: TauJetRNNEvaluator.cxx:177
TauJetRNNEvaluator::m_weightfile_1p
std::string m_weightfile_1p
Definition: TauJetRNNEvaluator.h:53
TauJetRNNEvaluator::get_rnn_1p
const TauJetRNN * get_rnn_1p() const
Definition: TauJetRNNEvaluator.cxx:173
TauJetRNNEvaluator::m_input_layer_clusters
std::string m_input_layer_clusters
Definition: TauJetRNNEvaluator.h:65
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauJetRNNEvaluator::m_net_1p
std::unique_ptr< TauJetRNN > m_net_1p
Definition: TauJetRNNEvaluator.h:71
TauJetRNNEvaluator::m_weightfile_0p
std::string m_weightfile_0p
Definition: TauJetRNNEvaluator.h:52
TauJetRNNEvaluator::m_output_layer
std::string m_output_layer
Definition: TauJetRNNEvaluator.h:66
TauJetRNNEvaluator::m_max_cluster_dr
float m_max_cluster_dr
Definition: TauJetRNNEvaluator.h:58
TauJetRNNEvaluator::m_net_0p
std::unique_ptr< TauJetRNN > m_net_0p
Definition: TauJetRNNEvaluator.h:70
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
TauJetRNNEvaluator::m_weightfile_2p
std::string m_weightfile_2p
Definition: TauJetRNNEvaluator.h:54
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
TauJetRNNEvaluator::m_net_2p
std::unique_ptr< TauJetRNN > m_net_2p
Definition: TauJetRNNEvaluator.h:72
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
TauJetRNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauJetRNNEvaluator.cxx:133
TauJet.h
TauJetRNNEvaluator::TauJetRNNEvaluator
TauJetRNNEvaluator(const std::string &name="TauJetRNNEvaluator")
Definition: TauJetRNNEvaluator.cxx:14
TauJetRNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauJetRNNEvaluator.cxx:219
TauJetRNNEvaluator::m_weightfile_3p
std::string m_weightfile_3p
Definition: TauJetRNNEvaluator.h:55
TauJetRNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauJetRNNEvaluator.cxx:42
TauJetRNNEvaluator
Tool to calculate a tau identification score based on neural networks.
Definition: TauJetRNNEvaluator.h:27
TauJetRNNEvaluator::m_net_3p
std::unique_ptr< TauJetRNN > m_net_3p
Definition: TauJetRNNEvaluator.h:73
TauJetRNNEvaluator::m_input_layer_scalar
std::string m_input_layer_scalar
Definition: TauJetRNNEvaluator.h:63