Loading [MathJax]/extensions/tex2jax.js
 |
ATLAS Offline Software
|
Go to the documentation of this file.
22 using ValueMap = std::map<std::string, double>;
23 using VectorMap = std::map<std::string, std::vector<double>>;
24 using InputMap = std::map<std::string, ValueMap>;
42 if (weightFile.empty())
45 return StatusCode::FAILURE;
47 ATH_MSG_INFO(
"Loaded network configuration from: " << weightFile);
51 lwt::GraphConfig lwtGraphConfig;
56 catch (
const std::logic_error &
e)
59 return StatusCode::FAILURE;
65 m_lwtGraph = std::make_unique<lwt::LightweightGraph>(lwtGraphConfig, lwtGraphConfig.outputs.cbegin()->first);
67 catch (
const lwt::NNConfigurationException &
e)
70 return StatusCode::FAILURE;
73 return StatusCode::SUCCESS;
85 std::set<std::string>
branches = {
"TauTrack",
"NeutralPFO",
"ShotPFO",
"ConvTrack"};
103 return StatusCode::FAILURE;
111 std::array<float, DMVar::nClasses> probs;
113 std::string
prefix =
"c_";
114 for (std::size_t
i = 0;
i < probs.size(); ++
i)
123 std::array<float, DMVar::nClasses>::const_iterator itMax;
127 itMax = std::max_element(probs.cbegin(), probs.cbegin() + 3);
132 itMax = std::max_element(probs.cbegin() + 3, probs.cend());
137 itMax = std::max_element(probs.cbegin(), probs.cend());
145 for (std::size_t
i = 0;
i < probs.size(); ++
i)
149 accProb(xTau) = probs[
i];
153 return StatusCode::SUCCESS;
158 std::vector<TrkPtr> vTauTracks;
159 std::vector<PFOPtr> vNeutralPFOs;
160 std::vector<PFOPtr> vShotPFOs;
161 std::vector<TrkPtr> vConvTracks;
176 vNeutralPFOs.push_back(pfo);
191 ATH_MSG_ERROR(
"Error retrieving tauShots_nPhotons: " <<
e.what());
192 return StatusCode::FAILURE;
196 vShotPFOs.push_back(pfo);
214 std::pair<float, bool> tau_etaTrkECal{0.,
false};
215 std::pair<float, bool> tau_phiTrkECal{0.,
false};
225 tau_phiTrkECal.second =
true;
233 tau_etaTrkECal.second =
true;
238 auto setCommonP4Vars = [&tau_p4, &tau_etaTrkECal, &tau_phiTrkECal](
VectorMap &in_seq_map,
const TLorentzVector &obj_p4) {
249 in_seq_map[
"d0TJVA"].push_back(trk->d0TJVA());
250 in_seq_map[
"d0SigTJVA"].push_back(trk->d0SigTJVA());
251 in_seq_map[
"z0sinthetaTJVA"].push_back(trk->z0sinthetaTJVA());
252 in_seq_map[
"z0sinthetaSigTJVA"].push_back(trk->z0sinthetaSigTJVA());
256 auto setNeutralPFOVars = [](
VectorMap &in_seq_map,
const PFOPtr &pfo) {
258 auto getAttr = std::bind(DMVar::pfoAttr<float>, pfo, std::placeholders::_1);
259 auto getAttrInt = std::bind(DMVar::pfoAttr<int>, pfo, std::placeholders::_1);
280 VectorMap &chrg_map = inputSeqMap.at(
"TauTrack");
283 for (
const auto &trk : vTauTracks)
285 setCommonP4Vars(chrg_map, trk->p4());
286 setTrackIPVars(chrg_map, trk);
290 VectorMap &neut_map = inputSeqMap.at(
"NeutralPFO");
293 for (
const auto &pfo : vNeutralPFOs)
295 setCommonP4Vars(neut_map, pfo->p4());
298 setNeutralPFOVars(neut_map, pfo);
302 ATH_MSG_ERROR(
"Error setting neutral PFO variables: " <<
e.what());
303 return StatusCode::FAILURE;
308 VectorMap &shot_map = inputSeqMap.at(
"ShotPFO");
310 for (
const auto &pfo : vShotPFOs)
312 setCommonP4Vars(shot_map, pfo->p4());
316 VectorMap &conv_map = inputSeqMap.at(
"ConvTrack");
319 for (
const auto &trk : vConvTracks)
321 setCommonP4Vars(conv_map, trk->p4());
322 setTrackIPVars(conv_map, trk);
325 return StatusCode::SUCCESS;
332 "dphiECal",
"detaECal",
"dphi",
"deta",
"pt_log",
"jetpt_log"};
335 "d0TJVA",
"d0SigTJVA",
"z0sinthetaTJVA",
"z0sinthetaSigTJVA"};
338 "FIRST_ETA",
"SECOND_R_log",
"DELTA_THETA",
"CENTER_LAMBDA_log",
"LONGITUDINAL",
"ENG_FRAC_CORE",
339 "SECOND_ENG_DENS_log",
"NPosECells_EM1",
"NPosECells_EM2",
"energy_EM1",
"energy_EM2",
"EM1CoreFrac",
340 "firstEtaWRTClusterPosition_EM1",
"firstEtaWRTClusterPosition_EM2",
341 "secondEtaWRTClusterPosition_EM1_log",
"secondEtaWRTClusterPosition_EM2_log"};
344 "1p0n",
"1p1n",
"1pXn",
"3p0n",
"3pXn"};
348 return p4_tau.DeltaPhi(p4);
353 return p4.Eta() - p4_tau.Eta();
365 return tau_etaTrkECal.second ? p4.Eta() - tau_etaTrkECal.first : 0.0f;
368 template <
typename T>
371 T
val{
static_cast<T
>(0)};
374 throw std::runtime_error(
"Can not retrieve PFO attribute! enum = " +
std::to_string(
static_cast<unsigned>(attr)));
384 template <
typename T>
387 auto cmp_pt = [](
const T lhs,
const T rhs) {
return lhs->pt() > rhs->pt(); };
388 std::sort(
vec.begin(),
vec.end(), cmp_pt);
389 if (
vec.size() > n_obj)
391 vec.erase(
vec.begin() + n_obj,
vec.end());
395 template <
typename T>
397 const std::set<std::string> &
keys)
std::map< std::string, VectorMap > InputSequenceMap
@ cellBased_firstEtaWRTClusterPosition_EM2
virtual StatusCode getInputs(const xAOD::TauJet &xTau, std::map< std::string, std::map< std::string, std::vector< double >>> &inputSeqMap) const
retrieve the input variables from a TauJet
@ cellBased_NPosECells_EM1
size_t nNeutralPFOs() const
Get the number of neutral PFO particles associated with this tau.
Gaudi::Property< bool > m_ensureTrackConsistency
Gaudi::Property< std::string > m_weightFile
bool attribute(PFODetails::PFOAttributes AttributeType, T &anAttribute) const
get a PFO Variable via enum
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual ~TauDecayModeNNClassifier()
std::vector< size_t > vec
@ cellBased_secondEtaWRTClusterPosition_EM1
__HOSTDEV__ double Phi_mpi_pi(double)
Gaudi::Property< std::string > m_probPrefix
std::map< std::string, ValueMap > InputMap
TauDecayModeNNClassifier(const std::string &name="TauDecayModeNNClassifier")
bool detail(TauJetParameters::TrackDetail detail, float &value) const
@ cellBased_firstEtaWRTClusterPosition_EM1
Gaudi::Property< std::size_t > m_maxNeutralPFOs
::StatusCode StatusCode
StatusCode definition for legacy code.
@ cellBased_secondEtaWRTClusterPosition_EM2
Class describing a tau jet.
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
@ cellBased_FIRST_ETA
These variables belong to the cell-based particle flow algorithm.
Class describing a particle flow object.
std::string to_string(const DetectorType &type)
Gaudi::Property< std::string > m_outputName
properties of the tool
@ cellBased_ENG_FRAC_CORE
std::map< std::string, std::vector< double > > VectorMap
GraphConfig parse_json_graph(std::istream &json)
Gaudi::Property< std::size_t > m_maxTauTracks
@ cellBased_SECOND_ENG_DENS
@ cellBased_NPosECells_EM2
Gaudi::Property< float > m_neutralPFOPtCut
std::unique_ptr< const lwt::LightweightGraph > m_lwtGraph
lwtnn graph
#define ATH_MSG_WARNING(x)
@ cellBased_CENTER_LAMBDA
virtual FourMom_t p4() const
The full 4-momentum of the particle.
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
const PFO * shotPFO(size_t i) const
Get the pointer to a given shot PFO associated with this tau.
Gaudi::Property< std::size_t > m_maxShotPFOs
Gaudi::Property< bool > m_decorateProb
float distance(const Amg::Vector3D &p1, const Amg::Vector3D &p2)
calculates the distance between two point in 3D space
std::map< std::string, double > ValueMap
const PFO * neutralPFO(size_t i) const
Get the pointer to a given neutral PFO associated with this tau.
virtual StatusCode initialize() override
Tool initializer.
std::vector< const TauTrack * > tracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Get the v<const pointer> to a given tauTrack collection associated with this tau.
Gaudi::Property< std::size_t > m_maxConvTracks
size_t nShotPFOs() const
Get the number of shot PFO particles associated with this tau.