11 std::shared_ptr<const FlavorTagInference::SaltModel> salt_model,
22 if (in_node.name == config.input_layer_scalar) {
23 scalar_input_node = &in_node;
24 ATH_MSG_DEBUG(
"Found scalar input node: " << in_node.name);
28 for (
const auto &in_node : graph_config.input_sequences) {
29 if (in_node.name == config.input_layer_tracks) {
30 track_input_node = &in_node;
31 ATH_MSG_DEBUG(
"Found track input node: " << in_node.name);
33 if (in_node.name ==
config.input_layer_clusters) {
34 cluster_input_node = &in_node;
35 ATH_MSG_DEBUG(
"Found cluster input node: " << in_node.name);
40 if (scalar_input_node) {
41 for (
const auto &in : scalar_input_node->variables) {
42 addScalarLoader(in.name, getScalarCalc(in.name));
45 ATH_MSG_ERROR(
"Scalar input node 'tau_vars' not found in the model input configuration");
46 throw std::runtime_error(
"Scalar input node 'tau_vars' not found in the model input configuration");
49 if (track_input_node) {
51 trk_config.
name =
"tautracks";
58 for (
const auto &in : track_input_node->variables) {
59 if (!
config.useTRT && (in.name ==
"eProbabilityHT")) {
60 ATH_MSG_WARNING(
"Track variable 'eProbabilityHT' requested but useTRT set to false. Using 'eProbabilityHT_noTRT' instead.");
66 addVectorLoader(
config.input_layer_tracks, std::make_shared<FlavorTagInference::ConstituentLoaderTauTrack>(trk_config));
68 ATH_MSG_ERROR(
"Track input node '" +
config.input_layer_tracks +
"' not found in the model input configuration");
69 throw std::runtime_error(
"Track input node '" +
config.input_layer_tracks +
"' not found in the model input configuration");
72 if (cluster_input_node) {
74 cls_config.
name =
"tauclusters";
81 for (
const auto &in : cluster_input_node->variables) {
84 addVectorLoader(
config.input_layer_clusters, std::make_shared<FlavorTagInference::ConstituentLoaderTauCluster>(cls_config,
config.max_dr_cluster,
config.doVertexCorrection));
86 ATH_MSG_ERROR(
"Cluster input node '" +
config.input_layer_clusters +
"' not found in the model input configuration");
87 throw std::runtime_error(
"Cluster input node '" +
config.input_layer_clusters +
"' not found in the model input configuration");
96 }
catch (
const std::out_of_range &e) {
98 throw std::runtime_error(
"Variable '" + name +
"' not defined");
104 throw std::runtime_error(
"Invalid TauJet pointer");
106 auto success = func(*tau, out);
108 throw std::runtime_error(
"Error in scalar variable calculation ");
119 out = std::abs(tau.
eta());
156 out = std::abs(ipSigLeadTrk);
198 const auto success = tau.
detail(TauDetail::dRmax,
dRmax);
223 out = std::log10(std::max(tau.
pt() / 1000., 1e-6));
238 out = std::log10(std::max(tau.
ptJetSeed(), 1e-3));
244 float absEtaLeadTrack = acc_absEtaLeadTrack(tau);
245 out = std::max(0.f, absEtaLeadTrack);
251 out = std::max(0.f, absDeltaEta);
256 float absDeltaPhi = tau.
nTracks() > 0 ? std::abs( tau.
track(0)->
track()->
p4().DeltaPhi(tau.
p4()) ) : -1111.;
257 out = std::max(0.f, absDeltaPhi);
265 if (!tracks.empty()) {
267 return lhs->
pt() > rhs->pt();
269 std::sort(tracks.begin(), tracks.end(), cmp_pt);
275 float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle);
276 out = (tauLeadTrack->
pt()>2000.) ? eProbabilityNN : eProbabilityHT;
286 float emFracFixed = acc_emFracFixed(tau);
287 out = std::max(emFracFixed, 0.0f);
307 const auto success = tau.
detail(TauDetail::PSSFraction,
PSFrac);
308 out = std::max(0.f,
PSFrac);
#define ATH_MSG_WARNING(x)
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
SaltModelGraphConfig::GraphConfig graph_config
SaltModelEDMLoaderBase(ISaltModelPtr salt_model)
std::string scalarInputName
Helper class to provide constant type-safe access to aux data.
ScalarCalc_t getScalarCalc(const std::string &name) const
std::function< bool(const xAOD::TauJet &, float &)> ScalarCalcByRef_t
std::function< float(const xAOD::IParticle *)> ScalarCalc_t
TauGNNDataLoader(std::shared_ptr< const FlavorTagInference::SaltModel > salt_model, const Config &config)
static const std::unordered_map< std::string, ScalarCalcByRef_t > m_func_map
AsgMessaging(const std::string &name)
Constructor with a name.
Class providing the definition of the 4-vector interface.
double ptDetectorAxis() const
virtual FourMom_t p4() const
The full 4-momentum of the particle.
virtual double pt() const
The transverse momentum ( ) of the particle.
double ptIntermediateAxis() const
bool detail(TauJetParameters::Detail detail, int &value) const
Get and set values of common details variables via enum.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
virtual double eta() const
The pseudorapidity ( ) of the particle.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual double pt() const
The transverse momentum ( ) of the particle.
const TrackParticle * track() const
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
bool summaryValue(uint8_t &value, const SummaryType &information) const
Accessor for TrackSummary values.
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool ptDetectorAxis(const xAOD::TauJet &tau, float &out)
bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, float &out)
bool pt(const xAOD::TauJet &tau, float &out)
bool ptIntermediateAxis(const xAOD::TauJet &tau, float &out)
bool EMFracFixed(const xAOD::TauJet &tau, float &out)
bool isolFrac(const xAOD::TauJet &tau, float &out)
bool massTrkSys(const xAOD::TauJet &tau, float &out)
bool PSFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, float &out)
bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, float &out)
bool EMPOverTrkSysP(const xAOD::TauJet &tau, float &out)
bool leadTrackProbNNorHT(const xAOD::TauJet &tau, float &out)
bool ptJetSeed_log(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaEta(const xAOD::TauJet &tau, float &out)
bool etOverPtLeadTrk(const xAOD::TauJet &tau, float &out)
bool absEta(const xAOD::TauJet &tau, float &out)
bool hadLeakFracFixed(const xAOD::TauJet &tau, float &out)
bool leadTrackDeltaPhi(const xAOD::TauJet &tau, float &out)
xAOD::TauJetParameters::Detail TauDetail
bool mEflowApprox(const xAOD::TauJet &tau, float &out)
bool innerTrkAvgDist(const xAOD::TauJet &tau, float &out)
bool centFrac(const xAOD::TauJet &tau, float &out)
bool ptRatioEflowApprox(const xAOD::TauJet &tau, float &out)
bool trFlightPathSig(const xAOD::TauJet &tau, float &out)
bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, float &out)
bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, float &out)
bool ClustersMeanEMProbability(const xAOD::TauJet &tau, float &out)
bool dRmax(const xAOD::TauJet &tau, float &out)
bool pt_tau_log(const xAOD::TauJet &tau, float &out)
bool absipSigLeadTrk(const xAOD::TauJet &tau, float &out)
bool SumPtTrkFrac(const xAOD::TauJet &tau, float &out)
bool absleadTrackEta(const xAOD::TauJet &tau, float &out)
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Detail
Enum for tau parameters - used mainly for backward compatibility with the analysis code.
TrackParticle_v1 TrackParticle
Reference the current persistent version:
TauTrack_v1 TauTrack
Definition of the current version.
TauJet_v3 TauJet
Definition of the current "tau version".
@ eProbabilityHT
Electron probability from High Threshold (HT) information [float].