ATLAS Offline Software
Classes | Public Member Functions | Public Attributes | Private Types | Private Member Functions | Private Attributes | List of all members
TauGNN Class Reference

Wrapper around ONNXUtil to compute the output score of a model. More...

#include <TauGNN.h>

Inheritance diagram for TauGNN:
Collaboration diagram for TauGNN:

Classes

struct  Config
 

Public Member Functions

 TauGNN (const std::string &nnFile, const Config &config)
 
 ~TauGNN ()
 
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute (const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
 
bool calculateInputVariables (const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
 
const TauGNNUtils::GNNVarCalcvariable_calculator () const
 
void setLevel (MSG::Level lvl)
 Change the current logging level. More...
 

Public Attributes

std::shared_ptr< const FlavorTagDiscriminants::OnnxUtil > m_onnxUtil
 
FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config
 

Private Types

using Inputs = FlavorTagDiscriminants::Inputs
 
using VariableMap = std::map< std::string, double >
 
using VectorMap = std::map< std::string, std::vector< double > >
 
using InputMap = std::map< std::string, VariableMap >
 
using InputSequenceMap = std::map< std::string, VectorMap >
 

Private Member Functions

void initMessaging () const
 Initialize our message level and MessageSvc. More...
 

Private Attributes

const Config m_config
 
std::vector< std::string > m_scalar_inputs
 
std::vector< std::string > m_track_inputs
 
std::vector< std::string > m_cluster_inputs
 
std::vector< std::string > m_scalarCalc_inputs
 
std::vector< std::string > m_trackCalc_inputs
 
std::vector< std::string > m_clusterCalc_inputs
 
std::unique_ptr< TauGNNUtils::GNNVarCalcm_var_calc
 
std::string m_nm
 Message source name. More...
 
boost::thread_specific_ptr< MsgStream > m_msg_tls
 MsgStream instance (a std::cout like with print-out levels) More...
 
std::atomic< IMessageSvc * > m_imsg { nullptr }
 MessageSvc pointer. More...
 
std::atomic< MSG::Level > m_lvl { MSG::NIL }
 Current logging level. More...
 
std::atomic_flag m_initialized ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
 Messaging initialized (initMessaging) More...
 

Detailed Description

Wrapper around ONNXUtil to compute the output score of a model.

Configures the network and computes the network outputs given the input objects. Retrieval of input variables is handled internally.

Author
N.M. Tamir

Definition at line 36 of file TauGNN.h.

Member Typedef Documentation

◆ InputMap

using TauGNN::InputMap = std::map<std::string, VariableMap>
private

Definition at line 81 of file TauGNN.h.

◆ Inputs

Definition at line 76 of file TauGNN.h.

◆ InputSequenceMap

using TauGNN::InputSequenceMap = std::map<std::string, VectorMap>
private

Definition at line 82 of file TauGNN.h.

◆ VariableMap

using TauGNN::VariableMap = std::map<std::string, double>
private

Definition at line 78 of file TauGNN.h.

◆ VectorMap

using TauGNN::VectorMap = std::map<std::string, std::vector<double> >
private

Definition at line 79 of file TauGNN.h.

Constructor & Destructor Documentation

◆ TauGNN()

TauGNN::TauGNN ( const std::string &  nnFile,
const Config config 
)

Definition at line 15 of file TauGNN.cxx.

15  :
16  asg::AsgMessaging("TauGNN"),
17  m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)),
19  {
20  //==================================================//
21  // This part is ported from FTagDiscriminant GNN.cxx//
22  //==================================================//
23 
24  // get the configuration of the model outputs
25  FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
26 
27  //Let's see the output!
28  for (const auto& out_node: gnn_output_config) {
29  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name);
30  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name);
31  if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name);
32  }
33 
34  //Get model config (for inputs)
35  auto lwtnn_config = m_onnxUtil->getLwtConfig();
36 
37  //===================================================//
38  // This part is ported from tauRecTools TauJetRNN.cxx//
39  //===================================================//
40 
41  // Search for input layer names specified in 'config'
42  auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
43  return in_node.name == config.input_layer_scalar;
44  };
45  auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
46  return in_node.name == config.input_layer_tracks;
47  };
48  auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
49  return in_node.name == config.input_layer_clusters;
50  };
51 
52  auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
53  lwtnn_config.inputs.cend(),
54  node_is_scalar);
55 
56  auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
57  lwtnn_config.input_sequences.cend(),
58  node_is_track);
59 
60  auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
61  lwtnn_config.input_sequences.cend(),
62  node_is_cluster);
63 
64  // Check which input layers were found
65  auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
66  auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
67  auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
68  if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!");
69  if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!");
70  if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!");
71 
72  // Fill the variable names of each input layer into the corresponding vector
73  if (has_scalar_node) {
74  for (const auto &in : scalar_node->variables) {
75  std::string name = in.name;
76  m_scalarCalc_inputs.push_back(name);
77  }
78  }
79 
80  if (has_track_node) {
81  for (const auto &in : track_node->variables) {
82  std::string name = in.name;
83  m_trackCalc_inputs.push_back(name);
84  }
85  }
86 
87  if (has_cluster_node) {
88  for (const auto &in : cluster_node->variables) {
89  std::string name = in.name;
90  m_clusterCalc_inputs.push_back(name);
91  }
92  }
93  // Load the variable calculator
95  ATH_MSG_INFO("TauGNN object initialized successfully!");
96 }

