28 ATH_MSG_INFO(
"Initializing DiTauOnnxDiscriminantTool" );
31 if (model_path.empty()) {
33 return StatusCode::FAILURE;
35 m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"OnnxUtil");
36 Ort::SessionOptions session_options;
37 session_options.SetIntraOpNumThreads(1);
38 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
60 return StatusCode::SUCCESS;
69 omni_scoreDec(xDiTau) =
score;
70 return StatusCode::SUCCESS;
74 if (std::isnan(
value))
75 return nan_replacement;
76 if (
value == std::numeric_limits<float>::infinity())
77 return posinf_replacement;
78 if (
value == -std::numeric_limits<float>::infinity())
79 return neginf_replacement;
85 flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
86 for (
const auto &inner : vec_2d) {
94 points.reserve(track_features.size() * 2);
95 for (
const auto &
track : track_features) {
103 std::vector<float>
mask;
104 mask.reserve(track_features.size());
106 return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
112 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
113 return Ort::Value::CreateTensor<float>(memory_info,
data.data(),
data.size(),shape.data(), shape.size());
117 std::vector<Ort::Value> input_tensors;
125 std::vector<const char *> input_node_names;
129 std::vector<const char *> output_node_names;
133 auto output_tensors =
m_ort_session->Run(Ort::RunOptions{
nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
136 for (
size_t i = 0;
i < output_tensors.size(); ++
i) {
137 const auto &tensor = output_tensors[
i];
138 const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
139 const float *
data = tensor.GetTensorData<
float>();
140 (
i == 0 ? output.output_1 : output.output_2) = std::vector<float>(
data,
data +
length);
172 std::vector<float> jet_vars = {
173 R_max_leadDec (ditau),
174 R_max_subleadDec (ditau),
175 R_tracks_sublDec (ditau),
176 R_isotracDec (ditau),
177 d0_leadtrack_leadDec (ditau),
178 d0_leadtrack_sublDec (ditau),
179 f_core_leadDec (ditau),
180 f_core_sublDec (ditau),
181 f_subjet_sublDec (ditau),
183 f_isotracks_Dec (ditau),
184 M_core_leadDec (ditau),
185 M_core_sublDec (ditau),
186 M_tracks_leadDec (ditau),
187 static_cast<float>( n_trackDec (ditau)),
189 std::vector<int64_t> jet_shape = {1,
static_cast<int64_t
>(jet_vars.size())};
192 std::vector<std::vector<float>> track_features(
m_maxTracks, std::vector<float>(11, 0.0
f));
194 float jet_eta = ditau.
eta();
198 for (
size_t i = 0;
i < num_tracks; ++
i) {
200 if (!trackLink.
isValid())
continue;
202 float track_eta = xTrack->
eta();
203 float track_phi = xTrack->
phi();
204 float delta_eta = track_eta - jet_eta;
207 float track_pt =
static_cast<float>(xTrack->
pt());
209 float jet_pt = ditau_ptDec(ditau);
210 float pt_ratio = track_pt / jet_pt;
211 float pt_ratio_log =
std::log(1.0
f - pt_ratio + 1
e-8
f);
212 float track_charge = xTrack->
charge();
213 float pt_ratio_log_nan_less =
nan_to_num(pt_ratio_log, 0.0
f, 0.0
f, 0.0
f);
215 track_features[
i] = {
220 pt_ratio_log_nan_less,
223 static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
224 static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
225 static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
229 std::vector<int64_t> track_shape = {1,
static_cast<int64_t
>(
m_maxTracks), 11};
236 {1, track_shape[1], 2},
240 std::move(jet_shape),
245 return output.output_1[1];