ATLAS Offline Software
Public Types | Public Member Functions | Private Member Functions | Static Private Member Functions | Private Attributes | Static Private Attributes | List of all members
TauGNNUtils::GNNVarCalc Class Reference

Tool to calculate input variables for the GNN-based tau identification. More...

#include <TauGNNUtils.h>

Inheritance diagram for TauGNNUtils::GNNVarCalc:
Collaboration diagram for TauGNNUtils::GNNVarCalc:

Public Types

using ScalarCalc = std::function< bool(const xAOD::TauJet &, float &)>
 
using TrackCalc = std::function< bool(const xAOD::TauJet &, const xAOD::TauTrack &, float &)>
 
using ClusterCalc = std::function< bool(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &, float &)>
 

Public Member Functions

 GNNVarCalc (bool useTRT)
 
 ~GNNVarCalc ()=default
 
float compute (const std::string &name, const xAOD::TauJet &tau) const
 
std::vector< float > compute (const std::string &name, const xAOD::TauJet &tau, const std::vector< const xAOD::TauTrack * > &tracks) const
 
std::vector< float > compute (const std::string &name, const xAOD::TauJet &tau, const std::vector< xAOD::CaloVertexedTopoCluster > &clusters) const
 
void setLevel (MSG::Level lvl)
 Change the current logging level. More...
 

Private Member Functions

void initMessaging () const
 Initialize our message level and MessageSvc. More...
 

Static Private Member Functions

static void initialize_map (bool useTRT)
 

Private Attributes

std::string m_nm
 Message source name. More...
 
boost::thread_specific_ptr< MsgStream > m_msg_tls
 MsgStream instance (a std::cout like with print-out levels) More...
 
std::atomic< IMessageSvc * > m_imsg { nullptr }
 MessageSvc pointer. More...
 
std::atomic< MSG::Level > m_lvl { MSG::NIL }
 Current logging level. More...
 

Static Private Attributes

static std::once_flag m_init_flag
 
static const std::unordered_map< std::string, ScalarCalcm_scalar_map
 
static std::unordered_map< std::string, TrackCalc > m_track_map ATLAS_THREAD_SAFE
 
static const std::unordered_map< std::string, ClusterCalcm_cluster_map
 

Detailed Description

Tool to calculate input variables for the GNN-based tau identification.

Used to calculate input variables for (onnx)GNN-based tau identification on the fly by providing a mapping between variable names (strings) and functions to calculate these variables.

Author
C. Deutsch
W. Davey
N.M. Tamir
D. Qichen

Definition at line 275 of file TauGNNUtils.h.

Member Typedef Documentation

◆ ClusterCalc

Definition at line 280 of file TauGNNUtils.h.

◆ ScalarCalc

using TauGNNUtils::GNNVarCalc::ScalarCalc = std::function<bool(const xAOD::TauJet &, float &)>

Definition at line 278 of file TauGNNUtils.h.

◆ TrackCalc

using TauGNNUtils::GNNVarCalc::TrackCalc = std::function<bool(const xAOD::TauJet &, const xAOD::TauTrack &, float &)>

Definition at line 279 of file TauGNNUtils.h.

Constructor & Destructor Documentation

◆ GNNVarCalc()

TauGNNUtils::GNNVarCalc::GNNVarCalc ( bool  useTRT)

Definition at line 14 of file TauGNNUtils.cxx.

14  :
15  asg::AsgMessaging("TauGNNUtils::GNNVarCalc") {
16  initialize_map(useTRT);
17 }

◆ ~GNNVarCalc()

TauGNNUtils::GNNVarCalc::~GNNVarCalc ( )
default

Member Function Documentation

◆ compute() [1/3]

float TauGNNUtils::GNNVarCalc::compute ( const std::string &  name,
const xAOD::TauJet tau 
) const

Definition at line 25 of file TauGNNUtils.cxx.

