8 #include "AthLinks/ElementLink.h"
23 declareInterface<DiTauToolBase > (
this);
32 ATH_MSG_INFO(
"Initializing DiTauOnnxDiscriminantTool" );
36 if (model_path.empty()) {
38 return StatusCode::FAILURE;
40 m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"OnnxUtil");
41 Ort::SessionOptions session_options;
42 session_options.SetIntraOpNumThreads(1);
43 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
44 session_options.DisableCpuMemArena();
46 return StatusCode::SUCCESS;
54 return StatusCode::SUCCESS;
64 omni_scoreDec(*xDitau) =
score;
65 return StatusCode::SUCCESS;
74 omni_scoreDec(xDiTau) =
score;
75 return StatusCode::SUCCESS;
80 flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
81 for (
const auto &inner : vec_2d) {
89 points.reserve(track_features.size() * 2);
90 for (
const auto &
track : track_features) {
98 std::vector<float>
mask;
99 mask.reserve(track_features.size());
101 return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
107 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108 return Ort::Value::CreateTensor<float>(memory_info,
data.data(),
data.size(),shape.data(), shape.size());
112 std::vector<Ort::Value> input_tensors;
120 std::vector<const char *> input_node_names;
124 std::vector<const char *> output_node_names;
128 auto output_tensors =
m_ort_session->Run(Ort::RunOptions{
nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
131 for (
size_t i = 0;
i < output_tensors.size(); ++
i) {
132 const auto &tensor = output_tensors[
i];
133 const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
134 const float *
data = tensor.GetTensorData<
float>();
159 std::vector<float> jet_vars = {
160 R_max(ditau, ditauTrackingInfo, 0),
161 R_max(ditau, ditauTrackingInfo, 1),
162 R_tracks(ditau, ditauTrackingInfo, 1),
174 static_cast<float>(
n_track(ditau)),
176 std::vector<int64_t> jet_shape = {1,
static_cast<int64_t
>(jet_vars.size())};
179 std::vector<std::vector<float>> track_features(
m_maxTracks, std::vector<float>(11, 0.0
f));
181 float jet_eta = ditau.
eta();
185 for (
size_t i = 0;
i < num_tracks; ++
i) {
187 if (!trackLink.
isValid())
continue;
189 float track_eta = xTrack->
eta();
190 float track_phi = xTrack->
phi();
191 float delta_eta = track_eta - jet_eta;
194 float track_pt =
static_cast<float>(xTrack->
pt());
197 float pt_ratio = track_pt / jet_pt;
198 float pt_ratio_log = (pt_ratio <= 1.0f) ?
std::log(1.0
f - pt_ratio + 1
e-8
f) : 0.0f;
199 float track_charge = xTrack->
charge();
201 track_features[
i] = {
209 static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
210 static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
211 static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
215 std::vector<int64_t> track_shape = {1,
static_cast<int64_t
>(
m_maxTracks), 11};
222 {1, track_shape[1], 2},
226 std::move(jet_shape),
231 return output.output_1[1];
237 while (xDiTau.
subjetPt(nSubjet) > 0. ){
250 return xDiTau.
fCore(iSubjet);
254 return xDiTau.
subjetPt(iSubjet) / xDiTau.
pt();
267 if (subjetInfo.
subjet_p4.DeltaR(xTrack->
p4()) > Rmax) {
282 for (
int i = 0;
i < 2;
i++) {
285 R_sum += subjetInfo.
subjet_p4.DeltaR(xTrack->
p4()) * xTrack->
pt();
301 R_sum += subjetInfo.
subjet_p4.DeltaR(xTrack->
p4()) * xTrack->
pt();
311 TLorentzVector allCoreTracks_p4;
314 allCoreTracks_p4 += xTrack->
p4();
316 float mass = allCoreTracks_p4.M();
324 TLorentzVector allTracks_p4;
327 allTracks_p4 += xTrack->
p4();
329 float mass = allTracks_p4.M();
347 iso_pt += xTrack->
pt();
349 if( xDiTau.
pt() == 0.){
352 return iso_pt / xDiTau.
pt();
364 ATH_MSG_WARNING(
"Track " << (!trackLinksAcc.
isAvailable(xDiTau) ?
"DiTauJet.trackLinks" :
"DiTauJet.isoTrackLinks") <<
" links not available.");
365 return StatusCode::FAILURE;
369 float Rsubjet = R_subjetAcc(xDiTau);
370 float RCore = R_coreAcc(xDiTau);
378 std::vector<ElementLink<xAOD::TrackParticleContainer>> isoTrackLinks = xDiTau.
isoTrackLinks();
379 for (
const auto &trackLink: isoTrackLinks) {
380 if (!trackLink.isValid()) {
387 std::vector<ElementLink<xAOD::TrackParticleContainer>> trackLinks = xDiTau.
trackLinks();
388 for (
const auto &trackLink : trackLinks) {
389 if (!trackLink.isValid()) {
394 trackingInfo.
vTracks.push_back(xTrack);
397 for (
int i=0;
i<nSubjets; ++
i){
399 TLorentzVector subjet_p4 = TLorentzVector();
401 subjetTrackingInfo.
subjet_p4 = subjet_p4;
402 trackingInfo.
vSubjetInfo.push_back(subjetTrackingInfo);
407 for (
int i=0;
i<nSubjets; ++
i){
409 if (dRTrackSubjet < Rsubjet && dRTrackSubjet < dRMin){
410 dRMin = dRTrackSubjet;
419 for (
int i=0;
i<nSubjets; ++
i){
420 float ptLeadTrack = 0;
422 if (
track->pt() > ptLeadTrack){
423 ptLeadTrack =
track->pt();
429 for (
int i=0;
i<nSubjets; ++
i){
432 if (subjetTrackingInfo.subjet_p4.DeltaR(
track->p4()) < RCore){
442 for (
int i=0;
i<nSubjets; ++
i){
444 if (dRTrackSubjet > Rsubjet && dRTrackSubjet < RIso && dRTrackSubjet < dRMin){
445 dRMin = dRTrackSubjet;
453 return StatusCode::SUCCESS;