25 ATH_MSG_ERROR(
"Cannot load both prong-inclusive and prong-dependent networks!");
26 return StatusCode::FAILURE;
39 if(!
m_net_0p)
return StatusCode::FAILURE;
44 if(!
m_net_1p)
return StatusCode::FAILURE;
50 if(!
m_net_2p)
return StatusCode::FAILURE;
55 if(!
m_net_3p)
return StatusCode::FAILURE;
83 ATH_MSG_ERROR(
"TauContainerName and HitsHandleKey must be provided to read hits for the GNN evaluation");
84 return StatusCode::FAILURE;
87 return StatusCode::SUCCESS;
92 if(network_file.empty())
return nullptr;
94 const std::string pr_network_file =
find_file(network_file);
95 if(pr_network_file.empty()) {
96 ATH_MSG_ERROR(
"Could not find network weights: " << network_file);
100 ATH_MSG_INFO(
"Using network config: " << pr_network_file);
104 config.nnFile = pr_network_file;
120 std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(
config);
134 out_ptau(tau) = -1111.0f;
135 out_pjet(tau) = -1111.0f;
139 return StatusCode::SUCCESS;
144 if (tau.
nTracks()>5)
return StatusCode::SUCCESS;
149 if (tau.
nTracks()!=1 && tau.
nTracks()!=3)
return StatusCode::SUCCESS;
153 std::map<std::string, float> out_f;
154 std::map<std::string, std::vector<char>> out_vc;
155 std::map<std::string, std::vector<float>> out_vf;
169 n_tracks = std::count_if(trks.begin(), trks.end(),
174 if(n_tracks == 0 &&
m_net_0p) std::tie(out_f, out_vc, out_vf) =
m_net_0p->compute(tau);
175 else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) =
m_net_1p->compute(tau);
176 else if(n_tracks == 2) {
178 else std::tie(out_f, out_vc, out_vf) =
m_net_3p->compute(tau);
179 }
else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) =
m_net_3p->compute(tau);
194 return StatusCode::SUCCESS;
#define ATH_CHECK
Evaluate an expression and check for errors.
Helper class to provide type-safe access to aux data.
Helper class to provide type-safe access to aux data.
Gaudi::Property< std::string > m_weightfile_inclusive
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
Gaudi::Property< int > m_max_tracks
Gaudi::Property< std::string > m_input_layer_hits
std::unique_ptr< TauGNN > m_net_1p
Gaudi::Property< std::string > m_input_layer_scalar
std::string m_hits_decor_name
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
virtual ~TauGNNEvaluator()
Gaudi::Property< float > m_max_cluster_dr
std::unique_ptr< TauGNN > m_net_3p
Gaudi::Property< float > m_minTauPt
std::unique_ptr< TauGNN > load_network(const std::string &network_file) const
Gaudi::Property< int > m_output_discriminant
Gaudi::Property< std::string > m_input_layer_clusters
Gaudi::Property< int > m_max_clusters
Gaudi::Property< float > m_min_prong_track_pt
Gaudi::Property< std::string > m_outnode_tau
std::unique_ptr< TauGNN > m_net_0p
Gaudi::Property< std::string > m_tauContainerName
Gaudi::Property< bool > m_doVertexCorrection
Gaudi::Property< int > m_max_hits
Gaudi::Property< bool > m_applyTightTrackSel
Gaudi::Property< std::string > m_output_varname
Gaudi::Property< std::string > m_output_pjet
Gaudi::Property< bool > m_useTRT
std::unique_ptr< TauGNN > m_net_inclusive
std::unique_ptr< TauGNN > m_net_2p
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_pJetHandleKey
Gaudi::Property< bool > m_doTrackClassification
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_pTauHandleKey
Gaudi::Property< std::string > m_outnode_jet
Gaudi::Property< std::string > m_output_ptau
Gaudi::Property< std::string > m_weightfile_1p
Gaudi::Property< std::string > m_weightfile_3p
SG::ReadDecorHandleKey< xAOD::TauJetContainer > m_hitsHandleKey
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.
Gaudi::Property< bool > m_applyLooseTrackSel
virtual StatusCode initialize() override
Tool initializer.
Gaudi::Property< std::string > m_weightfile_2p
Gaudi::Property< std::string > m_input_layer_tracks
Gaudi::Property< std::string > m_weightfile_0p
virtual double pt() const
The transverse momentum ( ) of the particle.
size_t nTracksCharged() const
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
TauTrack_v1 TauTrack
Definition of the current version.
TauJet_v3 TauJet
Definition of the current "tau version".