ATLAS Offline Software
Public Member Functions | Public Attributes | Private Member Functions | Private Attributes | List of all members
TauGNN Class Reference

Wrapper around SaltModel to compute the output score of a model. More...

#include <TauGNN.h>

Inheritance diagram for TauGNN:
Collaboration diagram for TauGNN:

Public Member Functions

 TauGNN (const TauGNNDataLoader::Config &config)
 
 ~TauGNN ()=default
 
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > compute (const xAOD::TauJet &tau) const
 
void setLevel (MSG::Level lvl)
 Change the current logging level. More...
 

Public Attributes

std::shared_ptr< const FlavorTagInference::SaltModel > m_saltModel
 
TauGNNDataLoader m_dataloader
 
FlavorTagInference::OutputConfig gnn_output_config
 

Private Member Functions

void initMessaging () const
 Initialize our message level and MessageSvc. More...
 

Private Attributes

std::string m_nm
 Message source name. More...
 
boost::thread_specific_ptr< MsgStream > m_msg_tls
 MsgStream instance (a std::cout like with print-out levels) More...
 
std::atomic< IMessageSvc * > m_imsg { nullptr }
 MessageSvc pointer. More...
 
std::atomic< MSG::Level > m_lvl { MSG::NIL }
 Current logging level. More...
 
std::atomic_flag m_initialized ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
 Messaging initialized (initMessaging) More...
 

Detailed Description

Wrapper around SaltModel to compute the output score of a model.

Configures the network and computes the network outputs given the input objects. Retrieval of input variables is handled internally.

Author
N.M. Tamir

Definition at line 36 of file TauGNN.h.

Constructor & Destructor Documentation

◆ TauGNN()

TauGNN::TauGNN ( const TauGNNDataLoader::Config config)

Definition at line 9 of file TauGNN.cxx.

9  :
10  asg::AsgMessaging("TauGNN"),
11  m_saltModel(std::make_shared<FlavorTagInference::SaltModel>(config.nnFile)),
13 {
14  ATH_MSG_INFO("TauGNN object initialized successfully!");
15 }

◆ ~TauGNN()

TauGNN::~TauGNN ( )
default

Member Function Documentation

◆ compute()

std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > TauGNN::compute ( const xAOD::TauJet tau) const

Definition at line 21 of file TauGNN.cxx.

21  {
22  ATH_MSG_DEBUG("Computing TauGNN features...");
23  auto salt_model_input_data = m_dataloader.loadInputs(&tau);
24  // m_dataloader.DumpGnnInputs(salt_model_input_data.gnn_inputs);
25  ATH_MSG_DEBUG("Running inference...");
26  auto [out_f, out_vc, out_vf] = m_saltModel->runInference(salt_model_input_data.gnn_inputs);
27  ATH_MSG_DEBUG("Inference done.");
28  return std::make_tuple(out_f, out_vc, out_vf);
29 }

◆ initMessaging()

void AthMessaging::initMessaging ( ) const
privateinherited

Initialize our message level and MessageSvc.

This method should only be called once.

Definition at line 39 of file AthMessaging.cxx.

40 {
42  m_lvl = m_imsg ?
43  static_cast<MSG::Level>( m_imsg.load()->outputLevel(m_nm) ) :
44  MSG::INFO;
45 }

◆ msg() [1/2]

MsgStream & asg::AsgMessaging::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 49 of file AsgMessaging.cxx.

49  {
50 #ifndef XAOD_STANDALONE
52 #else // not XAOD_STANDALONE
53  return m_msg;
54 #endif // not XAOD_STANDALONE
55  }

◆ msg() [2/2]

MsgStream & asg::AsgMessaging::msg ( const MSG::Level  lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 57 of file AsgMessaging.cxx.

57  {
58 #ifndef XAOD_STANDALONE
60 #else // not XAOD_STANDALONE
61  m_msg << lvl;
62  return m_msg;
63 #endif // not XAOD_STANDALONE
64  }

◆ msgLvl()

bool asg::AsgMessaging::msgLvl ( const MSG::Level  lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 41 of file AsgMessaging.cxx.

41  {
42 #ifndef XAOD_STANDALONE
43  return ::AthMessaging::msgLvl( lvl );
44 #else // not XAOD_STANDALONE
45  return m_msg.msgLevel( lvl );
46 #endif // not XAOD_STANDALONE
47  }

◆ setLevel()

void AthMessaging::setLevel ( MSG::Level  lvl)
inherited

Change the current logging level.

Use this rather than msg().setLevel() for proper operation with MT.

Definition at line 28 of file AthMessaging.cxx.

29 {
30  m_lvl = lvl;
31 }

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::atomic_flag m_initialized AthMessaging::ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
mutableprivateinherited

Messaging initialized (initMessaging)

Definition at line 141 of file AthMessaging.h.

◆ gnn_output_config

FlavorTagInference::OutputConfig TauGNN::gnn_output_config

Definition at line 54 of file TauGNN.h.

◆ m_dataloader

TauGNNDataLoader TauGNN::m_dataloader

Definition at line 39 of file TauGNN.h.

◆ m_imsg

std::atomic<IMessageSvc*> AthMessaging::m_imsg { nullptr }
mutableprivateinherited

MessageSvc pointer.

Definition at line 135 of file AthMessaging.h.

◆ m_lvl

std::atomic<MSG::Level> AthMessaging::m_lvl { MSG::NIL }
mutableprivateinherited

Current logging level.

Definition at line 138 of file AthMessaging.h.

◆ m_msg_tls

boost::thread_specific_ptr<MsgStream> AthMessaging::m_msg_tls
mutableprivateinherited

MsgStream instance (a std::cout like with print-out levels)

Definition at line 132 of file AthMessaging.h.

◆ m_nm

std::string AthMessaging::m_nm
privateinherited

Message source name.

Definition at line 129 of file AthMessaging.h.

◆ m_saltModel

std::shared_ptr<const FlavorTagInference::SaltModel> TauGNN::m_saltModel

Definition at line 38 of file TauGNN.h.


The documentation for this class was generated from the following files:
AthMessaging::m_lvl
std::atomic< MSG::Level > m_lvl
Current logging level.
Definition: AthMessaging.h:138
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthMessaging::m_imsg
std::atomic< IMessageSvc * > m_imsg
MessageSvc pointer.
Definition: AthMessaging.h:135
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
TrigConf::MSGTC::Level
Level
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:21
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
TauGNNDataLoader
Definition: TauGNNDataLoader.h:57
TauGNN::m_dataloader
TauGNNDataLoader m_dataloader
Definition: TauGNN.h:39
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
TauGNN::m_saltModel
std::shared_ptr< const FlavorTagInference::SaltModel > m_saltModel
Definition: TauGNN.h:38
python.Constants.INFO
int INFO
Definition: Control/AthenaCommon/python/Constants.py:15
AthMessaging::m_nm
std::string m_nm
Message source name.
Definition: AthMessaging.h:129
FlavorTagInference::SaltModelEDMLoaderBase::loadInputs
virtual SaltModelData loadInputs(const xAOD::IParticle *p) const final
Definition: SaltModelEDMLoaderBase.h:49
python.AutoConfigFlags.msg
msg
Definition: AutoConfigFlags.py:7