8 #include "AthLinks/ElementLink.h"
29 ATH_MSG_INFO(
"Initializing DiTauOnnxDiscriminantTool" );
32 if (model_path.empty()) {
34 return StatusCode::FAILURE;
36 m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"OnnxUtil");
37 Ort::SessionOptions session_options;
38 session_options.SetIntraOpNumThreads(1);
39 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
61 return StatusCode::SUCCESS;
70 omni_scoreDec(xDiTau) =
score;
71 return StatusCode::SUCCESS;
75 if (std::isnan(
value))
76 return nan_replacement;
77 if (
value == std::numeric_limits<float>::infinity())
78 return posinf_replacement;
79 if (
value == -std::numeric_limits<float>::infinity())
80 return neginf_replacement;
86 flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
87 for (
const auto &inner : vec_2d) {
95 points.reserve(track_features.size() * 2);
96 for (
const auto &
track : track_features) {
104 std::vector<float>
mask;
105 mask.reserve(track_features.size());
107 return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
113 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
114 return Ort::Value::CreateTensor<float>(memory_info,
data.data(),
data.size(),shape.data(), shape.size());
118 std::vector<Ort::Value> input_tensors;
126 std::vector<const char *> input_node_names;
130 std::vector<const char *> output_node_names;
134 auto output_tensors =
m_ort_session->Run(Ort::RunOptions{
nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
137 for (
size_t i = 0;
i < output_tensors.size(); ++
i) {
138 const auto &tensor = output_tensors[
i];
139 const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
140 const float *
data = tensor.GetTensorData<
float>();
141 (
i == 0 ? output.output_1 : output.output_2) = std::vector<float>(
data,
data +
length);
173 std::vector<float> jet_vars = {
174 R_max_leadDec (ditau),
175 R_max_subleadDec (ditau),
176 R_tracks_sublDec (ditau),
177 R_isotracDec (ditau),
178 d0_leadtrack_leadDec (ditau),
179 d0_leadtrack_sublDec (ditau),
180 f_core_leadDec (ditau),
181 f_core_sublDec (ditau),
182 f_subjet_sublDec (ditau),
184 f_isotracks_Dec (ditau),
185 M_core_leadDec (ditau),
186 M_core_sublDec (ditau),
187 M_tracks_leadDec (ditau),
188 static_cast<float>( n_trackDec (ditau)),
190 std::vector<int64_t> jet_shape = {1,
static_cast<int64_t
>(jet_vars.size())};
193 std::vector<std::vector<float>> track_features(
m_maxTracks, std::vector<float>(11, 0.0
f));
195 float jet_eta = ditau.
eta();
199 for (
size_t i = 0;
i < num_tracks; ++
i) {
201 if (!trackLink.
isValid())
continue;
203 float track_eta = xTrack->
eta();
204 float track_phi = xTrack->
phi();
205 float delta_eta = track_eta - jet_eta;
208 float track_pt =
static_cast<float>(xTrack->
pt());
210 float jet_pt = ditau_ptDec(ditau);
211 float pt_ratio = track_pt / jet_pt;
212 float pt_ratio_log =
std::log(1.0
f - pt_ratio + 1
e-8
f);
213 float track_charge = xTrack->
charge();
214 float pt_ratio_log_nan_less =
nan_to_num(pt_ratio_log, 0.0
f, 0.0
f, 0.0
f);
216 track_features[
i] = {
221 pt_ratio_log_nan_less,
224 static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
225 static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
226 static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
230 std::vector<int64_t> track_shape = {1,
static_cast<int64_t
>(
m_maxTracks), 11};
237 {1, track_shape[1], 2},
241 std::move(jet_shape),
246 return output.output_1[1];