ATLAS Offline Software
|
Wrapper around ONNXUtil to compute the output score of a model. More...
#include <TauGNN.h>
Classes | |
struct | Config |
Public Member Functions | |
TauGNN (const std::string &nnFile, const Config &config) | |
~TauGNN () | |
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > | compute (const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const |
bool | calculateInputVariables (const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters, std::map< std::string, std::map< std::string, double >> &scalarInputs, std::map< std::string, std::map< std::string, std::vector< double >>> &vectorInputs) const |
const TauGNNUtils::GNNVarCalc * | variable_calculator () const |
void | setLevel (MSG::Level lvl) |
Change the current logging level. More... | |
Public Attributes | |
std::shared_ptr< const FlavorTagDiscriminants::OnnxUtil > | m_onnxUtil |
FlavorTagDiscriminants::OnnxUtil::OutputConfig | gnn_output_config |
Private Types | |
using | Inputs = FlavorTagDiscriminants::Inputs |
using | VariableMap = std::map< std::string, double > |
using | VectorMap = std::map< std::string, std::vector< double > > |
using | InputMap = std::map< std::string, VariableMap > |
using | InputSequenceMap = std::map< std::string, VectorMap > |
Private Member Functions | |
void | initMessaging () const |
Initialize our message level and MessageSvc. More... | |
Private Attributes | |
const Config | m_config |
std::vector< std::string > | m_scalar_inputs |
std::vector< std::string > | m_track_inputs |
std::vector< std::string > | m_cluster_inputs |
std::vector< std::string > | m_scalarCalc_inputs |
std::vector< std::string > | m_trackCalc_inputs |
std::vector< std::string > | m_clusterCalc_inputs |
std::unique_ptr< TauGNNUtils::GNNVarCalc > | m_var_calc |
std::string | m_nm |
Message source name. More... | |
boost::thread_specific_ptr< MsgStream > | m_msg_tls |
MsgStream instance (a std::cout like with print-out levels) More... | |
std::atomic< IMessageSvc * > | m_imsg { nullptr } |
MessageSvc pointer. More... | |
std::atomic< MSG::Level > | m_lvl { MSG::NIL } |
Current logging level. More... | |
std::atomic_flag m_initialized | ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT |
Messaging initialized (initMessaging) More... | |
Wrapper around ONNXUtil to compute the output score of a model.
Configures the network and computes the network outputs given the input objects. Retrieval of input variables is handled internally.
|
private |
|
private |
|
private |
|
private |
|
private |
Definition at line 15 of file TauGNN.cxx.
TauGNN::~TauGNN | ( | ) |
Definition at line 98 of file TauGNN.cxx.
bool TauGNN::calculateInputVariables | ( | const xAOD::TauJet & | tau, |
const std::vector< const xAOD::TauTrack * > & | tracks, | ||
const std::vector< xAOD::CaloVertexedTopoCluster > & | clusters, | ||
std::map< std::string, std::map< std::string, double >> & | scalarInputs, | ||
std::map< std::string, std::map< std::string, std::vector< double >>> & | vectorInputs | ||
) | const |
Definition at line 167 of file TauGNN.cxx.
std::tuple< std::map< std::string, float >, std::map< std::string, std::vector< char > >, std::map< std::string, std::vector< float > > > TauGNN::compute | ( | const xAOD::TauJet & | tau, |
const std::vector< const xAOD::TauTrack * > & | tracks, | ||
const std::vector< xAOD::CaloVertexedTopoCluster > & | clusters | ||
) | const |
Definition at line 104 of file TauGNN.cxx.
|
privateinherited |
Initialize our message level and MessageSvc.
This method should only be called once.
Definition at line 39 of file AthMessaging.cxx.
|
inherited |
The standard message stream.
Definition at line 49 of file AsgMessaging.cxx.
|
inherited |
The standard message stream.
lvl | The message level to set the stream to |
Definition at line 57 of file AsgMessaging.cxx.
|
inherited |
Test the output level of the object.
lvl | The message level to test against |
true
If messages at level "lvl" will be printed Definition at line 41 of file AsgMessaging.cxx.
|
inherited |
Change the current logging level.
Use this rather than msg().setLevel() for proper operation with MT.
Definition at line 28 of file AthMessaging.cxx.
|
inline |
|
mutableprivateinherited |
Messaging initialized (initMessaging)
Definition at line 141 of file AthMessaging.h.
FlavorTagDiscriminants::OnnxUtil::OutputConfig TauGNN::gnn_output_config |
|
private |
|
mutableprivateinherited |
MessageSvc pointer.
Definition at line 135 of file AthMessaging.h.
|
mutableprivateinherited |
Current logging level.
Definition at line 138 of file AthMessaging.h.
|
mutableprivateinherited |
MsgStream instance (a std::cout like with print-out levels)
Definition at line 132 of file AthMessaging.h.
|
privateinherited |
Message source name.
Definition at line 129 of file AthMessaging.h.
std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> TauGNN::m_onnxUtil |
|
private |
|
private |