 |
ATLAS Offline Software
|
Go to the documentation of this file.
25 ATH_MSG_ERROR(
"Cannot load both prong-inclusive and prong-dependent networks!");
26 return StatusCode::FAILURE;
39 if(!
m_net_0p)
return StatusCode::FAILURE;
44 if(!
m_net_1p)
return StatusCode::FAILURE;
50 if(!
m_net_2p)
return StatusCode::FAILURE;
55 if(!
m_net_3p)
return StatusCode::FAILURE;
58 if(m_output_discriminant < Discriminant::NegLogPJet || m_output_discriminant > Discriminant::PTau) {
67 return StatusCode::SUCCESS;
72 if(network_file.empty())
return nullptr;
74 const std::string pr_network_file =
find_file(network_file);
75 if(pr_network_file.empty()) {
76 ATH_MSG_ERROR(
"Could not find network weights: " << network_file);
80 ATH_MSG_INFO(
"Using network config: " << pr_network_file);
84 config.nnFile = pr_network_file;
97 std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(
config);
111 out_ptau(tau) = -1111.0f;
112 out_pjet(tau) = -1111.0f;
116 return StatusCode::SUCCESS;
121 if (tau.
nTracks()>5)
return StatusCode::SUCCESS;
126 if (tau.
nTracks()!=1 && tau.
nTracks()!=3)
return StatusCode::SUCCESS;
130 std::map<std::string, float> out_f;
131 std::map<std::string, std::vector<char>> out_vc;
132 std::map<std::string, std::vector<float>> out_vf;
146 n_tracks = std::count_if(trks.begin(), trks.end(),
151 if(n_tracks == 0 &&
m_net_0p) std::tie(out_f, out_vc, out_vf) =
m_net_0p->compute(tau);
152 else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) =
m_net_1p->compute(tau);
153 else if(n_tracks == 2) {
155 else std::tie(out_f, out_vc, out_vf) =
m_net_3p->compute(tau);
156 }
else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) =
m_net_3p->compute(tau);
171 return StatusCode::SUCCESS;
Gaudi::Property< bool > m_doTrackClassification
Gaudi::Property< std::string > m_output_pjet
Gaudi::Property< bool > m_applyLooseTrackSel
std::unique_ptr< TauGNN > load_network(const std::string &network_file) const
size_t nTracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
std::unique_ptr< TauGNN > m_net_3p
Gaudi::Property< std::string > m_weightfile_inclusive
Gaudi::Property< int > m_max_clusters
std::unique_ptr< TauGNN > m_net_1p
Gaudi::Property< std::string > m_weightfile_3p
virtual StatusCode initialize() override
Tool initializer.
virtual double pt() const
The transverse momentum ( ) of the particle.
Gaudi::Property< bool > m_applyTightTrackSel
Gaudi::Property< float > m_max_cluster_dr
size_t nTracksCharged() const
std::unique_ptr< TauGNN > m_net_0p
Gaudi::Property< bool > m_useTRT
std::unique_ptr< TauGNN > m_net_2p
::StatusCode StatusCode
StatusCode definition for legacy code.
Gaudi::Property< std::string > m_weightfile_2p
Class describing a tau jet.
Gaudi::Property< unsigned int > m_output_discriminant
Gaudi::Property< std::string > m_tauContainerName
Gaudi::Property< int > m_max_tracks
std::unique_ptr< TauGNN > m_net_inclusive
Gaudi::Property< float > m_minTauPt
Gaudi::Property< std::string > m_outnode_tau
Gaudi::Property< std::string > m_input_layer_tracks
Gaudi::Property< std::string > m_input_layer_scalar
Gaudi::Property< float > m_min_prong_track_pt
Gaudi::Property< bool > m_doVertexCorrection
Gaudi::Property< std::string > m_outnode_jet
Gaudi::Property< std::string > m_weightfile_0p
Gaudi::Property< std::string > m_weightfile_1p
virtual ~TauGNNEvaluator()
Gaudi::Property< std::string > m_input_layer_clusters
TauGNNEvaluator(const std::string &name="TauGNNEvaluator")
SG::WriteDecorHandleKey< xAOD::TauJetContainer > m_scoreHandleKey
std::vector< const TauTrack * > allTracks() const
Get the v<const pointer> to all tracks associated with this tau, regardless of classification.
Gaudi::Property< std::string > m_output_ptau
Gaudi::Property< std::string > m_output_varname
virtual StatusCode execute(xAOD::TauJet &tau) const override
Execute - called for each tau candidate.