25  {
26  // Retrieve calculator function
27  ScalarCalc func = nullptr;
28  try {
29  func = m_scalar_map.at(name);
30  } catch (const std::out_of_range &e) {
31  ATH_MSG_ERROR("Variable '" << name << "' not defined");
32  throw;
33  }
34 
35  // Calculate variable
36  float out;
37  bool success = func(tau, out);
38  if (!success) {
39  ATH_MSG_ERROR("Error in scalar variable calculation");
40  throw std::runtime_error("Error in scalar variable calculation");
41  }
42  return out;
43 }

◆ compute() [2/3]

std::vector< float > TauGNNUtils::GNNVarCalc::compute ( const std::string &  name,
const xAOD::TauJet tau,
const std::vector< const xAOD::TauTrack * > &  tracks 
) const

Definition at line 45 of file TauGNNUtils.cxx.

46  {
47  std::vector<float> out;
48  out.reserve(tracks.size());
49 
50  // Retrieve calculator function
51  TrackCalc func = nullptr;
52  try {
53  func = m_track_map.at(name);
54  } catch (const std::out_of_range &e) {
55  ATH_MSG_ERROR("Variable '" << name << "' not defined");
56  throw std::runtime_error("Variable '" + name + "' not defined");
57  }
58 
59  // Calculate variables for selected tracks
60  bool success = true;
61  float value;
62  for (const auto *const trk : tracks) {
63  success = success && func(tau, *trk, value);
64  out.push_back(value);
65  }
66  if (!success) {
67  ATH_MSG_ERROR("Error in track variable calculation");
68  throw std::runtime_error("Error in track variable calculation");
69  }
70  return out;
71 }

◆ compute() [3/3]

std::vector< float > TauGNNUtils::GNNVarCalc::compute ( const std::string &  name,
const xAOD::TauJet tau,
const std::vector< xAOD::CaloVertexedTopoCluster > &  clusters 
) const

Definition at line 73 of file TauGNNUtils.cxx.

74  {
75  std::vector<float> out;
76  out.reserve(clusters.size());
77 
78  // Retrieve calculator function
79  ClusterCalc func = nullptr;
80  try {
81  func = m_cluster_map.at(name);
82  } catch (const std::out_of_range &e) {
83  ATH_MSG_ERROR("Variable '" << name << "' not defined");
84  throw;
85  }
86 
87  // Calculate variables for selected clusters
88  bool success = true;
89  float value;
90  for (const xAOD::CaloVertexedTopoCluster& cluster : clusters) {
91  success = success && func(tau, cluster, value);
92  out.push_back(value);
93  }
94  if (!success) {
95  ATH_MSG_ERROR("Error in cluster variable calculation");
96  throw std::runtime_error("Error in cluster variable calculation");
97  }
98  return out;
99 }

◆ initialize_map()

void TauGNNUtils::GNNVarCalc::initialize_map ( bool  useTRT)
staticprivate

Definition at line 19 of file TauGNNUtils.cxx.

19  {
20  std::call_once(m_init_flag, [useTRT]() {
21  if(!useTRT) m_track_map["eProbabilityHT"] = Variables::Track::eProbabilityHT_noTRT;
22  });
23 }

◆ initMessaging()

void AthMessaging::initMessaging ( ) const
privateinherited

Initialize our message level and MessageSvc.

This method should only be called once.

Definition at line 39 of file AthMessaging.cxx.

40 {
42  m_lvl = m_imsg ?
43  static_cast<MSG::Level>( m_imsg.load()->outputLevel(m_nm) ) :
44  MSG::INFO;
45 }

◆ msg() [1/2]

MsgStream & asg::AsgMessaging::msg ( ) const
inherited

The standard message stream.

Returns
A reference to the default message stream of this object.

Definition at line 49 of file AsgMessaging.cxx.

49  {
50 #ifndef XAOD_STANDALONE
52 #else // not XAOD_STANDALONE
53  return m_msg;
54 #endif // not XAOD_STANDALONE
55  }

