Loading [MathJax]/extensions/tex2jax.js
 |
ATLAS Offline Software
|
Go to the documentation of this file.
22 using ValueMap = std::map<std::string, double>;
23 using VectorMap = std::map<std::string, std::vector<double>>;
24 using InputMap = std::map<std::string, ValueMap>;
52 if (weightFile.empty())
55 return StatusCode::FAILURE;
57 ATH_MSG_INFO(
"Loaded network configuration from: " << weightFile);
61 lwt::GraphConfig lwtGraphConfig;
66 catch (
const std::logic_error &
e)
69 return StatusCode::FAILURE;
75 m_lwtGraph = std::make_unique<lwt::LightweightGraph>(lwtGraphConfig, lwtGraphConfig.outputs.cbegin()->first);
77 catch (
const lwt::NNConfigurationException &
e)
80 return StatusCode::FAILURE;
83 return StatusCode::SUCCESS;
95 std::set<std::string>
branches = {
"TauTrack",
"NeutralPFO",
"ShotPFO",
"ConvTrack"};
113 return StatusCode::FAILURE;
121 std::array<float, DMVar::nClasses> probs;
123 std::string
prefix =
"c_";
124 for (std::size_t
i = 0;
i < probs.size(); ++
i)
133 std::array<float, DMVar::nClasses>::const_iterator itMax;
137 itMax = std::max_element(probs.cbegin(), probs.cbegin() + 3);
142 itMax = std::max_element(probs.cbegin() + 3, probs.cend());
147 itMax = std::max_element(probs.cbegin(), probs.cend());
155 for (std::size_t
i = 0;
i < probs.size(); ++
i)
159 accProb(xTau) = probs[
i];
163 return StatusCode::SUCCESS;
168 std::vector<TrkPtr> vTauTracks;
169 std::vector<PFOPtr> vNeutralPFOs;
170 std::vector<PFOPtr> vShotPFOs;
171 std::vector<TrkPtr> vConvTracks;
186 vNeutralPFOs.push_back(pfo);
201 ATH_MSG_ERROR(
"Error retrieving tauShots_nPhotons: " <<
e.what());
202 return StatusCode::FAILURE;
206 vShotPFOs.push_back(pfo);
224 std::pair<float, bool> tau_etaTrkECal{0.,
false};
225 std::pair<float, bool> tau_phiTrkECal{0.,
false};
235 tau_phiTrkECal.second =
true;
243 tau_etaTrkECal.second =
true;
248 auto setCommonP4Vars = [&tau_p4, &tau_etaTrkECal, &tau_phiTrkECal](
VectorMap &in_seq_map,
const TLorentzVector &obj_p4) {
259 in_seq_map[
"d0TJVA"].push_back(trk->d0TJVA());
260 in_seq_map[
"d0SigTJVA"].push_back(trk->d0SigTJVA());
261 in_seq_map[
"z0sinthetaTJVA"].push_back(trk->z0sinthetaTJVA());
262 in_seq_map[
"z0sinthetaSigTJVA"].push_back(trk->z0sinthetaSigTJVA());
266 auto setNeutralPFOVars = [](
VectorMap &in_seq_map,
const PFOPtr &pfo) {
268 auto getAttr = std::bind(DMVar::pfoAttr<float>, pfo, std::placeholders::_1);
269 auto getAttrInt = std::bind(DMVar::pfoAttr<int>, pfo, std::placeholders::_1);
290 VectorMap &chrg_map = inputSeqMap.at(
"TauTrack");
293 for (
const auto &trk : vTauTracks)
295 setCommonP4Vars(chrg_map, trk->p4());
296 setTrackIPVars(chrg_map, trk);
300 VectorMap &neut_map = inputSeqMap.at(
"NeutralPFO");
303 for (
const auto &pfo : vNeutralPFOs)
305 setCommonP4Vars(neut_map, pfo->p4());
308 setNeutralPFOVars(neut_map, pfo);
312 ATH_MSG_ERROR(
"Error setting neutral PFO variables: " <<
e.what());
313 return StatusCode::FAILURE;
318 VectorMap &shot_map = inputSeqMap.at(
"ShotPFO");
320 for (
const auto &pfo : vShotPFOs)
322 setCommonP4Vars(shot_map, pfo->p4());
326 VectorMap &conv_map = inputSeqMap.at(
"ConvTrack");
329 for (
const auto &trk : vConvTracks)
331 setCommonP4Vars(conv_map, trk->p4());
332 setTrackIPVars(conv_map, trk);
335 return StatusCode::SUCCESS;
342 "dphiECal",
"detaECal",
"dphi",
"deta",
"pt_log",
"jetpt_log"};
345 "d0TJVA",
"d0SigTJVA",
"z0sinthetaTJVA",
"z0sinthetaSigTJVA"};
348 "FIRST_ETA",
"SECOND_R_log",
"DELTA_THETA",
"CENTER_LAMBDA_log",
"LONGITUDINAL",
"ENG_FRAC_CORE",
349 "SECOND_ENG_DENS_log",
"NPosECells_EM1",
"NPosECells_EM2",
"energy_EM1",
"energy_EM2",
"EM1CoreFrac",
350 "firstEtaWRTClusterPosition_EM1",
"firstEtaWRTClusterPosition_EM2",
351 "secondEtaWRTClusterPosition_EM1_log",
"secondEtaWRTClusterPosition_EM2_log"};
354 "1p0n",
"1p1n",
"1pXn",
"3p0n",
"3pXn"};
358 return p4_tau.DeltaPhi(p4);
363 return p4.Eta() - p4_tau.Eta();
375 return tau_etaTrkECal.second ? p4.Eta() - tau_etaTrkECal.first : 0.0f;
378 template <
typename T>
381 T
val{
static_cast<T
>(0)};
384 throw std::runtime_error(
"Can not retrieve PFO attribute! enum = " +
std::to_string(
static_cast<unsigned>(attr)));
392 return clus0pt > 0.0f ? (clus0pt - pfo->
pt()) / clus0pt : 0.0
f;
397 float clus0e = pfo->
cluster(0)->
e();
398 return clus0e > 0.0f ? energy_em2 / clus0e : 0.0f;
406 template <
typename T>
409 auto cmp_pt = [](
const T lhs,
const T rhs) {
return lhs->pt() > rhs->pt(); };
410 std::sort(
vec.begin(),
vec.end(), cmp_pt);
411 if (
vec.size() > n_obj)
413 vec.erase(
vec.begin() + n_obj,
vec.end());
417 template <
typename T>
419 const std::set<std::string> &
keys)
std::map< std::string, VectorMap > InputSequenceMap
std::size_t m_maxConvTracks
@ cellBased_firstEtaWRTClusterPosition_EM2
virtual StatusCode getInputs(const xAOD::TauJet &xTau, std::map< std::string, std::map< std::string, std::vector< double >>> &inputSeqMap) const
retrieve the input variables from a TauJet
@ cellBased_NPosECells_EM1
virtual double pt() const
The transverse momentum ( ) of the particle.
size_t nNeutralPFOs() const
Get the number of neutral PFO particles associated with this tau.
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
bool attribute(PFODetails::PFOAttributes AttributeType, T &anAttribute) const
get a PFO Variable via enum
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
virtual ~TauDecayModeNNClassifier()
std::vector< size_t > vec
@ cellBased_secondEtaWRTClusterPosition_EM1
__HOSTDEV__ double Phi_mpi_pi(double)
std::map< std::string, ValueMap > InputMap
TauDecayModeNNClassifier(const std::string &name="TauDecayModeNNClassifier")
bool detail(TauJetParameters::TrackDetail detail, float &value) const
std::size_t m_maxShotPFOs
@ cellBased_firstEtaWRTClusterPosition_EM1
::StatusCode StatusCode
StatusCode definition for legacy code.
@ cellBased_secondEtaWRTClusterPosition_EM2
Class describing a tau jet.
bool m_ensureTrackConsistency
const TauTrack * track(size_t i, TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged, int *container_index=0) const
Get the pointer to a given tauTrack associated with this tau /*container index needed by trackNonCons...
@ cellBased_FIRST_ETA
These variables belong to the cell-based particle flow algorithm.
virtual double pt() const
The transverse momentum ( ) of the particle (negative for negative-energy clusters)
Class describing a particle flow object.
std::size_t m_maxTauTracks
std::string to_string(const DetectorType &type)
@ cellBased_ENG_FRAC_CORE
std::map< std::string, std::vector< double > > VectorMap
GraphConfig parse_json_graph(std::istream &json)
@ cellBased_SECOND_ENG_DENS
@ cellBased_NPosECells_EM2
std::unique_ptr< const lwt::LightweightGraph > m_lwtGraph
lwtnn graph
#define ATH_MSG_WARNING(x)
@ cellBased_CENTER_LAMBDA
virtual FourMom_t p4() const
The full 4-momentum of the particle.
std::string m_outputName
properties of the tool
virtual StatusCode execute(xAOD::TauJet &xTau) const override
Execute - called for each tau candidate.
const CaloCluster * cluster(unsigned int index) const
Retrieve a const pointer to a CaloCluster.
const PFO * shotPFO(size_t i) const
Get the pointer to a given shot PFO associated with this tau.
std::size_t m_maxNeutralPFOs
float distance(const Amg::Vector3D &p1, const Amg::Vector3D &p2)
calculates the distance between two point in 3D space
std::map< std::string, double > ValueMap
const PFO * neutralPFO(size_t i) const
Get the pointer to a given neutral PFO associated with this tau.
virtual StatusCode initialize() override
Tool initializer.
std::vector< const TauTrack * > tracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Get the v<const pointer> to a given tauTrack collection associated with this tau.
virtual double e() const
The total energy of the particle.
size_t nShotPFOs() const
Get the number of shot PFO particles associated with this tau.