ATLAS Offline Software
TauJetRNN.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef TAURECTOOLS_TAUJETRNN_H
6 #define TAURECTOOLS_TAUJETRNN_H
7 
8 #include "xAODTau/TauJet.h"
10 
12 
13 #include <memory>
14 
15 // Forward declaration
16 namespace lwt {
17  class LightweightGraph;
18 }
19 
20 namespace TauJetRNNUtils {
21  class VarCalc;
22 }
23 
34 class TauJetRNN : public asg::AsgMessaging {
35 public:
36  // Configuration of the weight file structure
37  struct Config {
38  std::string input_layer_scalar;
39  std::string input_layer_tracks;
40  std::string input_layer_clusters;
41  std::string output_layer;
42  std::string output_node;
43  };
44 
45 public:
46  // Construct a network from the .json specification created by the lwtnn
47  // converters (kerasfunc2json.py).
48  TauJetRNN(const std::string &filename, const Config &config);
49  ~TauJetRNN();
50 
51  // Compute the signal probability in [0, 1] or a default value
52  float compute(const xAOD::TauJet &tau,
53  const std::vector<const xAOD::TauTrack *> &tracks,
54  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const;
55 
56  // Compute all input variables and store them in the maps that are passed by reference
57  bool calculateInputVariables(const xAOD::TauJet &tau,
58  const std::vector<const xAOD::TauTrack *> &tracks,
59  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
60  std::map<std::string, std::map<std::string, double>>& scalarInputs,
61  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const;
62 
63  // Getter for the variable calculator
65  return m_var_calc.get();
66  }
67 
68  explicit operator bool() const {
69  return static_cast<bool>(m_graph);
70  }
71 
72 private:
73  // Abbreviations for lwtnn
74  using VariableMap = std::map<std::string, double>;
75  using VectorMap = std::map<std::string, std::vector<double>>;
76 
77  using InputMap = std::map<std::string, VariableMap>;
78  using InputSequenceMap = std::map<std::string, VectorMap>;
79 
80 private:
82  std::unique_ptr<const lwt::LightweightGraph> m_graph;
83 
84  // Names of the input variables
85  std::vector<std::string> m_scalar_inputs;
86  std::vector<std::string> m_track_inputs;
87  std::vector<std::string> m_cluster_inputs;
88 
89  // Variable calculator to calculate input variables on the fly
90  std::unique_ptr<TauJetRNNUtils::VarCalc> m_var_calc;
91 };
92 
93 #endif // TAURECTOOLS_TAUJETRNN_H
TauJetRNN::InputMap
std::map< std::string, VariableMap > InputMap
Definition: TauJetRNN.h:77
TauJetRNNUtils::VarCalc
Tool to calculate input variables for the RNN-based tau identification.
Definition: TauJetRNNUtils.h:29
TauJetRNN::compute
float compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition: TauJetRNN.cxx:91
TauJetRNN::m_graph
std::unique_ptr< const lwt::LightweightGraph > m_graph
Definition: TauJetRNN.h:82
TauJetRNN::TauJetRNN
TauJetRNN(const std::string &filename, const Config &config)
Definition: TauJetRNN.cxx:17
TauJetRNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauJetRNN.cxx:105
TauJetRNN::m_config
const Config m_config
Definition: TauJetRNN.h:81
TauJetRNN::Config::output_node
std::string output_node
Definition: TauJetRNN.h:42
TauJetRNN
Wrapper around lwtnn to compute the output score of a neural network.
Definition: TauJetRNN.h:34
TauJetRNN::m_track_inputs
std::vector< std::string > m_track_inputs
Definition: TauJetRNN.h:86
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TauJetRNN::variable_calculator
const TauJetRNNUtils::VarCalc * variable_calculator() const
Definition: TauJetRNN.h:64
TauJetRNN::VariableMap
std::map< std::string, double > VariableMap
Definition: TauJetRNN.h:74
AsgMessaging.h
TauJetRNNUtils
Definition: TauJetRNNUtils.cxx:10
TauJetRNN::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauJetRNN.h:75
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
lwt
Definition: NnClusterizationFactory.h:52
TauJetRNN::m_scalar_inputs
std::vector< std::string > m_scalar_inputs
Definition: TauJetRNN.h:85
TauJetRNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauJetRNN.h:39
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
CaloVertexedTopoCluster.h
Evaluate cluster kinematics with a different vertex / signal state.
TauJetRNN::InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauJetRNN.h:78
TauJetRNN::~TauJetRNN
~TauJetRNN()
Definition: TauJetRNN.cxx:89
TauJetRNN::m_cluster_inputs
std::vector< std::string > m_cluster_inputs
Definition: TauJetRNN.h:87
TauJet.h
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
TauJetRNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauJetRNN.h:40
xAOD::bool
setBGCode setTAP setLVL2ErrorBits bool
Definition: TrigDecision_v1.cxx:60
TauJetRNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauJetRNN.h:38
TauJetRNN::m_var_calc
std::unique_ptr< TauJetRNNUtils::VarCalc > m_var_calc
Definition: TauJetRNN.h:90
TauJetRNN::Config
Definition: TauJetRNN.h:37
TauJetRNN::Config::output_layer
std::string output_layer
Definition: TauJetRNN.h:41