◆ msg() [2/2]

MsgStream & asg::AsgMessaging::msg ( const MSG::Level  lvl) const
inherited

The standard message stream.

Parameters
lvlThe message level to set the stream to
Returns
A reference to the default message stream, set to level "lvl"

Definition at line 57 of file AsgMessaging.cxx.

57  {
58 #ifndef XAOD_STANDALONE
60 #else // not XAOD_STANDALONE
61  m_msg << lvl;
62  return m_msg;
63 #endif // not XAOD_STANDALONE
64  }

◆ msgLvl()

bool asg::AsgMessaging::msgLvl ( const MSG::Level  lvl) const
inherited

Test the output level of the object.

Parameters
lvlThe message level to test against
Returns
boolean Indicting if messages at given level will be printed
true If messages at level "lvl" will be printed

Definition at line 41 of file AsgMessaging.cxx.

41  {
42 #ifndef XAOD_STANDALONE
43  return ::AthMessaging::msgLvl( lvl );
44 #else // not XAOD_STANDALONE
45  return m_msg.msgLevel( lvl );
46 #endif // not XAOD_STANDALONE
47  }

◆ setLevel()

void AthMessaging::setLevel ( MSG::Level  lvl)
inherited

Change the current logging level.

Use this rather than msg().setLevel() for proper operation with MT.

Definition at line 28 of file AthMessaging.cxx.

29 {
30  m_lvl = lvl;
31 }

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::unordered_map<std::string, TrackCalc> m_track_map TauGNNUtils::GNNVarCalc::ATLAS_THREAD_SAFE
inlinestaticprivate
Initial value:
= {
{"pt_jetseed_log", Variables::Track::pt_jetseed_log},
{"z0sinThetaTJVA_abs_log", Variables::Track::z0sinThetaTJVA_abs_log},
{"z0sinthetaTJVA", Variables::Track::z0sinthetaTJVA},
{"z0sinthetaSigTJVA", Variables::Track::z0sinthetaSigTJVA},
{"dEtaJetSeedAxis", Variables::Track::dEtaJetSeedAxis},
{"dPhiJetSeedAxis", Variables::Track::dPhiJetSeedAxis},
{"nInnermostPixelHits", Variables::Track::nInnermostPixelHits},
{"numberOfInnermostPixelLayerHits", Variables::Track::numberOfInnermostPixelLayerHits},
{"nIBLHitsAndExp", Variables::Track::nIBLHitsAndExp},
{"nPixelHitsPlusDeadSensors", Variables::Track::nPixelHitsPlusDeadSensors},
{"nSCTHitsPlusDeadSensors", Variables::Track::nSCTHitsPlusDeadSensors},
}

Definition at line 321 of file TauGNNUtils.h.

◆ m_cluster_map

const std::unordered_map<std::string, ClusterCalc> TauGNNUtils::GNNVarCalc::m_cluster_map
inlinestaticprivate
Initial value:

Definition at line 348 of file TauGNNUtils.h.

◆ m_imsg

std::atomic<IMessageSvc*> AthMessaging::m_imsg { nullptr }
mutableprivateinherited

MessageSvc pointer.

Definition at line 135 of file AthMessaging.h.

◆ m_init_flag

std::once_flag TauGNNUtils::GNNVarCalc::m_init_flag
inlinestaticprivate

Definition at line 301 of file TauGNNUtils.h.

◆ m_lvl

std::atomic<MSG::Level> AthMessaging::m_lvl { MSG::NIL }
mutableprivateinherited

Current logging level.

Definition at line 138 of file AthMessaging.h.

◆ m_msg_tls

boost::thread_specific_ptr<MsgStream> AthMessaging::m_msg_tls
mutableprivateinherited

MsgStream instance (a std::cout like with print-out levels)

Definition at line 132 of file AthMessaging.h.