◆ ~TauGNN()

TauGNN::~TauGNN ( )

Definition at line 98 of file TauGNN.cxx.

98 {}

Member Function Documentation

◆ calculateInputVariables()

bool TauGNN::calculateInputVariables ( const xAOD::TauJet tau,
const std::vector< const xAOD::TauTrack * > &  tracks,
const std::vector< xAOD::CaloVertexedTopoCluster > &  clusters,
std::map< std::string, std::map< std::string, double >> &  scalarInputs,
std::map< std::string, std::map< std::string, std::vector< double >>> &  vectorInputs 
) const

Definition at line 167 of file TauGNN.cxx.

171  {
172  scalarInputs.clear();
173  vectorInputs.clear();
174  // Populate input (sequence) map with input variables
175  for (const auto &varname : m_scalarCalc_inputs) {
176  if (!m_var_calc->compute(varname, tau,
177  scalarInputs[m_config.input_layer_scalar][varname])) {
178  ATH_MSG_WARNING("Error computing '" << varname
179  << "' returning default");
180  return false;
181  }
182  }
183 
184  for (const auto &varname : m_trackCalc_inputs) {
185  if (!m_var_calc->compute(varname, tau, tracks,
186  vectorInputs[m_config.input_layer_tracks][varname])) {
187  ATH_MSG_WARNING("Error computing '" << varname
188  << "' returning default");
189  return false;
190  }
191  }
192 
193  for (const auto &varname : m_clusterCalc_inputs) {
194  if (!m_var_calc->compute(varname, tau, clusters,
195  vectorInputs[m_config.input_layer_clusters][varname])) {
196  ATH_MSG_WARNING("Error computing '" << varname
197  << "' returning default");
198  return false;
199  }
200  }
201  return true;
202 }

◆ compute()

std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > TauGNN::compute ( const xAOD::TauJet tau,
const std::vector< const xAOD::TauTrack * > &  tracks,
const std::vector< xAOD::CaloVertexedTopoCluster > &  clusters 
) const

Definition at line 104 of file TauGNN.cxx.

