92 if (!jets.isValid()) {
94 return StatusCode::FAILURE;
100 Ort::AllocatorWithDefaultOptions allocator;
102 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108 validDecor(*jptr) = 0;
112 std::vector<float> x_flat;
113 std::vector<int64_t> edge_index_flat;
114 std::vector<int64_t> batch_vec;
115 std::vector<int64_t> counts_vec;
116 std::vector<float> Ntrk_vec;
119 jet, x_flat, edge_index_flat, batch_vec, counts_vec, Ntrk_vec
127 const int64_t num_nodes = batch_vec.size();
129 if (num_nodes == 0 || x_flat.size() !=
static_cast<size_t>(num_nodes * 3)) {
137 const int64_t min_nodes = 2;
139 if (num_nodes < min_nodes) {
144 const int64_t num_edges = edge_index_flat.size() / 2;
151 std::vector<int64_t> x_shape = { num_nodes, 3 };
152 Ort::Value x_tensor =
153 Ort::Value::CreateTensor<float>(
162 std::vector<int64_t> ei_shape = {
164 static_cast<int64_t
>(edge_index_flat.size() / 2)
166 Ort::Value edge_tensor =
167 Ort::Value::CreateTensor<int64_t>(
169 edge_index_flat.data(),
170 edge_index_flat.size(),
176 std::vector<int64_t> batch_shape = { num_nodes };
177 Ort::Value batch_tensor =
178 Ort::Value::CreateTensor<int64_t>(
187 std::vector<int64_t> ntrk_shape = {
static_cast<int64_t
>(Ntrk_vec.size())};
188 Ort::Value ntrk_tensor =
189 Ort::Value::CreateTensor<float>(
196 std::array<Ort::Value, 4> input_tensors = {
198 std::move(edge_tensor),
199 std::move(batch_tensor),
200 std::move(ntrk_tensor)
203 std::array<const char*, 4> input_names = {
210 std::array<const char*, 1> output_names = {
217 std::vector<Ort::Value> output_tensors;
220 Ort::RunOptions{
nullptr},
222 input_tensors.data(),
223 input_tensors.size(),
228 catch (
const Ort::Exception& e) {
236 if (output_tensors.empty() || !output_tensors.front().IsTensor()) {
243 output_tensors.front().GetTensorMutableData<
float>();
245 float score = out_data[0];
247 scoreDecor(*jptr) = score;
248 validDecor(*jptr) = 1;
250 ATH_MSG_DEBUG(
"Jet decorated with LundNet score = " << score);
253 return StatusCode::SUCCESS;
256 std::vector<float>& out_x_float,
257 std::vector<int64_t>& out_edge_index_int64,
258 std::vector<int64_t>& out_batch_int64,
259 std::vector<int64_t>& out_counts_int64,
260 std::vector<float>& out_Ntrk_float)
const {
263 out_edge_index_int64.clear();
264 out_batch_int64.clear();
265 out_counts_int64.clear();
266 out_Ntrk_float.clear();
268 auto tryGetVectorFloat = [&](
const std::string& baseName, std::vector<float>& out)->
bool {
269 std::string withPref =
m_prefix.value() + baseName;
270 if (
jet.getAttribute(withPref, out)) {
274 if (
jet.getAttribute(baseName, out)) {
278 ATH_MSG_DEBUG(
"Attribute not found: " << withPref <<
" nor " << baseName);
282 auto tryGetVectorInt = [&](
const std::string& baseName, std::vector<int>& out)->
bool {
283 std::string withPref =
m_prefix.value() + baseName;
284 if (
jet.getAttribute(withPref, out)) {
288 if (
jet.getAttribute(baseName, out)) {
292 ATH_MSG_DEBUG(
"Attribute not found: " << withPref <<
" nor " << baseName);
296 auto tryGetInt = [&](
const std::string& baseName,
int& out)->
bool {
297 std::string withPref =
m_prefix.value() + baseName;
298 if (
jet.getAttribute(withPref, out)) {
302 if (
jet.getAttribute(baseName, out)) {
306 ATH_MSG_DEBUG(
"Attribute not found: " << withPref <<
" nor " << baseName);
311 std::vector<float> lnR, lnkT,
z;
312 std::vector<int> idp1, idp2;
315 bool ok_lnR = tryGetVectorFloat(
"LundAllLnR", lnR);
316 bool ok_lnkT = tryGetVectorFloat(
"LundAllLnKT", lnkT);
317 bool ok_z = tryGetVectorFloat(
"LundAllZ",
z);
318 bool ok_idp1 = tryGetVectorInt(
"LundAllIDP1", idp1);
319 bool ok_idp2 = tryGetVectorInt(
"LundAllIDP2", idp2);
320 bool ok_nsp = tryGetInt(
"nSplits", nSplits);
322 if (!(ok_lnR && ok_lnkT && ok_z && ok_idp1 && ok_idp2 && ok_nsp)) {
323 ATH_MSG_DEBUG(
"Missing one or more Lund decorations (lnR/lnkT/z/idp1/idp2/nSplits). Aborting build.");
327 size_t n_nodes = lnR.size();
329 ATH_MSG_DEBUG(
"Lund decorations present but zero-length vectors.");
336 bool gotNtrk =
false;
338 std::vector<std::string> ntrkCandidates = {
"LRJ_Nconst_Charged",
"nTrk",
"Ntrk",
"NTracks" };
339 for (
auto &cand : ntrkCandidates) {
340 std::string withPref =
m_prefix.value() + cand;
341 if (
jet.getAttribute(withPref, tmp_ntrk_i)) { ntrk_f =
static_cast<float>(tmp_ntrk_i); gotNtrk =
true;
ATH_MSG_DEBUG(
"Using Ntrk attr: " << withPref);
break; }
342 if (
jet.getAttribute(cand, tmp_ntrk_i)) { ntrk_f =
static_cast<float>(tmp_ntrk_i); gotNtrk =
true;
ATH_MSG_DEBUG(
"Using Ntrk attr: " << cand);
break; }
345 const auto & links =
jet.constituentLinks();
346 for (
size_t i = 0; i <
jet.numConstituents(); ++i) {
349 if (fe && fe->
isCharged()) ntrk_f += 1.f;
351 ATH_MSG_DEBUG(
"Computed Ntrk from constituents: " << ntrk_f);
357 float ln_kTcut = (kTsel_val > 0.f) ? std::log(std::max(1e-12f, kTsel_val)) : -1e9f;
359 std::vector<char> mask(n_nodes, 0);
360 size_t n_selected = 0;
361 for (
size_t i = 0; i < n_nodes; ++i) {
363 if (lnk > ln_kTcut) { mask[i] = 1; ++n_selected; }
367 if (n_selected < 1) {
368 ATH_MSG_DEBUG(
"No nodes passed kT selection (n_selected=" << n_selected <<
").");
373 out_x_float.reserve(n_selected * 3);
374 for (
size_t i = 0; i < n_nodes; ++i) {
375 if (!mask[i])
continue;
376 float f_ln1overdR = lnR[i];
377 float f_lnkT = lnkT[i];
378 float zval = std::max(1e-6f,
z[i]);
379 float f_ln1overz = -std::log(zval);
385 out_x_float.push_back(dr_std);
386 out_x_float.push_back(kt_std);
387 out_x_float.push_back(z_std);
391 std::vector<int> old2new(n_nodes, -1);
393 for (
size_t i = 0; i < n_nodes; ++i) {
394 if (mask[i]) old2new[i] = new_idx++;
398 for (
size_t child = 0; child < n_nodes; ++child) {
399 if (!mask[child])
continue;
400 int p1 = (child < idp1.size()) ? idp1[child] : -1;
401 int p2 = (child < idp2.size()) ? idp2[child] : -1;
402 if (p1 >= 0 && p1 <
static_cast<int>(n_nodes) && old2new[p1] >= 0) {
403 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[p1]));
404 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[child]));
406 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[child]));
407 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[p1]));
409 if (p2 >= 0 && p2 <
static_cast<int>(n_nodes) && old2new[p2] >= 0) {
410 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[p2]));
411 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[child]));
412 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[child]));
413 out_edge_index_int64.push_back(
static_cast<int64_t
>(old2new[p2]));
418 for (int64_t i = 0; i < new_idx; ++i) out_batch_int64.push_back(0);
421 out_counts_int64.push_back(
static_cast<int64_t
>(new_idx));
427 out_Ntrk_float.reserve(batch_size);
428 for (int64_t i = 0; i < batch_size; ++i) {
429 out_Ntrk_float.push_back(ntrk_std);
432 if (out_x_float.empty() || out_counts_int64.empty()) {
438 <<
" edges=" << (out_edge_index_int64.size()/2)
439 <<
" ntrk=" << ntrk_f);
Class providing the definition of the 4-vector interface.