Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TauJetRNNEvaluator.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUJETRNNEVALUATOR_H
6 #define TAURECTOOLS_TAUJETRNNEVALUATOR_H
7 
10 
11 #include "xAODTau/TauJet.h"
13 
14 #include <memory>
15 
16 class TauJetRNN;
17 
29 public:
31 
32  TauJetRNNEvaluator(const std::string &name = "TauJetRNNEvaluator");
33  virtual ~TauJetRNNEvaluator();
34 
35  virtual StatusCode initialize() override;
36  virtual StatusCode execute(xAOD::TauJet &tau) const override;
37  // Getter for the underlying RNN implementation
38  const TauJetRNN* get_rnn_0p() const;
39  const TauJetRNN* get_rnn_1p() const;
40  const TauJetRNN* get_rnn_2p() const;
41  const TauJetRNN* get_rnn_3p() const;
42 
43  // Selects tracks to be used as input to the network
45  std::vector<const xAOD::TauTrack *> &out) const;
46 
47  // Selects clusters to be used as input to the network
49  std::vector<xAOD::CaloVertexedTopoCluster> &out) const;
50 
51 private:
52 
53  // properties
54  Gaudi::Property<std::string> m_weightfile_0p{this, "NetworkFile0P", ""};
55  Gaudi::Property<std::string> m_weightfile_1p{this, "NetworkFile1P", ""};
56  Gaudi::Property<std::string> m_weightfile_2p{this, "NetworkFile2P", ""};
57  Gaudi::Property<std::string> m_weightfile_3p{this, "NetworkFile3P", ""};
58  Gaudi::Property<std::string> m_output_varname{this, "OutputVarname", "RNNJetScore"};
59  Gaudi::Property<std::size_t> m_max_tracks{this, "MaxTracks", 10};
60  Gaudi::Property<std::size_t> m_max_clusters{this, "MaxClusters", 6};
61  Gaudi::Property<float> m_max_cluster_dr{this, "MaxClusterDR", 1.0f};
62  Gaudi::Property<bool> m_doVertexCorrection{this, "VertexCorrection", true};
63  Gaudi::Property<bool> m_doTrackClassification{this, "TrackClassification", true};
64  Gaudi::Property<std::string> m_input_layer_scalar{this, "InputLayerScalar", "scalar"};
65  Gaudi::Property<std::string> m_input_layer_tracks{this, "InputLayerTracks", "tracks"};
66  Gaudi::Property<std::string> m_input_layer_clusters{this, "InputLayerClusters", "clusters"};
67  Gaudi::Property<std::string> m_output_layer{this, "OutputLayer", "rnnid_output"};
68  Gaudi::Property<std::string> m_output_node{this, "OutputNode", "sig_prob"};
69 
70  // Wrappers for lwtnn
71  std::unique_ptr<TauJetRNN> m_net_0p;
72  std::unique_ptr<TauJetRNN> m_net_1p;
73  std::unique_ptr<TauJetRNN> m_net_2p;
74  std::unique_ptr<TauJetRNN> m_net_3p;
75 };
76 
77 #endif // TAURECTOOLS_TAUJETRNNEVALUATOR_H
TauJetRNNEvaluator::get_rnn_0p
const TauJetRNN * get_rnn_0p() const
Definition: TauJetRNNEvaluator.cxx:152
PropertyWrapper.h
TauJetRNNEvaluator::get_rnn_3p
const TauJetRNN * get_rnn_3p() const
Definition: TauJetRNNEvaluator.cxx:164
ASG_TOOL_CLASS2
#define ASG_TOOL_CLASS2(CLASSNAME, INT1, INT2)
Definition: AsgToolMacros.h:77
TauJetRNNEvaluator::get_tracks
StatusCode get_tracks(const xAOD::TauJet &tau, std::vector< const xAOD::TauTrack * > &out) const
Definition: TauJetRNNEvaluator.cxx:168
TauRecToolBase.h
TauRecToolBase
The base class for all tau tools.
Definition: TauRecToolBase.h:21
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
TauJetRNNEvaluator::~TauJetRNNEvaluator
virtual ~TauJetRNNEvaluator()
Definition: TauJetRNNEvaluator.cxx:23
TauJetRNNEvaluator::m_max_clusters
Gaudi::Property< std::size_t > m_max_clusters
Definition: TauJetRNNEvaluator.h:60
TauJetRNN
Wrapper around lwtnn to compute the output score of a neural network.
Definition: TauJetRNN.h:34
TauJetRNNEvaluator::m_output_layer
Gaudi::Property< std::string > m_output_layer
Definition: TauJetRNNEvaluator.h:67
TauJetRNNEvaluator::m_max_cluster_dr
Gaudi::Property< float > m_max_cluster_dr
Definition: TauJetRNNEvaluator.h:61
TauJetRNNEvaluator::m_doVertexCorrection
Gaudi::Property< bool > m_doVertexCorrection
Definition: TauJetRNNEvaluator.h:62
TauJetRNNEvaluator::m_output_varname
Gaudi::Property< std::string > m_output_varname
Definition: TauJetRNNEvaluator.h:58
TauJetRNNEvaluator::get_rnn_2p
const TauJetRNN * get_rnn_2p() const
Definition: TauJetRNNEvaluator.cxx:160
TauJetRNNEvaluator::m_output_node
Gaudi::Property< std::string > m_output_node
Definition: TauJetRNNEvaluator.h:68
TauJetRNNEvaluator::get_rnn_1p
const TauJetRNN * get_rnn_1p() const
Definition: TauJetRNNEvaluator.cxx:156
TauJetRNNEvaluator::m_input_layer_tracks
Gaudi::Property< std::string > m_input_layer_tracks
Definition: TauJetRNNEvaluator.h:65
TauJetRNNEvaluator::m_input_layer_scalar
Gaudi::Property< std::string > m_input_layer_scalar
Definition: TauJetRNNEvaluator.h:64
TauJetRNNEvaluator::m_weightfile_0p
Gaudi::Property< std::string > m_weightfile_0p
Definition: TauJetRNNEvaluator.h:54
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
TauJetRNNEvaluator::m_net_1p
std::unique_ptr< TauJetRNN > m_net_1p
Definition: TauJetRNNEvaluator.h:72
TauJetRNNEvaluator::m_weightfile_2p
Gaudi::Property< std::string > m_weightfile_2p
Definition: TauJetRNNEvaluator.h:56
TauJetRNNEvaluator::m_net_0p
std::unique_ptr< TauJetRNN > m_net_0p
Definition: TauJetRNNEvaluator.h:71
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
TauJetRNNEvaluator::m_net_2p
std::unique_ptr< TauJetRNN > m_net_2p
Definition: TauJetRNNEvaluator.h:73
ITauToolBase
The base class for all tau tools.
Definition: ITauToolBase.h:30
TauJetRNNEvaluator::execute
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Definition: TauJetRNNEvaluator.cxx:116
TauJetRNNEvaluator::m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_1p
Definition: TauJetRNNEvaluator.h:55
TauJet.h
TauJetRNNEvaluator::TauJetRNNEvaluator
TauJetRNNEvaluator(const std::string &name="TauJetRNNEvaluator")
Definition: TauJetRNNEvaluator.cxx:14
TauJetRNNEvaluator::get_clusters
StatusCode get_clusters(const xAOD::TauJet &tau, std::vector< xAOD::CaloVertexedTopoCluster > &out) const
Definition: TauJetRNNEvaluator.cxx:202
TauJetRNNEvaluator::m_weightfile_3p
Gaudi::Property< std::string > m_weightfile_3p
Definition: TauJetRNNEvaluator.h:57
TauJetRNNEvaluator::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: TauJetRNNEvaluator.cxx:25
TauJetRNNEvaluator::m_input_layer_clusters
Gaudi::Property< std::string > m_input_layer_clusters
Definition: TauJetRNNEvaluator.h:66
TauJetRNNEvaluator::m_doTrackClassification
Gaudi::Property< bool > m_doTrackClassification
Definition: TauJetRNNEvaluator.h:63
TauJetRNNEvaluator
Tool to calculate a tau identification score based on neural networks.
Definition: TauJetRNNEvaluator.h:28
TauJetRNNEvaluator::m_net_3p
std::unique_ptr< TauJetRNN > m_net_3p
Definition: TauJetRNNEvaluator.h:74
TauJetRNNEvaluator::m_max_tracks
Gaudi::Property< std::size_t > m_max_tracks
Definition: TauJetRNNEvaluator.h:59