ATLAS Offline Software
Classes | Public Member Functions | Private Types | Private Member Functions | Private Attributes | List of all members
TauJetRNN Class Reference

Wrapper around lwtnn to compute the output score of a neural network. More...

#include <TauJetRNN.h>

Inheritance diagram for TauJetRNN:
Collaboration diagram for TauJetRNN:

Classes

struct  Config
 

Public Member Functions

 TauJetRNN (const std::string &filename, const Config &config)
 
 ~TauJetRNN ()
 
float compute (const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
 
bool calculateInputVariables (const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
 
const TauJetRNNUtils::VarCalcvariable_calculator () const
 
 operator bool () const
 
void setLevel (MSG::Level lvl)
 Change the current logging level. More...
 

Private Types

using VariableMap = std::map< std::string, double >
 
using VectorMap = std::map< std::string, std::vector< double > >
 
using InputMap = std::map< std::string, VariableMap >
 
using InputSequenceMap = std::map< std::string, VectorMap >
 

Private Member Functions

void initMessaging () const
 Initialize our message level and MessageSvc. More...
 

Private Attributes

const Config m_config
 
std::unique_ptr< const lwt::LightweightGraph > m_graph
 
std::vector< std::string > m_scalar_inputs
 
std::vector< std::string > m_track_inputs
 
std::vector< std::string > m_cluster_inputs
 
std::unique_ptr< TauJetRNNUtils::VarCalcm_var_calc
 
std::string m_nm
 Message source name. More...
 
boost::thread_specific_ptr< MsgStream > m_msg_tls
 MsgStream instance (a std::cout like with print-out levels) More...
 
std::atomic< IMessageSvc * > m_imsg { nullptr }
 MessageSvc pointer. More...
 
std::atomic< MSG::Level > m_lvl { MSG::NIL }
 Current logging level. More...
 
std::atomic_flag m_initialized ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
 Messaging initialized (initMessaging) More...
 

Detailed Description

Wrapper around lwtnn to compute the output score of a neural network.

Configures the network and computes the network outputs given the input objects. Retrieval of input variables is handled internally.

Author
C. Deutsch
W. Davey

Definition at line 34 of file TauJetRNN.h.

Member Typedef Documentation

◆ InputMap

using TauJetRNN::InputMap = std::map<std::string, VariableMap>
private

Definition at line 77 of file TauJetRNN.h.

◆ InputSequenceMap

using TauJetRNN::InputSequenceMap = std::map<std::string, VectorMap>
private

Definition at line 78 of file TauJetRNN.h.

◆ VariableMap

using TauJetRNN::VariableMap = std::map<std::string, double>
private

Definition at line 74 of file TauJetRNN.h.

◆ VectorMap

using TauJetRNN::VectorMap = std::map<std::string, std::vector<double> >
private

Definition at line 75 of file TauJetRNN.h.

Constructor & Destructor Documentation

◆ TauJetRNN()

TauJetRNN::TauJetRNN ( const std::string &  filename,
const Config config 
)

Definition at line 17 of file TauJetRNN.cxx.

18  : asg::AsgMessaging("TauJetRNN"), m_config(config), m_graph(nullptr) {
19  // Load the json file defining the network
20  std::ifstream input_file(filename);
21  lwt::GraphConfig lwtnn_config;
22  try {
23  lwtnn_config = lwt::parse_json_graph(input_file);
24  } catch (const std::logic_error &e) {
25  ATH_MSG_ERROR("Error parsing network config: " << e.what());
26  throw;
27  }
28 
29  // Search for input layer names specified in 'config'
30  auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
31  return in_node.name == config.input_layer_scalar;
32  };
33  auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
34  return in_node.name == config.input_layer_tracks;
35  };
36  auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
37  return in_node.name == config.input_layer_clusters;
38  };
39 
40  auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
41  lwtnn_config.inputs.cend(),
42  node_is_scalar);
43 
44  auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
45  lwtnn_config.input_sequences.cend(),
46  node_is_track);
47 
48  auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
49  lwtnn_config.input_sequences.cend(),
50  node_is_cluster);
51 
52  // Check which input layers were found
53  auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
54  auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
55  auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
56 
57  // Fill the variable names of each input layer into the corresponding vector
58  if (has_scalar_node) {
59  for (const auto &in : scalar_node->variables) {
60  m_scalar_inputs.push_back(in.name);
61  }
62  }
63 
64  if (has_track_node) {
65  for (const auto &in : track_node->variables) {
66  m_track_inputs.push_back(in.name);
67  }
68  }
69 
70  if (has_cluster_node) {
71  for (const auto &in : cluster_node->variables) {
72  m_cluster_inputs.push_back(in.name);
73  }
74  }
75 
76  // Configure the network
77  try {
78  m_graph = std::make_unique<lwt::LightweightGraph>(
79  lwtnn_config, config.output_layer);
80  } catch (const lwt::NNConfigurationException &e) {
81  ATH_MSG_ERROR(e.what());
82  throw;
83  }
84 
85  // Load the variable calculator
87 }

