ATLAS Offline Software
Loading...
Searching...
No Matches
TauJetRNN.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef TAURECTOOLS_TAUJETRNN_H
6#define TAURECTOOLS_TAUJETRNN_H
7
8#include "xAODTau/TauJet.h"
10
12
13#include <memory>
14
15// Forward declaration
16namespace lwt {
17 class LightweightGraph;
18}
19
20namespace TauJetRNNUtils {
21 class VarCalc;
22}
23
35public:
36 // Configuration of the weight file structure
37 struct Config {
38 std::string input_layer_scalar;
39 std::string input_layer_tracks;
41 std::string output_layer;
42 std::string output_node;
43 };
44
45public:
46 // Construct a network from the .json specification created by the lwtnn
47 // converters (kerasfunc2json.py).
48 TauJetRNN(const std::string &filename, const Config &config, bool useTRT);
49 ~TauJetRNN();
50
51 // Compute the signal probability in [0, 1] or a default value
52 float compute(const xAOD::TauJet &tau,
53 const std::vector<const xAOD::TauTrack *> &tracks,
54 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const;
55
56 // Compute all input variables and store them in the maps that are passed by reference
58 const std::vector<const xAOD::TauTrack *> &tracks,
59 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
60 std::map<std::string, std::map<std::string, double>>& scalarInputs,
61 std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const;
62
63 // Getter for the variable calculator
65 return m_var_calc.get();
66 }
67
68 explicit operator bool() const {
69 return static_cast<bool>(m_graph);
70 }
71
72private:
73 // Abbreviations for lwtnn
74 using VariableMap = std::map<std::string, double>;
75 using VectorMap = std::map<std::string, std::vector<double>>;
76
77 using InputMap = std::map<std::string, VariableMap>;
78 using InputSequenceMap = std::map<std::string, VectorMap>;
79
80private:
82 std::unique_ptr<const lwt::LightweightGraph> m_graph;
83
84 // Names of the input variables
85 std::vector<std::string> m_scalar_inputs;
86 std::vector<std::string> m_track_inputs;
87 std::vector<std::string> m_cluster_inputs;
88
89 // Variable calculator to calculate input variables on the fly
90 std::unique_ptr<TauJetRNNUtils::VarCalc> m_var_calc;
91
92 bool m_useTRT = true;
93};
94
95#endif // TAURECTOOLS_TAUJETRNN_H
Evaluate cluster kinematics with a different vertex / signal state.
Tool to calculate input variables for the RNN-based tau identification.
std::map< std::string, double > VariableMap
Definition TauJetRNN.h:74
const Config m_config
Definition TauJetRNN.h:81
std::vector< std::string > m_track_inputs
Definition TauJetRNN.h:86
std::vector< std::string > m_scalar_inputs
Definition TauJetRNN.h:85
std::unique_ptr< TauJetRNNUtils::VarCalc > m_var_calc
Definition TauJetRNN.h:90
std::map< std::string, VectorMap > InputSequenceMap
Definition TauJetRNN.h:78
std::map< std::string, VariableMap > InputMap
Definition TauJetRNN.h:77
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double > > &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double > > > &vectorInputs) const
const TauJetRNNUtils::VarCalc * variable_calculator() const
Definition TauJetRNN.h:64
std::map< std::string, std::vector< double > > VectorMap
Definition TauJetRNN.h:75
bool m_useTRT
Definition TauJetRNN.h:92
std::unique_ptr< const lwt::LightweightGraph > m_graph
Definition TauJetRNN.h:82
std::vector< std::string > m_cluster_inputs
Definition TauJetRNN.h:87
TauJetRNN(const std::string &filename, const Config &config, bool useTRT)
Definition TauJetRNN.cxx:17
float compute(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
Definition TauJetRNN.cxx:91
Class mimicking the AthMessaging class from the offline software.
TauJet_v3 TauJet
Definition of the current "tau version".
std::string input_layer_clusters
Definition TauJetRNN.h:40
std::string input_layer_tracks
Definition TauJetRNN.h:39
std::string output_layer
Definition TauJetRNN.h:41
std::string input_layer_scalar
Definition TauJetRNN.h:38
std::string output_node
Definition TauJetRNN.h:42