23using VectorMap = std::map<std::string, std::vector<double>>;
24using InputMap = std::map<std::string, ValueMap>;
42 if (weightFile.empty())
45 return StatusCode::FAILURE;
47 ATH_MSG_INFO(
"Loaded network configuration from: " << weightFile);
50 std::ifstream inputFile(weightFile);
51 lwt::GraphConfig lwtGraphConfig;
54 lwtGraphConfig = lwt::parse_json_graph(inputFile);
56 catch (
const std::logic_error &e)
59 return StatusCode::FAILURE;
65 m_lwtGraph = std::make_unique<lwt::LightweightGraph>(lwtGraphConfig, lwtGraphConfig.outputs.cbegin()->first);
67 catch (
const lwt::NNConfigurationException &e)
70 return StatusCode::FAILURE;
73 return StatusCode::SUCCESS;
85 std::set<std::string> branches = {
"TauTrack",
"NeutralPFO",
"ShotPFO",
"ConvTrack"};
98 outputs =
m_lwtGraph->compute(inputMapDummy, inputSeqMap);
100 catch (
const std::exception &e)
103 return StatusCode::FAILURE;
111 std::array<float, DMVar::nClasses> probs;
113 std::string prefix =
"c_";
114 for (std::size_t i = 0; i < probs.size(); ++i)
123 std::array<float, DMVar::nClasses>::const_iterator itMax;
127 itMax = std::max_element(probs.cbegin(), probs.cbegin() + 3);
132 itMax = std::max_element(probs.cbegin() + 3, probs.cend());
137 itMax = std::max_element(probs.cbegin(), probs.cend());
141 accDecayMode(xTau) = std::distance(probs.cbegin(), itMax);
145 for (std::size_t i = 0; i < probs.size(); ++i)
149 accProb(xTau) = probs[i];
153 return StatusCode::SUCCESS;
158 std::vector<TrkPtr> vTauTracks;
159 std::vector<PFOPtr> vNeutralPFOs;
160 std::vector<PFOPtr> vShotPFOs;
161 std::vector<TrkPtr> vConvTracks;
176 vNeutralPFOs.push_back(pfo);
180 for (std::size_t i = 0; i < xTau.
nShotPFOs(); ++i)
182 const auto pfo = xTau.
shotPFO(i);
189 catch (
const std::exception &e)
191 ATH_MSG_ERROR(
"Error retrieving tauShots_nPhotons: " << e.what());
192 return StatusCode::FAILURE;
196 vShotPFOs.push_back(pfo);
214 std::pair<float, bool> tau_etaTrkECal{0.,
false};
215 std::pair<float, bool> tau_phiTrkECal{0.,
false};
225 tau_phiTrkECal.second =
true;
233 tau_etaTrkECal.second =
true;
238 auto setCommonP4Vars = [&tau_p4, &tau_etaTrkECal, &tau_phiTrkECal](
VectorMap &in_seq_map,
const TLorentzVector &obj_p4) {
249 in_seq_map[
"d0TJVA"].push_back(trk->d0TJVA());
250 in_seq_map[
"d0SigTJVA"].push_back(trk->d0SigTJVA());
251 in_seq_map[
"z0sinthetaTJVA"].push_back(trk->z0sinthetaTJVA());
252 in_seq_map[
"z0sinthetaSigTJVA"].push_back(trk->z0sinthetaSigTJVA());
256 auto setNeutralPFOVars = [](
VectorMap &in_seq_map,
const PFOPtr &pfo) {
261 in_seq_map[
"FIRST_ETA"].push_back(getAttr(PFOAttributes::cellBased_FIRST_ETA));
262 in_seq_map[
"SECOND_R_log"].push_back(
DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_SECOND_R), 1e-3f));
263 in_seq_map[
"DELTA_THETA"].push_back(getAttr(PFOAttributes::cellBased_DELTA_THETA));
264 in_seq_map[
"CENTER_LAMBDA_log"].push_back(
DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_CENTER_LAMBDA), 1e-3f));
265 in_seq_map[
"LONGITUDINAL"].push_back(getAttr(PFOAttributes::cellBased_LONGITUDINAL));
266 in_seq_map[
"ENG_FRAC_CORE"].push_back(getAttr(PFOAttributes::cellBased_ENG_FRAC_CORE));
267 in_seq_map[
"SECOND_ENG_DENS_log"].push_back(
DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_SECOND_ENG_DENS), 1e-6f));
268 in_seq_map[
"NPosECells_EM1"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM1));
269 in_seq_map[
"NPosECells_EM2"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM2));
270 in_seq_map[
"energy_EM1"].push_back(getAttr(PFOAttributes::cellBased_energy_EM1));
271 in_seq_map[
"energy_EM2"].push_back(getAttr(PFOAttributes::cellBased_energy_EM2));
272 in_seq_map[
"EM1CoreFrac"].push_back(getAttr(PFOAttributes::cellBased_EM1CoreFrac));
273 in_seq_map[
"firstEtaWRTClusterPosition_EM1"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1));
274 in_seq_map[
"firstEtaWRTClusterPosition_EM2"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM2));
275 in_seq_map[
"secondEtaWRTClusterPosition_EM1_log"].push_back(
DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM1), 1e-6f));
276 in_seq_map[
"secondEtaWRTClusterPosition_EM2_log"].push_back(
DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2), 1e-6f));
280 VectorMap &chrg_map = inputSeqMap.at(
"TauTrack");
283 for (
const auto &trk : vTauTracks)
285 setCommonP4Vars(chrg_map, trk->p4());
286 setTrackIPVars(chrg_map, trk);
290 VectorMap &neut_map = inputSeqMap.at(
"NeutralPFO");
293 for (
const auto &pfo : vNeutralPFOs)
295 setCommonP4Vars(neut_map, pfo->p4());
298 setNeutralPFOVars(neut_map, pfo);
300 catch (
const std::exception &e)
302 ATH_MSG_ERROR(
"Error setting neutral PFO variables: " << e.what());
303 return StatusCode::FAILURE;
308 VectorMap &shot_map = inputSeqMap.at(
"ShotPFO");
310 for (
const auto &pfo : vShotPFOs)
312 setCommonP4Vars(shot_map, pfo->p4());
316 VectorMap &conv_map = inputSeqMap.at(
"ConvTrack");
319 for (
const auto &trk : vConvTracks)
321 setCommonP4Vars(conv_map, trk->p4());
322 setTrackIPVars(conv_map, trk);
325 return StatusCode::SUCCESS;
332 "dphiECal",
"detaECal",
"dphi",
"deta",
"pt_log",
"jetpt_log"};
335 "d0TJVA",
"d0SigTJVA",
"z0sinthetaTJVA",
"z0sinthetaSigTJVA"};
338 "FIRST_ETA",
"SECOND_R_log",
"DELTA_THETA",
"CENTER_LAMBDA_log",
"LONGITUDINAL",
"ENG_FRAC_CORE",
339 "SECOND_ENG_DENS_log",
"NPosECells_EM1",
"NPosECells_EM2",
"energy_EM1",
"energy_EM2",
"EM1CoreFrac",
340 "firstEtaWRTClusterPosition_EM1",
"firstEtaWRTClusterPosition_EM2",
341 "secondEtaWRTClusterPosition_EM1_log",
"secondEtaWRTClusterPosition_EM2_log"};
344 "1p0n",
"1p1n",
"1pXn",
"3p0n",
"3pXn"};
348 return p4_tau.DeltaPhi(p4);
353 return p4.Eta() - p4_tau.Eta();
359 return tau_phiTrkECal.second ? TVector2::Phi_mpi_pi(p4.Phi() - tau_phiTrkECal.first) : 0.0f;
365 return tau_etaTrkECal.second ? p4.Eta() - tau_etaTrkECal.first : 0.0f;
368 template <
typename T>
371 T val{
static_cast<T
>(0)};
374 throw std::runtime_error(
"Can not retrieve PFO attribute! enum = " + std::to_string(
static_cast<unsigned>(attr)));
381 return TMath::Log10(std::max(val, min_val));
384 template <
typename T>
387 auto cmp_pt = [](
const T lhs,
const T rhs) {
return lhs->pt() > rhs->pt(); };
389 if (
vec.size() > n_obj)
391 vec.erase(
vec.begin() + n_obj,
vec.end());
395 template <
typename T>
397 const std::set<std::string> &keys)
400 for (
const auto &key : keys)
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_WARNING(x)
std::vector< size_t > vec
std::map< std::string, std::vector< double > > VectorMap
std::map< std::string, VectorMap > InputSequenceMap
const xAOD::TauTrack * TrkPtr
std::map< std::string, double > ValueMap
std::map< std::string, ValueMap > InputMap
xAOD::PFODetails::PFOAttributes PFOAttributes
tauRecTools::TauDecayModeNNVariable DMVar
tauRecTools::TauDecayModeNNHelper DMHelper
Helper class to provide type-safe access to aux data.
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_decorateProb
virtual StatusCode getInputs(const xAOD::TauJet &xTau, std::map< std::string, std::map< std::string, std::vector< double > > > &inputSeqMap) const
retrieve the input variables from a TauJet
TauDecayModeNNClassifier(const std::string &name="TauDecayModeNNClassifier")
Gaudi::Property< bool > m_ensureTrackConsistency
Gaudi::Property< std::size_t > m_maxTauTracks
Gaudi::Property< std::string > m_outputName
properties of the tool
Gaudi::Property< std::string > m_weightFile
Gaudi::Property< std::size_t > m_maxConvTracks
Gaudi::Property< float > m_neutralPFOPtCut
virtual StatusCode initialize() override
Tool initializer.
std::unique_ptr< const lwt::LightweightGraph > m_lwtGraph
lwtnn graph
virtual ~TauDecayModeNNClassifier()
Gaudi::Property< std::string > m_probPrefix
Gaudi::Property< std::size_t > m_maxShotPFOs
Gaudi::Property< std::size_t > m_maxNeutralPFOs
bool attribute(PFODetails::PFOAttributes AttributeType, T &anAttribute) const
get a PFO Variable via enum
size_t nNeutralPFOs() const
Get the number of neutral PFO particles associated with this tau.
size_t nShotPFOs() const
Get the number of shot PFO particles associated with this tau.
virtual FourMom_t p4() const
The full 4-momentum of the particle.
const PFO * shotPFO(size_t i) const
Get the pointer to a given shot PFO associated with this tau.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
const PFO * neutralPFO(size_t i) const
Get the pointer to a given neutral PFO associated with this tau.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
std::vector< const TauTrack * > tracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Get the v<const pointer> to a given tauTrack collection associated with this tau.
bool detail(TauJetParameters::TrackDetail detail, float &value) const
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
PFO_v1 PFO
Definition of the current "pfo version".
TauTrack_v1 TauTrack
Definition of the current version.
TauJet_v3 TauJet
Definition of the current "tau version".