◆ ~TauJetRNN()

TauJetRNN::~TauJetRNN ( )

Definition at line 89 of file TauJetRNN.cxx.

89 {}

Member Function Documentation

◆ calculateInputVariables()

bool TauJetRNN::calculateInputVariables ( const xAOD::TauJet tau,
const std::vector< const xAOD::TauTrack * > &  tracks,
const std::vector< xAOD::CaloVertexedTopoCluster > &  clusters,
std::map< std::string, std::map< std::string, double >> &  scalarInputs,
std::map< std::string, std::map< std::string, std::vector< double >>> &  vectorInputs 
) const

Definition at line 105 of file TauJetRNN.cxx.

109  {
110  scalarInputs.clear();
111  vectorInputs.clear();
112  // Populate input (sequence) map with input variables
113  for (const auto &varname : m_scalar_inputs) {
114  if (!m_var_calc->compute(varname, tau,
115  scalarInputs[m_config.input_layer_scalar][varname])) {
116  ATH_MSG_WARNING("Error computing '" << varname
117  << "' returning default");
118  return false;
119  }
120  }
121 
122  for (const auto &varname : m_track_inputs) {
123  if (!m_var_calc->compute(varname, tau, tracks,
124  vectorInputs[m_config.input_layer_tracks][varname])) {
125  ATH_MSG_WARNING("Error computing '" << varname
126  << "' returning default");
127  return false;
128  }
129  }
130 
131  for (const auto &varname : m_cluster_inputs) {
132  if (!m_var_calc->compute(varname, tau, clusters,
133  vectorInputs[m_config.input_layer_clusters][varname])) {
134  ATH_MSG_WARNING("Error computing '" << varname
135  << "' returning default");
136  return false;
137  }
138  }
139  return true;
140 }

◆ compute()

float TauJetRNN::compute ( const xAOD::TauJet tau,
const std::vector< const xAOD::TauTrack * > &  tracks,
const std::vector< xAOD::CaloVertexedTopoCluster > &  clusters 
) const

Definition at line 91 of file TauJetRNN.cxx.

93  {
94  InputMap scalarInputs;
95  InputSequenceMap vectorInputs;
96  if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
97  return -1111.0;
98  }
99  // Compute the network outputs
100  const auto outputs = m_graph->compute(scalarInputs, vectorInputs);
101  // Return value of the output neuron
102  return outputs.at(m_config.output_node);
103 }

◆ initMessaging()

void AthMessaging::initMessaging ( ) const
privateinherited

Initialize our message level and MessageSvc.

This method should only be called once.

Definition at line 39 of file AthMessaging.cxx.

40 {
42  m_lvl = m_imsg ?
43  static_cast<MSG::Level>( m_imsg.load()->outputLevel(m_nm) ) :
44  MSG::INFO;
45 }

◆ msg() [1/2]

MsgStream & asg::AsgMessaging::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 49 of file AsgMessaging.cxx.

49  {
50 #ifndef XAOD_STANDALONE
52 #else // not XAOD_STANDALONE
53  return m_msg;
54 #endif // not XAOD_STANDALONE
55  }

◆ msg() [2/2]

