15 m_net_inclusive(nullptr),
16 m_net_0p(nullptr), m_net_1p(nullptr), m_net_2p(nullptr), m_net_3p(nullptr) {
28 "Discriminant used to calculate the output score: 0 -> -log(PJet), 1 -> PTau");
67 ATH_MSG_ERROR(
"Cannot load both prong-inclusive and prong-dependent networks!");
68 return StatusCode::FAILURE;
81 if(!
m_net_0p)
return StatusCode::FAILURE;
86 if(!
m_net_1p)
return StatusCode::FAILURE;
92 if(!
m_net_2p)
return StatusCode::FAILURE;
97 if(!
m_net_3p)
return StatusCode::FAILURE;
100 if(m_output_discriminant < Discriminant::NegLogPJet || m_output_discriminant > Discriminant::PTau) {
104 return StatusCode::SUCCESS;
109 if(network_file.empty())
return nullptr;
111 const std::string pr_network_file =
find_file(network_file);
112 if(pr_network_file.empty()) {
113 ATH_MSG_ERROR(
"Could not find network weights: " << network_file);
117 ATH_MSG_INFO(
"Using network config: " << pr_network_file);
120 std::unique_ptr<TauGNN> net = std::make_unique<TauGNN>(pr_network_file,
config);
134 out_ptau(tau) = -1111.0f;
135 out_pjet(tau) = -1111.0f;
139 return StatusCode::SUCCESS;
144 std::vector<const xAOD::TauTrack *> tracks;
147 std::vector<xAOD::CaloVertexedTopoCluster>
clusters;
153 std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
156 std::map<std::string, float> out_f;
157 std::map<std::string, std::vector<char>> out_vc;
158 std::map<std::string, std::vector<float>> out_vf;
175 else if(n_tracks == 1) std::tie(out_f, out_vc, out_vf) =
m_net_1p->compute(tau, trackVec,
clusters);
176 else if(n_tracks == 2) {
178 else std::tie(out_f, out_vc, out_vf) =
m_net_3p->compute(tau, trackVec,
clusters);
179 }
else if(n_tracks == 3) std::tie(out_f, out_vc, out_vf) =
m_net_3p->compute(tau, trackVec,
clusters);
194 for(
size_t i = 0;
i < tracks.size();
i++) {
195 if(
i < out_vc.at(
"track_class").size()) out_trkclass(*tracks.at(
i)) = out_vc.at(
"track_class").at(
i);
196 else out_trkclass(*tracks.at(
i)) =
'9';
201 return StatusCode::SUCCESS;
206 std::vector<const xAOD::TauTrack*> tracks = tau.
allTracks();
214 while(
it != tracks.end()) {
216 it = tracks.erase(
it);
226 return lhs->
pt() > rhs->pt();
228 std::sort(tracks.begin(), tracks.end(), cmp_pt);
229 out = std::move(tracks);
231 return StatusCode::SUCCESS;
239 TLorentzVector clusterP4 = vertexedCluster.p4();
242 clusters.push_back(vertexedCluster);
248 return lhs.
p4().Et() > rhs.p4().Et();
257 return StatusCode::SUCCESS;