106  {
107  InputMap scalarInputs;
108  InputSequenceMap vectorInputs;
109  std::map<std::string, Inputs> gnn_input;
110  ATH_MSG_DEBUG("Starting compute...");
111  //Prepare input variables
112  if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
113  ATH_MSG_FATAL("Failed calculateInputVariables");
114  throw StatusCode::FAILURE;
115  }
116 
117  // Add TauJet-level features to the input
118  std::vector<float> tau_feats;
119  for (const auto &varname : m_scalarCalc_inputs) {
120  tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname]));
121  }
122  std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())};
123  Inputs tau_info (tau_feats, tau_feats_dim);
124  gnn_input.insert({"tau_vars", tau_info});
125 
126  //Add track-level features to the input
127  std::vector<float> trk_feats;
128  int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size());
129  int num_node_vars=static_cast<int>(m_trackCalc_inputs.size());
130  trk_feats.resize(num_nodes * num_node_vars);
131  int var_idx=0;
132  for (const auto &varname : m_trackCalc_inputs) {
133  for (int node_idx=0; node_idx<num_nodes; node_idx++){
134  trk_feats.at(node_idx*num_node_vars + var_idx)
135  = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx));
136  }
137  var_idx++;
138  }
139  std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
140  Inputs trk_info (trk_feats, trk_feats_dim);
141  gnn_input.insert({"track_vars", trk_info});
142 
143  //Add cluster-level features to the input
144  std::vector<float> cls_feats;
145  num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size());
146  num_node_vars=static_cast<int>(m_clusterCalc_inputs.size());
147  cls_feats.resize(num_nodes * num_node_vars);
148  var_idx=0;
149  for (const auto &varname : m_clusterCalc_inputs) {
150  for (int node_idx=0; node_idx<num_nodes; node_idx++){
151  cls_feats.at(node_idx*num_node_vars + var_idx)
152  = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx));
153  }
154  var_idx++;
155  }
156  std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
157  Inputs cls_info (cls_feats, cls_feats_dim);
158  gnn_input.insert({"cluster_vars", cls_info});
159 
160  //RUN THE INFERENCE!!!
161  ATH_MSG_DEBUG("Prepared inputs, running inference...");
162  auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input);
163  ATH_MSG_DEBUG("Finished compute!");
164  return std::make_tuple(out_f, out_vc, out_vf);
165 }

◆ initMessaging()

void AthMessaging::initMessaging ( ) const
privateinherited

Initialize our message level and MessageSvc.

This method should only be called once.

Definition at line 39 of file AthMessaging.cxx.

40 {
42  m_lvl = m_imsg ?
43  static_cast<MSG::Level>( m_imsg.load()->outputLevel(m_nm) ) :
44  MSG::INFO;
45 }

◆ msg() [1/2]

MsgStream & asg::AsgMessaging::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 49 of file AsgMessaging.cxx.

49  {
50 #ifndef XAOD_STANDALONE
52 #else // not XAOD_STANDALONE
53  return m_msg;
54 #endif // not XAOD_STANDALONE
55  }

◆ msg() [2/2]