MsgStream & asg::AsgMessaging::msg ( const MSG::Level  lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 57 of file AsgMessaging.cxx.

57  {
58 #ifndef XAOD_STANDALONE
60 #else // not XAOD_STANDALONE
61  m_msg << lvl;
62  return m_msg;
63 #endif // not XAOD_STANDALONE
64  }

◆ msgLvl()

bool asg::AsgMessaging::msgLvl ( const MSG::Level  lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 41 of file AsgMessaging.cxx.

41  {
42 #ifndef XAOD_STANDALONE
43  return ::AthMessaging::msgLvl( lvl );
44 #else // not XAOD_STANDALONE
45  return m_msg.msgLevel( lvl );
46 #endif // not XAOD_STANDALONE
47  }

◆ operator bool()

TauJetRNN::operator bool ( ) const
inlineexplicit

Definition at line 68 of file TauJetRNN.h.

68  {
69  return static_cast<bool>(m_graph);
70  }

◆ setLevel()

void AthMessaging::setLevel ( MSG::Level  lvl)
inherited

Change the current logging level.

Use this rather than msg().setLevel() for proper operation with MT.

Definition at line 28 of file AthMessaging.cxx.

29 {
30  m_lvl = lvl;
31 }

◆ variable_calculator()

const TauJetRNNUtils::VarCalc* TauJetRNN::variable_calculator ( ) const
inline

Definition at line 64 of file TauJetRNN.h.

64  {
65  return m_var_calc.get();
66  }

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::atomic_flag m_initialized AthMessaging::ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
mutableprivateinherited

Messaging initialized (initMessaging)

Definition at line 141 of file AthMessaging.h.

◆ m_cluster_inputs

std::vector<std::string> TauJetRNN::m_cluster_inputs
private

Definition at line 87 of file TauJetRNN.h.

◆ m_config

const Config TauJetRNN::m_config
private

Definition at line 81 of file TauJetRNN.h.

◆ m_graph

std::unique_ptr<const lwt::LightweightGraph> TauJetRNN::m_graph
private

Definition at line 82 of file TauJetRNN.h.

◆ m_imsg

std::atomic<IMessageSvc*> AthMessaging::m_imsg { nullptr }
mutableprivateinherited

MessageSvc pointer.

Definition at line 135 of file AthMessaging.h.

◆ m_lvl

std::atomic<MSG::Level> AthMessaging::m_lvl { MSG::NIL }
mutableprivateinherited

Current logging level.

Definition at line 138 of file AthMessaging.h.

◆ m_msg_tls

boost::thread_specific_ptr<MsgStream> AthMessaging::m_msg_tls
mutableprivateinherited

MsgStream instance (a std::cout like with print-out levels)

Definition at line 132 of file AthMessaging.h.

◆ m_nm

std::string AthMessaging::m_nm
privateinherited

Message source name.

Definition at line 129 of file AthMessaging.h.

◆ m_scalar_inputs

std::vector<std::string> TauJetRNN::m_scalar_inputs
private

Definition at line 85 of file TauJetRNN.h.

◆ m_track_inputs

std::vector<std::string> TauJetRNN::m_track_inputs
private

Definition at line 86 of file TauJetRNN.h.

◆ m_var_calc

std::unique_ptr<TauJetRNNUtils::VarCalc> TauJetRNN::m_var_calc
private

Definition at line 90 of file TauJetRNN.h.


The documentation for this class was generated from the following files:
InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauDecayModeNNClassifier.cxx:25
AthMessaging::m_lvl
std::atomic< MSG::Level > m_lvl
Current logging level.
Definition: AthMessaging.h:138
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
TauJetRNN::m_graph
std::unique_ptr< const lwt::LightweightGraph > m_graph
Definition: TauJetRNN.h:82
TauJetRNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauJetRNN.cxx:105
python.resample_meson.input_file
input_file
Definition: resample_meson.py:164
TauJetRNN::m_config
const Config m_config
Definition: TauJetRNN.h:81
TauJetRNN::Config::output_node
std::string output_node
Definition: TauJetRNN.h:42
AthMessaging::m_imsg
std::atomic< IMessageSvc * > m_imsg
MessageSvc pointer.
Definition: AthMessaging.h:135
InputMap
std::map< std::string, ValueMap > InputMap
Definition: TauDecayModeNNClassifier.cxx:24
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
TauJetRNN::m_track_inputs
std::vector< std::string > m_track_inputs
Definition: TauJetRNN.h:86
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TrigConf::MSGTC::Level
Level
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:21
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
TauJetRNN::m_scalar_inputs
std::vector< std::string > m_scalar_inputs
Definition: TauJetRNN.h:85
TauJetRNNUtils::get_calculator
std::unique_ptr< VarCalc > get_calculator(const std::vector< std::string > &scalar_vars, const std::vector< std::string > &track_vars, const std::vector< std::string > &cluster_vars)
Definition: TauJetRNNUtils.cxx:110
TauJetRNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauJetRNN.h:39
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
lwtDev::parse_json_graph
GraphConfig parse_json_graph(std::istream &json)
Definition: parse_json.cxx:71
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
TauJetRNN::m_cluster_inputs
std::vector< std::string > m_cluster_inputs
Definition: TauJetRNN.h:87
AthMessaging::m_nm
std::string m_nm
Message source name.
Definition: AthMessaging.h:129
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:74
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
TauJetRNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauJetRNN.h:40
TauJetRNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauJetRNN.h:38
TauJetRNN::m_var_calc
std::unique_ptr< TauJetRNNUtils::VarCalc > m_var_calc
Definition: TauJetRNN.h:90
python.AutoConfigFlags.msg
msg
Definition: AutoConfigFlags.py:7