11 std::shared_ptr<const FlavorTagInference::SaltModel> salt_model,
23 if (in_node.name == config.input_layer_scalar) {
24 scalar_input_node = &in_node;
25 ATH_MSG_DEBUG(
"Found scalar input node: " << in_node.name);
29 for (
const auto &in_node : graph_config.input_sequences) {
30 if (in_node.name == config.input_layer_tracks) {
31 track_input_node = &in_node;
32 ATH_MSG_DEBUG(
"Found track input node: " << in_node.name);
34 if (in_node.name ==
config.input_layer_clusters) {
35 cluster_input_node = &in_node;
36 ATH_MSG_DEBUG(
"Found cluster input node: " << in_node.name);
38 if (in_node.name ==
config.input_layer_hits) {
39 hit_input_node = &in_node;
40 ATH_MSG_DEBUG(
"Found hit input node: " << in_node.name);
45 if (scalar_input_node) {
46 for (
const auto &in : scalar_input_node->variables) {
47 addScalarLoader(in.name, getScalarCalc(in.name));
49 }
else if(!
config.input_layer_scalar.empty()) {
50 ATH_MSG_ERROR(
"Scalar input node '" + config.input_layer_scalar +
"' not found in the model input configuration");
51 throw std::runtime_error(
"Scalar input node '" + config.input_layer_scalar +
"' not found in the model input configuration");
54 if (track_input_node) {
56 trk_config.
name =
"tautracks";
63 for (
const auto &in : track_input_node->variables) {
64 if (!
config.useTRT && (in.name ==
"eProbabilityHT")) {
65 ATH_MSG_WARNING(
"Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
71 addVectorLoader(
config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
72 }
else if(!
config.input_layer_tracks.empty() &&
config.n_max_tracks > 0) {
73 ATH_MSG_ERROR(
"Track input node '" + config.input_layer_tracks +
"' not found in the model input configuration");
74 throw std::runtime_error(
"Track input node '" + config.input_layer_tracks +
"' not found in the model input configuration");
77 if (cluster_input_node) {
79 cls_config.
name =
"tauclusters";
86 for (
const auto &in : cluster_input_node->variables) {
89 addVectorLoader(
config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config,
config.max_dr_cluster,
config.doVertexCorrection));
90 }
else if(!
config.input_layer_clusters.empty() &&
config.n_max_clusters > 0) {
91 ATH_MSG_ERROR(
"Cluster input node '" + config.input_layer_clusters +
"' not found in the model input configuration");
92 throw std::runtime_error(
"Cluster input node '" + config.input_layer_clusters +
"' not found in the model input configuration");
97 cls_config.
name =
"tauhits";
104 for (
const auto &in : hit_input_node->variables) {
107 addVectorLoader(
config.input_layer_hits, std::make_shared<FlavorTagInference::ConstituentLoaderTauHit>(cls_config,
config.hits_decor_name));
108 }
else if(!
config.input_layer_hits.empty() &&
config.n_max_hits > 0) {
109 ATH_MSG_ERROR(
"Hit input node '" + config.input_layer_hits +
"' not found in the model input configuration");
110 throw std::runtime_error(
"Hit input node '" + config.input_layer_hits +
"' not found in the model input configuration");
119 }
catch (
const std::out_of_range &e) {
121 throw std::runtime_error(
"Variable '" + name +
"' not defined");
127 throw std::runtime_error(
"Invalid TauJet pointer");
129 auto success = func(*tau, out);
131 throw std::runtime_error(
"Error in scalar variable calculation ");
147 out = std::abs(tau.
eta());
184 out = std::abs(ipSigLeadTrk);
226 const auto success = tau.
detail(TauDetail::dRmax,
dRmax);
251 out = std::log10(std::max(tau.
pt() / 1000., 1e-6));
276 out = std::log10(std::max(tau.
ptJetSeed(), 1e-3));
282 float absEtaLeadTrack = acc_absEtaLeadTrack(tau);
283 out = std::max(0.f, absEtaLeadTrack);
289 out = std::max(0.f, absDeltaEta);
294 float absDeltaPhi = tau.
nTracks() > 0 ? std::abs( tau.
track(0)->
track()->
p4().DeltaPhi(tau.
p4()) ) : -1111.;
295 out = std::max(0.f, absDeltaPhi);
303 if (!tracks.empty()) {
305 return lhs->
pt() > rhs->pt();
307 std::sort(tracks.begin(), tracks.end(), cmp_pt);
313 float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle);
314 out = (tauLeadTrack->
pt()>2000.) ? eProbabilityNN : eProbabilityHT;
324 float emFracFixed = acc_emFracFixed(tau);
325 out = std::max(emFracFixed, 0.0f);
345 const auto success = tau.
detail(TauDetail::PSSFraction,
PSFrac);
346 out = std::max(0.f,
PSFrac);
Scalar eta() const
pseudorapidity method
#define ATH_MSG_WARNING(x)
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
SaltModelGraphConfig::GraphConfig graph_config
SaltModelEDMLoaderBase(ISaltModelPtr salt_model)
std::string scalarInputName
Helper class to provide constant type-safe access to aux data.
ScalarCalc_t getScalarCalc(const std::string &name) const
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
TauGNNDataLoader(std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
AsgMessaging(const std::string &name)
Constructor with a name.
Class providing the definition of the 4-vector interface.
double ptDetectorAxis() const
virtual FourMom_t p4() const
The full 4-momentum of the particle.
virtual double pt() const
The transverse momentum ( ) of the particle.
double ptIntermediateAxis() const
bool detail(TauJetParameters::Detail detail, int &value) const
Get and set values of common details variables via enum.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
double etaJetSeed() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
virtual double eta() const
The pseudorapidity ( ) of the particle.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual double pt() const
The transverse momentum ( ) of the particle.
const TrackParticle * track() const
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
bool summaryValue(uint8_t &value, const SummaryType &information) const
Accessor for TrackSummary values.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool ptDetectorAxis(const xAOD::TauJet &tau, float &out)
bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out)
bool pt(const xAOD::TauJet &tau, float &out)
bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out)
bool EMFracFixed(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool PSFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out)
bool ptJetSeed_log(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out)
bool etaJetSeed(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool absEta(const xAOD::TauJet &tau, float &out)
bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out)
xAOD::TauJetParameters::Detail TauDetail
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out)
bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool pt_tau_log(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
bool ptJetSeed(const xAOD::TauJet &tau, float &out)
bool absleadTrackEta(const xAOD::TauJet &tau, float &out)
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Detail
Enum for tau parameters - used mainly for backward compatibility with the analysis code.
TrackParticle_v1 TrackParticle
Reference the current persistent version:
TauTrack_v1 TauTrack
Definition of the current version.
TauJet_v3 TauJet
Definition of the current "tau version".
@ eProbabilityHT
Electron probability from High Threshold (HT) information [float].