◆ m_nm

std::string AthMessaging::m_nm
privateinherited

Message source name.

Definition at line 129 of file AthMessaging.h.

◆ m_scalar_map

const std::unordered_map<std::string, ScalarCalc> TauGNNUtils::GNNVarCalc::m_scalar_map
inlinestaticprivate
Initial value:

Definition at line 304 of file TauGNNUtils.h.


The documentation for this class was generated from the following files:
AthMessaging::m_lvl
std::atomic< MSG::Level > m_lvl
Current logging level.
Definition: AthMessaging.h:138
TauGNNUtils::Variables::Track::nIBLHitsAndExp
bool nIBLHitsAndExp(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:445
TauGNNUtils::Variables::Track::pt_jetseed_log
bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::TauTrack &, float &out)
Definition: TauGNNUtils.cxx:366
TauGNNUtils::Variables::Cluster::e
bool e(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:778
TauGNNUtils::Variables::Track::d0SigTJVA
bool d0SigTJVA(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:396
TauGNNUtils::Variables::Scalar::dRmax
bool dRmax(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:185
TauGNNUtils::Variables::Track::z0sinThetaTJVA_abs_log
bool z0sinThetaTJVA_abs_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:376
TauGNNUtils::Variables::Track::nPixelHitsPlusDeadSensors
bool nPixelHitsPlusDeadSensors(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:453
TauGNNUtils::GNNVarCalc::TrackCalc
std::function< bool(const xAOD::TauJet &, const xAOD::TauTrack &, float &)> TrackCalc
Definition: TauGNNUtils.h:279
TauGNNUtils::Variables::Scalar::etOverPtLeadTrk
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:128
TauGNNUtils::Variables::Track::d0_abs_log
bool d0_abs_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:371
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
athena.value
value
Definition: athena.py:124
TauGNNUtils::Variables::Scalar::pt
bool pt(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:206
TauGNNUtils::GNNVarCalc::m_init_flag
static std::once_flag m_init_flag
Definition: TauGNNUtils.h:301
TauGNNUtils::Variables::Track::z0sinthetaSigTJVA
bool z0sinthetaSigTJVA(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:386
TauGNNUtils::Variables::Cluster::dPhi
bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:705
TauGNNUtils::Variables::Scalar::centFrac
bool centFrac(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:112
AthMessaging::m_imsg
std::atomic< IMessageSvc * > m_imsg
MessageSvc pointer.
Definition: AthMessaging.h:135
Athena::getMessageSvc
IMessageSvc * getMessageSvc(bool quiet=false)
Definition: getMessageSvc.cxx:20
TauGNNUtils::GNNVarCalc::ScalarCalc
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalc
Definition: TauGNNUtils.h:278
TauGNNUtils::Variables::Track::nSCTHits
bool nSCTHits(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:437
TauGNNUtils::Variables::Scalar::ptRatioEflowApprox
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:170
TauGNNUtils::Variables::Scalar::sumEMCellEtOverLeadTrkPt
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:149
TauGNNUtils::Variables::Track::trackPhi
bool trackPhi(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:356
TauGNNUtils::Variables::Track::numberOfInnermostPixelLayerHits
bool numberOfInnermostPixelLayerHits(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:522
TrigConf::MSGTC::Level
Level
Definition: Trigger/TrigConfiguration/TrigConfBase/TrigConfBase/MsgStream.h:21
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
TauGNNUtils::Variables::Track::eProbabilityHT
bool eProbabilityHT(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:469
TauGNNUtils::Variables::Track::trackPt
bool trackPt(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:346
TauGNNUtils::Variables::Cluster::SECOND_LAMBDA
bool SECOND_LAMBDA(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:717
TauGNNUtils::Variables::Scalar::isolFrac
bool isolFrac(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:120
TauGNNUtils::Variables::Track::dEtaJetSeedAxis
bool dEtaJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:406
TauGNNUtils::Variables::Track::dEta
bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:401
TauGNNUtils::Variables::Scalar::massTrkSys
bool massTrkSys(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:199
TauGNNUtils::Variables::Cluster::SECOND_R
bool SECOND_R(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:710
TauGNNUtils::GNNVarCalc::initialize_map
static void initialize_map(bool useTRT)
Definition: TauGNNUtils.cxx:19
TauGNNUtils::Variables::Scalar::mEflowApprox
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:178
TauGNNUtils::Variables::Scalar::absipSigLeadTrk
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:142
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
TauGNNUtils::Variables::Track::pt_tau_log
bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::TauTrack &, float &out)
Definition: TauGNNUtils.cxx:361
asg::AsgMessaging
Class mimicking the AthMessaging class from the offline software.
Definition: AsgMessaging.h:40
TauGNNUtils::Variables::Track::nPixelHits
bool nPixelHits(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:430
TauGNNUtils::Variables::Track::pt_log
bool pt_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:341
TauGNNUtils::Variables::Scalar::EMPOverTrkSysP
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:163
TauGNNUtils::Variables::Cluster::dEta
bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:700
python.Constants.INFO
int INFO
Definition: Control/AthenaCommon/python/Constants.py:15
AthMessaging::m_nm
std::string m_nm
Message source name.
Definition: AthMessaging.h:129
TauGNNUtils::Variables::Track::dPhi
bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:412
RunTileMonitoring.clusters
clusters
Definition: RunTileMonitoring.py:133
TauGNNUtils::Variables::Track::nSCTHitsPlusDeadSensors
bool nSCTHitsPlusDeadSensors(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:461
TauGNNUtils::GNNVarCalc::m_cluster_map
static const std::unordered_map< std::string, ClusterCalc > m_cluster_map
Definition: TauGNNUtils.h:348
TauGNNUtils::Variables::Cluster::et
bool et(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:783
TauGNNUtils::GNNVarCalc::ClusterCalc
std::function< bool(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &, float &)> ClusterCalc
Definition: TauGNNUtils.h:280
TauGNNUtils::Variables::Scalar::innerTrkAvgDist
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:135
TauGNNUtils::Variables::Scalar::trFlightPathSig
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:192
xAOD::CaloVertexedTopoCluster
Evaluate cluster kinematics with a different vertex / signal state.
Definition: Event/xAOD/xAODCaloEvent/xAODCaloEvent/CaloVertexedTopoCluster.h:38
TauGNNUtils::Variables::Track::eProbabilityHT_noTRT
bool eProbabilityHT_noTRT(const xAOD::TauJet &, const xAOD::TauTrack &, float &out)
Definition: TauGNNUtils.cxx:476
TauGNNUtils::Variables::Track::nInnermostPixelHits
bool nInnermostPixelHits(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:423
TauGNNUtils::Variables::Track::trackEta
bool trackEta(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:351
TauGNNUtils::Variables::Scalar::SumPtTrkFrac
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
Definition: TauGNNUtils.cxx:156
TauGNNUtils::Variables::Cluster::CENTER_LAMBDA
bool CENTER_LAMBDA(const xAOD::TauJet &, const xAOD::CaloVertexedTopoCluster &cluster, float &out)
Definition: TauGNNUtils.cxx:724
TauGNNUtils::Variables::Track::z0sinthetaTJVA
bool z0sinthetaTJVA(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:381
TauGNNUtils::Variables::Track::dPhiJetSeedAxis
bool dPhiJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:417
python.AutoConfigFlags.msg
msg
Definition: AutoConfigFlags.py:7
TauGNNUtils::Variables::Track::d0TJVA
bool d0TJVA(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:391
TauGNNUtils::GNNVarCalc::m_scalar_map
static const std::unordered_map< std::string, ScalarCalc > m_scalar_map
Definition: TauGNNUtils.h:304