MsgStream & asg::AsgMessaging::msg ( const MSG::Level  lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 57 of file AsgMessaging.cxx.

57  {
58 #ifndef XAOD_STANDALONE
60 #else // not XAOD_STANDALONE
61  m_msg << lvl;
62  return m_msg;
63 #endif // not XAOD_STANDALONE
64  }

◆ msgLvl()

bool asg::AsgMessaging::msgLvl ( const MSG::Level  lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 41 of file AsgMessaging.cxx.

41  {
42 #ifndef XAOD_STANDALONE
43  return ::AthMessaging::msgLvl( lvl );
44 #else // not XAOD_STANDALONE
45  return m_msg.msgLevel( lvl );
46 #endif // not XAOD_STANDALONE
47  }

◆ setLevel()

void AthMessaging::setLevel ( MSG::Level  lvl)
inherited

Change the current logging level.

Use this rather than msg().setLevel() for proper operation with MT.

Definition at line 28 of file AthMessaging.cxx.

29 {
30  m_lvl = lvl;
31 }

◆ variable_calculator()

const TauGNNUtils::GNNVarCalc* TauGNN::variable_calculator ( ) const
inline

Definition at line 68 of file TauGNN.h.

68  {
69  return m_var_calc.get();
70  }

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::atomic_flag m_initialized AthMessaging::ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
mutableprivateinherited

Messaging initialized (initMessaging)

Definition at line 141 of file AthMessaging.h.

◆ gnn_output_config

FlavorTagDiscriminants::OnnxUtil::OutputConfig TauGNN::gnn_output_config

Definition at line 73 of file TauGNN.h.

◆ m_cluster_inputs

std::vector<std::string> TauGNN::m_cluster_inputs
private

Definition at line 90 of file TauGNN.h.

◆ m_clusterCalc_inputs

std::vector<std::string> TauGNN::m_clusterCalc_inputs
private

Definition at line 94 of file TauGNN.h.

◆ m_config

const Config TauGNN::m_config
private

Definition at line 85 of file TauGNN.h.

◆ m_imsg

std::atomic<IMessageSvc*> AthMessaging::m_imsg { nullptr }
mutableprivateinherited

MessageSvc pointer.

Definition at line 135 of file AthMessaging.h.

◆ m_lvl

std::atomic<MSG::Level> AthMessaging::m_lvl { MSG::NIL }
mutableprivateinherited

Current logging level.

Definition at line 138 of file AthMessaging.h.

◆ m_msg_tls

boost::thread_specific_ptr<MsgStream> AthMessaging::m_msg_tls
mutableprivateinherited

MsgStream instance (a std::cout like with print-out levels)

Definition at line 132 of file AthMessaging.h.

◆ m_nm

std::string AthMessaging::m_nm
privateinherited

Message source name.

Definition at line 129 of file AthMessaging.h.

◆ m_onnxUtil

std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> TauGNN::m_onnxUtil

Definition at line 46 of file TauGNN.h.

◆ m_scalar_inputs

std::vector<std::string> TauGNN::m_scalar_inputs
private

Definition at line 88 of file TauGNN.h.

◆ m_scalarCalc_inputs

std::vector<std::string> TauGNN::m_scalarCalc_inputs
private

Definition at line 92 of file TauGNN.h.

◆ m_track_inputs

std::vector<std::string> TauGNN::m_track_inputs
private

Definition at line 89 of file TauGNN.h.

◆ m_trackCalc_inputs

std::vector<std::string> TauGNN::m_trackCalc_inputs
private

Definition at line 93 of file TauGNN.h.

◆ m_var_calc

std::unique_ptr<TauGNNUtils::GNNVarCalc> TauGNN::m_var_calc
private

Definition at line 97 of file TauGNN.h.


The documentation for this class was generated from the following files:
InputSequenceMap
std::map< std::string, VectorMap > InputSequenceMap
Definition: TauDecayModeNNClassifier.cxx:25
AthMessaging::m_lvl
std::atomic< MSG::Level > m_lvl
Current logging level.
Definition: AthMessaging.h:138
TauGNN::m_config
const Config m_config
Definition: TauGNN.h:85
FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR
@ VECCHAR
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
TauGNN::Config::input_layer_tracks
std::string input_layer_tracks
Definition: TauGNN.h:41
TauGNN::m_scalarCalc_inputs
std::vector< std::string > m_scalarCalc_inputs
Definition: TauGNN.h:92
TauGNNUtils::get_calculator
std::unique_ptr< GNNVarCalc > get_calculator(const std::vector< std::string > &scalar_vars, const std::vector< std::string > &track_vars, const std::vector< std::string > &cluster_vars)
Definition: TauGNNUtils.cxx:113
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT
@ FLOAT
TauGNN::Inputs
FlavorTagDiscriminants::Inputs Inputs
Definition: TauGNN.h:76
TauGNN::m_clusterCalc_inputs
std::vector< std::string > m_clusterCalc_inputs
Definition: TauGNN.h:94
TauGNN::gnn_output_config
FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config
Definition: TauGNN.h:73
FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT
@ VECFLOAT
AthMessaging::m_imsg
std::atomic< IMessageSvc * > m_imsg
MessageSvc pointer.
Definition: AthMessaging.h:135
InputMap
std::map< std::string, ValueMap > InputMap
Definition: TauDecayModeNNClassifier.cxx:24
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TrigConf::MSGTC::Level
Level
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:21
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
TauGNN::m_var_calc
std::unique_ptr< TauGNNUtils::GNNVarCalc > m_var_calc
Definition: TauGNN.h:97
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
TauGNN::m_trackCalc_inputs
std::vector< std::string > m_trackCalc_inputs
Definition: TauGNN.h:93
TauGNN::Config::input_layer_clusters
std::string input_layer_clusters
Definition: TauGNN.h:42
TauGNN::calculateInputVariables
bool calculateInputVariables(const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const
Definition: TauGNN.cxx:167
LArG4AODNtuplePlotter.varname
def varname(hname)
Definition: LArG4AODNtuplePlotter.py:37
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
AthMessaging::m_nm
std::string m_nm
Message source name.
Definition: AthMessaging.h:129
TauGNN::m_onnxUtil
std::shared_ptr< const FlavorTagDiscriminants::OnnxUtil > m_onnxUtil
Definition: TauGNN.h:46
config
std::vector< std::string > config
Definition: fbtTestBasics.cxx:74
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNN::Config::input_layer_scalar
std::string input_layer_scalar
Definition: TauGNN.h:40
python.AutoConfigFlags.msg
msg
Definition: AutoConfigFlags.py:7