ATLAS Offline Software
Loading...
Searching...
No Matches
TauGNNDataLoader Class Reference

#include <TauGNNDataLoader.h>

Inheritance diagram for TauGNNDataLoader:
Collaboration diagram for TauGNNDataLoader:

Classes

struct  Config

Public Member Functions

 TauGNNDataLoader (std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
 ~TauGNNDataLoader ()=default
void addScalarLoader (const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
void addVectorLoader (const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
virtual SaltModelData loadInputs (const xAOD::IParticle *p) const final
void DumpGnnInputs (const SaltModelInputs &gnn_inputs) const
void setLevel (MSG::Level lvl)
 Change the current logging level.
Functions providing the same interface as AthMessaging
bool msgLvl (const MSG::Level lvl) const
 Test the output level of the object.
MsgStream & msg () const
 The standard message stream.
MsgStream & msg (const MSG::Level lvl) const
 The standard message stream.

Public Attributes

SaltModelGraphConfig::GraphConfig graph_config
std::string scalarInputName
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders

Private Types

using ScalarCalcByRef_t = std::function<bool(const xAOD::TauJet &, float &)>
using ScalarCalc_t = std::function<float(const xAOD::IParticle*)>

Private Member Functions

ScalarCalc_t getScalarCalc (const std::string &name) const
void initMessaging () const
 Initialize our message level and MessageSvc.

Private Attributes

std::string m_nm
 Message source name.
boost::thread_specific_ptr< MsgStream > m_msg_tls
 MsgStream instance (a std::cout like with print-out levels)
std::atomic< IMessageSvc * > m_imsg { nullptr }
 MessageSvc pointer.
std::atomic< MSG::Level > m_lvl { MSG::NIL }
 Current logging level.
std::atomic_flag m_initialized ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
 Messaging initialized (initMessaging)

Static Private Attributes

static const std::unordered_map< std::string, ScalarCalcByRef_tm_func_map

Detailed Description

Definition at line 57 of file TauGNNDataLoader.h.

Member Typedef Documentation

◆ ScalarCalc_t

using TauGNNDataLoader::ScalarCalc_t = std::function<float(const xAOD::IParticle*)>
private

Definition at line 80 of file TauGNNDataLoader.h.

◆ ScalarCalcByRef_t

using TauGNNDataLoader::ScalarCalcByRef_t = std::function<bool(const xAOD::TauJet &, float &)>
private

Definition at line 79 of file TauGNNDataLoader.h.

Constructor & Destructor Documentation

◆ TauGNNDataLoader()

TauGNNDataLoader::TauGNNDataLoader ( std::shared_ptr< const FlavorTagInference::SaltModel > salt_model,
const Config & config )

Definition at line 10 of file TauGNNDataLoader.cxx.

13 :
14 FlavorTagInference::SaltModelEDMLoaderBase(salt_model),
15 asg::AsgMessaging("TauGNNDataLoader")
16 {
17 scalarInputName = config.input_layer_scalar;
18 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* scalar_input_node = nullptr;
19 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* track_input_node = nullptr;
20 const FlavorTagInference::SaltModelGraphConfig::InputNodeConfig* cluster_input_node = nullptr;
21 for (const auto &in_node : graph_config.inputs) {
22 if (in_node.name == config.input_layer_scalar) {
23 scalar_input_node = &in_node;
24 ATH_MSG_DEBUG("Found scalar input node: " << in_node.name);
25 break;
26 }
27 }
28 for (const auto &in_node : graph_config.input_sequences) {
29 if (in_node.name == config.input_layer_tracks) {
30 track_input_node = &in_node;
31 ATH_MSG_DEBUG("Found track input node: " << in_node.name);
32 }
33 if (in_node.name == config.input_layer_clusters) {
34 cluster_input_node = &in_node;
35 ATH_MSG_DEBUG("Found cluster input node: " << in_node.name);
36 }
37 }
38
39 // Fill the variable names of each input layer into the corresponding vector
40 if (scalar_input_node) {
41 for (const auto &in : scalar_input_node->variables) {
42 addScalarLoader(in.name, getScalarCalc(in.name));
43 }
44 } else {
45 ATH_MSG_ERROR("Scalar input node 'tau_vars' not found in the model input configuration");
46 throw std::runtime_error("Scalar input node 'tau_vars' not found in the model input configuration");
47 }
48
49 if (track_input_node) {
50 FlavorTagInference::ConstituentsInputConfig trk_config;
51 trk_config.name = "tautracks";
52 trk_config.output_name = config.input_layer_tracks;
55 trk_config.max_n_constituents = config.n_max_tracks;
57 trk_config.inputs = {};
58 for (const auto &in : track_input_node->variables) {
59 if (!config.useTRT && (in.name == "eProbabilityHT")) {
60 ATH_MSG_WARNING("Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
61 trk_config.inputs.push_back({"eProbabilityHT_noTRT", FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
62 continue;
63 }
64 trk_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
65 }
66 addVectorLoader(config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
67 } else {
68 ATH_MSG_ERROR("Track input node '" + config.input_layer_tracks + "' not found in the model input configuration");
69 throw std::runtime_error("Track input node '" + config.input_layer_tracks + "' not found in the model input configuration");
70 }
71
72 if (cluster_input_node) {
73 FlavorTagInference::ConstituentsInputConfig cls_config;
74 cls_config.name = "tauclusters";
75 cls_config.output_name = config.input_layer_clusters;
78 cls_config.max_n_constituents = config.n_max_clusters;
80 cls_config.inputs = {};
81 for (const auto &in : cluster_input_node->variables) {
82 cls_config.inputs.push_back({in.name, FlavorTagInference::ConstituentsEDMType::CUSTOM_GETTER, false});
83 }
84 addVectorLoader(config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config, config.max_dr_cluster, config.doVertexCorrection));
85 } else {
86 ATH_MSG_ERROR("Cluster input node '" + config.input_layer_clusters + "' not found in the model input configuration");
87 throw std::runtime_error("Cluster input node '" + config.input_layer_clusters + "' not found in the model input configuration");
88 }
89}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
SaltModelGraphConfig::GraphConfig graph_config
ScalarCalc_t getScalarCalc(const std::string &name) const

◆ ~TauGNNDataLoader()

TauGNNDataLoader::~TauGNNDataLoader ( )
default

Member Function Documentation

◆ addScalarLoader()

void FlavorTagInference::SaltModelEDMLoaderBase::addScalarLoader ( const std::string & varName,
std::function< float(const xAOD::IParticle *)> loader )
inlineinherited

Definition at line 41 of file SaltModelEDMLoaderBase.h.

41 {
42 scalarVarLoaders.emplace_back(varName, loader);
43 }
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders

◆ addVectorLoader()

void FlavorTagInference::SaltModelEDMLoaderBase::addVectorLoader ( const std::string & vecName,
std::shared_ptr< IConstituentsLoader > loader )
inlineinherited

Definition at line 45 of file SaltModelEDMLoaderBase.h.

45 {
46 vectorVarLoaders.try_emplace(vecName, std::move(loader));
47 }
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders

◆ DumpGnnInputs()

void FlavorTagInference::SaltModelEDMLoaderBase::DumpGnnInputs ( const SaltModelInputs & gnn_inputs) const
inlineinherited

Definition at line 73 of file SaltModelEDMLoaderBase.h.

73 {
74 // Implementation for dumping GNN input data
75 std::cout << "-------- Dumping GNN Input Data --------" << std::endl;
76
77 for (const auto& [name, inputs] : gnn_inputs) {
78 std::cout << "Input Name: " << name << std::endl;
79 std::cout << " vec floats: ";
80 for (const auto& feature : inputs.first) {
81 std::cout << feature << " ";
82 }
83 std::cout << std::endl;
84 std::cout << " vec ints : ";
85 for (const auto& id : inputs.second) {
86 std::cout << id << " ";
87 }
88 std::cout << std::endl;
89 }
90 std::cout << "---------- END GNN Input Data ----------" << std::endl;
91 }

◆ getScalarCalc()

ScalarCalc_t TauGNNDataLoader::getScalarCalc ( const std::string & name) const
private

Definition at line 91 of file TauGNNDataLoader.cxx.

91 {
92 // Retrieve calculator function
93 ScalarCalcByRef_t func = nullptr;
94 try {
95 func = m_func_map.at(name);
96 } catch (const std::out_of_range &e) {
97 ATH_MSG_ERROR("Variable '" << name << "' not defined");
98 throw std::runtime_error("Variable '" + name + "' not defined");
99 }
100 return [func](const xAOD::IParticle* p) {
101 auto tau = dynamic_cast<const xAOD::TauJet*>(p);
102 float out;
103 if (!tau) {
104 throw std::runtime_error("Invalid TauJet pointer");
105 }
106 auto success = func(*tau, out);
107 if (!success) {
108 throw std::runtime_error("Error in scalar variable calculation ");
109 }
110 return out;
111 };
112}
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
TauJet_v3 TauJet
Definition of the current "tau version".

◆ initMessaging()

void AthMessaging::initMessaging ( ) const
privateinherited

Initialize our message level and MessageSvc.

This method should only be called once.

Definition at line 39 of file AthMessaging.cxx.

40{
42 // If user did not set an explicit level, set a default
43 if (m_lvl == MSG::NIL) {
44 m_lvl = m_imsg ?
45 static_cast<MSG::Level>( m_imsg.load()->outputLevel(m_nm) ) :
46 MSG::INFO;
47 }
48}
std::string m_nm
Message source name.
std::atomic< IMessageSvc * > m_imsg
MessageSvc pointer.
std::atomic< MSG::Level > m_lvl
Current logging level.
IMessageSvc * getMessageSvc(bool quiet=false)

◆ loadInputs()

virtual SaltModelData FlavorTagInference::SaltModelEDMLoaderBase::loadInputs ( const xAOD::IParticle * p) const
inlinefinalvirtualinherited

Definition at line 49 of file SaltModelEDMLoaderBase.h.

49 {
50 SaltModelData salt_model_data;
51 // loading scalar inputs.
52 std::vector<float> scalar_feat;
53 for (const auto& varLoader : scalarVarLoaders) {
54 std::string varName = varLoader.first;
55 scalar_feat.push_back(varLoader.second(p));
56 }
57 std::vector<int64_t> scalar_feat_dim = {1, static_cast<int64_t>(scalar_feat.size())};
58 Inputs scalar_inputs(scalar_feat, scalar_feat_dim);
59 salt_model_data.gnn_inputs.insert({scalarInputName, scalar_inputs});
60
61 //load vector inputs.
62 for (auto loader : vectorVarLoaders) {
63 std::string input_name = loader.first;
64 auto [input_data, input_objects] = loader.second->getData(*p);
65
66 salt_model_data.gnn_inputs.insert({input_name, input_data});
67 salt_model_data.num_inputs += input_data.first.size();
68 salt_model_data.constituents[input_name] = input_objects;
69 }
70 return salt_model_data;
71 }
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
str varName
end cluster ToT and charge

◆ msg() [1/2]

MsgStream & asg::AsgMessaging::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 49 of file AsgMessaging.cxx.

49 {
50#ifndef XAOD_STANDALONE
51 return ::AthMessaging::msg();
52#else // not XAOD_STANDALONE
53 return m_msg;
54#endif // not XAOD_STANDALONE
55 }

◆ msg() [2/2]

MsgStream & asg::AsgMessaging::msg ( const MSG::Level lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 57 of file AsgMessaging.cxx.

57 {
58#ifndef XAOD_STANDALONE
59 return ::AthMessaging::msg( lvl );
60#else // not XAOD_STANDALONE
61 m_msg << lvl;
62 return m_msg;
63#endif // not XAOD_STANDALONE
64 }

◆ msgLvl()

bool asg::AsgMessaging::msgLvl ( const MSG::Level lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 41 of file AsgMessaging.cxx.

41 {
42#ifndef XAOD_STANDALONE
43 return ::AthMessaging::msgLvl( lvl );
44#else // not XAOD_STANDALONE
45 return m_msg.msgLevel( lvl );
46#endif // not XAOD_STANDALONE
47 }

◆ setLevel()

void AthMessaging::setLevel ( MSG::Level lvl)
inherited

Change the current logging level.

Use this rather than msg().setLevel() for proper operation with MT.

Definition at line 28 of file AthMessaging.cxx.

29{
30 m_lvl = lvl;
31}

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::atomic_flag m_initialized AthMessaging::ATLAS_THREAD_SAFE = ATOMIC_FLAG_INIT
mutableprivateinherited

Messaging initialized (initMessaging)

Definition at line 141 of file AthMessaging.h.

◆ graph_config

SaltModelGraphConfig::GraphConfig FlavorTagInference::SaltModelEDMLoaderBase::graph_config
inherited

Definition at line 35 of file SaltModelEDMLoaderBase.h.

◆ m_func_map

const std::unordered_map<std::string, ScalarCalcByRef_t> TauGNNDataLoader::m_func_map
inlinestaticprivate
Initial value:
= {
{"isolFrac", TauScalarVars::isolFrac},
{"centFrac", TauScalarVars::centFrac},
{"etOverPtLeadTrk", TauScalarVars::etOverPtLeadTrk},
{"innerTrkAvgDist", TauScalarVars::innerTrkAvgDist},
{"absipSigLeadTrk", TauScalarVars::absipSigLeadTrk},
{"SumPtTrkFrac", TauScalarVars::SumPtTrkFrac},
{"sumEMCellEtOverLeadTrkPt", TauScalarVars::sumEMCellEtOverLeadTrkPt},
{"EMPOverTrkSysP", TauScalarVars::EMPOverTrkSysP},
{"ptRatioEflowApprox", TauScalarVars::ptRatioEflowApprox},
{"mEflowApprox", TauScalarVars::mEflowApprox},
{"trFlightPathSig", TauScalarVars::trFlightPathSig},
{"massTrkSys", TauScalarVars::massTrkSys},
}
bool pt(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)

Definition at line 82 of file TauGNNDataLoader.h.

82 {
83 {"isolFrac", TauScalarVars::isolFrac},
84 {"centFrac", TauScalarVars::centFrac},
85 {"etOverPtLeadTrk", TauScalarVars::etOverPtLeadTrk},
86 {"innerTrkAvgDist", TauScalarVars::innerTrkAvgDist},
87 {"absipSigLeadTrk", TauScalarVars::absipSigLeadTrk},
88 {"SumPtTrkFrac", TauScalarVars::SumPtTrkFrac},
89 {"sumEMCellEtOverLeadTrkPt", TauScalarVars::sumEMCellEtOverLeadTrkPt},
90 {"EMPOverTrkSysP", TauScalarVars::EMPOverTrkSysP},
91 {"ptRatioEflowApprox", TauScalarVars::ptRatioEflowApprox},
92 {"mEflowApprox", TauScalarVars::mEflowApprox},
93 {"dRmax", TauScalarVars::dRmax},
94 {"trFlightPathSig", TauScalarVars::trFlightPathSig},
95 {"massTrkSys", TauScalarVars::massTrkSys},
96 {"pt", TauScalarVars::pt}
97 };

◆ m_imsg

std::atomic<IMessageSvc*> AthMessaging::m_imsg { nullptr }
mutableprivateinherited

MessageSvc pointer.

Definition at line 135 of file AthMessaging.h.

135{ nullptr };

◆ m_lvl

std::atomic<MSG::Level> AthMessaging::m_lvl { MSG::NIL }
mutableprivateinherited

Current logging level.

Definition at line 138 of file AthMessaging.h.

138{ MSG::NIL };

◆ m_msg_tls

boost::thread_specific_ptr<MsgStream> AthMessaging::m_msg_tls
mutableprivateinherited

MsgStream instance (a std::cout like with print-out levels)

Definition at line 132 of file AthMessaging.h.

◆ m_nm

std::string AthMessaging::m_nm
privateinherited

Message source name.

Definition at line 129 of file AthMessaging.h.

◆ scalarInputName

std::string FlavorTagInference::SaltModelEDMLoaderBase::scalarInputName
inherited

Definition at line 36 of file SaltModelEDMLoaderBase.h.

◆ scalarVarLoaders

std::vector<std::pair<std::string , std::function<float(const xAOD::IParticle* )> > > FlavorTagInference::SaltModelEDMLoaderBase::scalarVarLoaders
inherited

Definition at line 37 of file SaltModelEDMLoaderBase.h.

◆ vectorVarLoaders

std::map<std::string , std::shared_ptr<IConstituentsLoader> > FlavorTagInference::SaltModelEDMLoaderBase::vectorVarLoaders
inherited

Definition at line 38 of file SaltModelEDMLoaderBase.h.


The documentation for this class was generated from the following files: