ATLAS Offline Software
Loading...
Searching...
No Matches
LundJetOnnxAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef XAOD_ANALYSIS
6
7// xAOD and StoreGate
12// AthOnnx
14
15// ONNX Runtime
16#include <onnxruntime_cxx_api.h>
19
20#include <sstream>
21#include <cmath>
22
23//static const char* k_input_names[] = {"x","edge_index","batch","Ntrk","counts"};
24//static const char* k_output_names[] = {"output"};
26
27 ATH_MSG_INFO("Initializing LundJetOnnxAlg");
28
29 ATH_MSG_INFO("Loading ONNX model from: " << m_modelPath);
30 // -------------------------
31 // Create ONNX Runtime session
32 // -------------------------
33 Ort::SessionOptions sessionOptions;
34 sessionOptions.SetIntraOpNumThreads(1);
35 sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
36
37 // Recommended in Athena to reduce memory usage
38 sessionOptions.DisableCpuMemArena();
39
40 m_env = std::make_unique<Ort::Env>(
41 ORT_LOGGING_LEVEL_WARNING,
42 "LundNetGNN"
43 );
46
47 if (m_resolvedModelPath.empty()) {
48 ATH_MSG_ERROR("Could not resolve ONNX model path: " << m_modelPath);
49 return StatusCode::FAILURE;
50 }
51
52 try {
53 m_session = std::make_unique<Ort::Session>(
54 *m_env,
55 m_resolvedModelPath.c_str(), // ← ahora es estable
56 sessionOptions
57 );
58 } catch (const Ort::Exception& e) {
59 ATH_MSG_ERROR("Failed to create ONNX Runtime session: " << e.what());
60 return StatusCode::FAILURE;
61 }
62
63 ATH_MSG_INFO("ONNX Runtime session successfully created");
64
65 // -------------------------
66 // Decorations
67 // -------------------------
68 ATH_MSG_INFO("InputJetContainer: " << m_inputJetContainer);
69 ATH_MSG_INFO("Prefix: '" << m_prefix << "'");
70 ATH_MSG_INFO("kT selection: " << m_kTSelection);
71 std::string decorFull =
72 m_inputJetContainer.value() + "." +
73 m_prefix.value() +
74 m_scoreName.value();
75 std::string validDecorFull =
76 m_inputJetContainer.value() + "." +
77 m_prefix.value() +
78 "LundNetValid";
79
80 m_scoreDecorKey = decorFull;
81 m_validDecorKey = validDecorFull;
82 ATH_CHECK(m_scoreDecorKey.initialize());
83 ATH_CHECK(m_validDecorKey.initialize());
84 ATH_MSG_INFO("Will write decoration: " << decorFull);
85
86 return StatusCode::SUCCESS;
87}
88
89StatusCode LundJetOnnxAlg::execute(const EventContext& ctx) const {
90
92 if (!jets.isValid()) {
93 ATH_MSG_ERROR("Failed to retrieve JetContainer: " << m_inputJetContainer);
94 return StatusCode::FAILURE;
95 }
96
99 // ONNX helpers
100 Ort::AllocatorWithDefaultOptions allocator;
101 auto memory_info =
102 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
103
104 for (const xAOD::Jet* jptr : *jets) {
105
106 const xAOD::Jet& jet = *jptr;
107
108 validDecor(*jptr) = 0;
109 // ---------------------------
110 // Build inputs
111 // ---------------------------
112 std::vector<float> x_flat; // [num_nodes * 3]
113 std::vector<int64_t> edge_index_flat; // [2 * num_edges]
114 std::vector<int64_t> batch_vec; // [num_nodes]
115 std::vector<int64_t> counts_vec; // unused but kept
116 std::vector<float> Ntrk_vec; // [1]
117
118 bool ok = buildOnnxInputs(
119 jet, x_flat, edge_index_flat, batch_vec, counts_vec, Ntrk_vec
120 );
121
122 if (!ok) {
123 ATH_MSG_DEBUG("Skipping jet: couldn't build inputs");
124 continue;
125 }
126
127 const int64_t num_nodes = batch_vec.size();
128
129 if (num_nodes == 0 || x_flat.size() != static_cast<size_t>(num_nodes * 3)) {
130 ATH_MSG_WARNING("Inconsistent node inputs");
131 continue;
132 }
133
134 // ---------------------------
135 // Create ONNX tensors
136 // ---------------------------
137 const int64_t min_nodes = 2;
138
139 if (num_nodes < min_nodes) {
140 ATH_MSG_DEBUG("Skipping jet: num_nodes < 2");
141 continue;
142 }
143
144 const int64_t num_edges = edge_index_flat.size() / 2;
145 if (num_edges < 1) {
146 ATH_MSG_DEBUG("Skipping jet: no edges in graph");
147 continue;
148 }
149
150 // x : [num_nodes, 3]
151 std::vector<int64_t> x_shape = { num_nodes, 3 };
152 Ort::Value x_tensor =
153 Ort::Value::CreateTensor<float>(
154 memory_info,
155 x_flat.data(),
156 x_flat.size(),
157 x_shape.data(),
158 x_shape.size()
159 );
160
161 // edge_index : [2, num_edges]
162 std::vector<int64_t> ei_shape = {
163 2,
164 static_cast<int64_t>(edge_index_flat.size() / 2)
165 };
166 Ort::Value edge_tensor =
167 Ort::Value::CreateTensor<int64_t>(
168 memory_info,
169 edge_index_flat.data(),
170 edge_index_flat.size(),
171 ei_shape.data(),
172 ei_shape.size()
173 );
174
175 // batch : [num_nodes]
176 std::vector<int64_t> batch_shape = { num_nodes };
177 Ort::Value batch_tensor =
178 Ort::Value::CreateTensor<int64_t>(
179 memory_info,
180 batch_vec.data(),
181 batch_vec.size(),
182 batch_shape.data(),
183 batch_shape.size()
184 );
185
186 // Ntrk : [batch_size]
187 std::vector<int64_t> ntrk_shape = {static_cast<int64_t>(Ntrk_vec.size())};
188 Ort::Value ntrk_tensor =
189 Ort::Value::CreateTensor<float>(
190 memory_info,
191 Ntrk_vec.data(),
192 Ntrk_vec.size(),
193 ntrk_shape.data(),
194 ntrk_shape.size()
195 );
196 std::array<Ort::Value, 4> input_tensors = {
197 std::move(x_tensor),
198 std::move(edge_tensor),
199 std::move(batch_tensor),
200 std::move(ntrk_tensor)
201 };
202
203 std::array<const char*, 4> input_names = {
204 "x",
205 "edge_index",
206 "batch",
207 "Ntrk"
208 };
209
210 std::array<const char*, 1> output_names = {
211 "output"
212 };
213
214 // ---------------------------
215 // Run inference
216 // ---------------------------
217 std::vector<Ort::Value> output_tensors;
218 try {
219 output_tensors = m_session->Run(
220 Ort::RunOptions{nullptr},
221 input_names.data(),
222 input_tensors.data(),
223 input_tensors.size(),
224 output_names.data(),
225 output_names.size()
226 );
227 }
228 catch (const Ort::Exception& e) {
229 ATH_MSG_ERROR("ONNX Runtime exception: " << e.what());
230 continue;
231 }
232
233 // ---------------------------
234 // Read output
235 // ---------------------------
236 if (output_tensors.empty() || !output_tensors.front().IsTensor()) {
237 ATH_MSG_WARNING("Invalid output tensor");
238 //scoreDecor(*jptr) = -999.f;
239 continue;
240 }
241
242 float* out_data =
243 output_tensors.front().GetTensorMutableData<float>();
244
245 float score = out_data[0]; // shape [-1,1], batch=1
246
247 scoreDecor(*jptr) = score;
248 validDecor(*jptr) = 1;
249
250 ATH_MSG_DEBUG("Jet decorated with LundNet score = " << score);
251 }
252
253 return StatusCode::SUCCESS;
254}
256 std::vector<float>& out_x_float,
257 std::vector<int64_t>& out_edge_index_int64,
258 std::vector<int64_t>& out_batch_int64,
259 std::vector<int64_t>& out_counts_int64,
260 std::vector<float>& out_Ntrk_float) const {
261 // Clear outputs
262 out_x_float.clear();
263 out_edge_index_int64.clear();
264 out_batch_int64.clear();
265 out_counts_int64.clear();
266 out_Ntrk_float.clear();
267
268 auto tryGetVectorFloat = [&](const std::string& baseName, std::vector<float>& out)->bool {
269 std::string withPref = m_prefix.value() + baseName;
270 if (jet.getAttribute(withPref, out)) {
271 ATH_MSG_DEBUG("Found attribute: " << withPref << " (used)");
272 return true;
273 }
274 if (jet.getAttribute(baseName, out)) {
275 ATH_MSG_DEBUG("Found attribute: " << baseName << " (used)");
276 return true;
277 }
278 ATH_MSG_DEBUG("Attribute not found: " << withPref << " nor " << baseName);
279 return false;
280 };
281
282 auto tryGetVectorInt = [&](const std::string& baseName, std::vector<int>& out)->bool {
283 std::string withPref = m_prefix.value() + baseName;
284 if (jet.getAttribute(withPref, out)) {
285 ATH_MSG_DEBUG("Found attribute: " << withPref << " (used)");
286 return true;
287 }
288 if (jet.getAttribute(baseName, out)) {
289 ATH_MSG_DEBUG("Found attribute: " << baseName << " (used)");
290 return true;
291 }
292 ATH_MSG_DEBUG("Attribute not found: " << withPref << " nor " << baseName);
293 return false;
294 };
295
296 auto tryGetInt = [&](const std::string& baseName, int& out)->bool {
297 std::string withPref = m_prefix.value() + baseName;
298 if (jet.getAttribute(withPref, out)) {
299 ATH_MSG_DEBUG("Found attribute: " << withPref << " (used)");
300 return true;
301 }
302 if (jet.getAttribute(baseName, out)) {
303 ATH_MSG_DEBUG("Found attribute: " << baseName << " (used)");
304 return true;
305 }
306 ATH_MSG_DEBUG("Attribute not found: " << withPref << " nor " << baseName);
307 return false;
308 };
309
310 // Read Lund decorations (try prefixed and unprefixed names)
311 std::vector<float> lnR, lnkT, z;
312 std::vector<int> idp1, idp2;
313 int nSplits = 0;
314
315 bool ok_lnR = tryGetVectorFloat("LundAllLnR", lnR);
316 bool ok_lnkT = tryGetVectorFloat("LundAllLnKT", lnkT);
317 bool ok_z = tryGetVectorFloat("LundAllZ", z);
318 bool ok_idp1 = tryGetVectorInt("LundAllIDP1", idp1);
319 bool ok_idp2 = tryGetVectorInt("LundAllIDP2", idp2);
320 bool ok_nsp = tryGetInt("nSplits", nSplits);
321
322 if (!(ok_lnR && ok_lnkT && ok_z && ok_idp1 && ok_idp2 && ok_nsp)) {
323 ATH_MSG_DEBUG("Missing one or more Lund decorations (lnR/lnkT/z/idp1/idp2/nSplits). Aborting build.");
324 return false;
325 }
326
327 size_t n_nodes = lnR.size();
328 if (n_nodes == 0) {
329 ATH_MSG_DEBUG("Lund decorations present but zero-length vectors.");
330 return false;
331 }
332
333 // Determine Ntrk: try various attribute names (try prefixed first, then unprefixed)
334 float ntrk_f = 0.f;
335 int tmp_ntrk_i = 0;
336 bool gotNtrk = false;
337 // Candidate names commonly used
338 std::vector<std::string> ntrkCandidates = { "LRJ_Nconst_Charged", "nTrk", "Ntrk", "NTracks" };
339 for (auto &cand : ntrkCandidates) {
340 std::string withPref = m_prefix.value() + cand;
341 if (jet.getAttribute(withPref, tmp_ntrk_i)) { ntrk_f = static_cast<float>(tmp_ntrk_i); gotNtrk = true; ATH_MSG_DEBUG("Using Ntrk attr: " << withPref); break; }
342 if (jet.getAttribute(cand, tmp_ntrk_i)) { ntrk_f = static_cast<float>(tmp_ntrk_i); gotNtrk = true; ATH_MSG_DEBUG("Using Ntrk attr: " << cand); break; }
343 }
344 if (!gotNtrk) {
345 const auto & links = jet.constituentLinks();
346 for (size_t i = 0; i < jet.numConstituents(); ++i) {
347 const xAOD::IParticle* p = *links[i];
348 const xAOD::FlowElement* fe = dynamic_cast<const xAOD::FlowElement*>(p);
349 if (fe && fe->isCharged()) ntrk_f += 1.f;
350 }
351 ATH_MSG_DEBUG("Computed Ntrk from constituents: " << ntrk_f);
352 }
353
354 // Build node mask based on kTSelection.
355 // m_kTSelection is a Gaudi Property<float> — get the value for comparison.
356 float kTsel_val = m_kTSelection;
357 float ln_kTcut = (kTsel_val > 0.f) ? std::log(std::max(1e-12f, kTsel_val)) : -1e9f;
358
359 std::vector<char> mask(n_nodes, 0);
360 size_t n_selected = 0;
361 for (size_t i = 0; i < n_nodes; ++i) {
362 float lnk = lnkT[i];
363 if (lnk > ln_kTcut) { mask[i] = 1; ++n_selected; }
364 else mask[i] = 0;
365 }
366
367 if (n_selected < 1) {
368 ATH_MSG_DEBUG("No nodes passed kT selection (n_selected=" << n_selected << ").");
369 return false;
370 }
371
372 // Build feature matrix x: [ln(1/dR), ln(kt), ln(1/z)] then standardize
373 out_x_float.reserve(n_selected * 3);
374 for (size_t i = 0; i < n_nodes; ++i) {
375 if (!mask[i]) continue;
376 float f_ln1overdR = lnR[i];
377 float f_lnkT = lnkT[i];
378 float zval = std::max(1e-6f, z[i]);
379 float f_ln1overz = -std::log(zval);
380
381 float z_std = (f_ln1overz - m_mean_z) / m_std_z;
382 float kt_std = (f_lnkT - m_mean_kt) / m_std_kt;
383 float dr_std = (f_ln1overdR - m_mean_dr) / m_std_dr;
384
385 out_x_float.push_back(dr_std);
386 out_x_float.push_back(kt_std);
387 out_x_float.push_back(z_std);
388 }
389
390 // Reindex old->new for masked nodes
391 std::vector<int> old2new(n_nodes, -1);
392 int new_idx = 0;
393 for (size_t i = 0; i < n_nodes; ++i) {
394 if (mask[i]) old2new[i] = new_idx++;
395 }
396
397 // Build edges (only include edges where both endpoints survive)
398 for (size_t child = 0; child < n_nodes; ++child) {
399 if (!mask[child]) continue;
400 int p1 = (child < idp1.size()) ? idp1[child] : -1;
401 int p2 = (child < idp2.size()) ? idp2[child] : -1;
402 if (p1 >= 0 && p1 < static_cast<int>(n_nodes) && old2new[p1] >= 0) {
403 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p1]));
404 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
405 // optionally also add reverse if model expects undirected edges (you did both)
406 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
407 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p1]));
408 }
409 if (p2 >= 0 && p2 < static_cast<int>(n_nodes) && old2new[p2] >= 0) {
410 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p2]));
411 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
412 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
413 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p2]));
414 }
415 }
416
417 // batch vector: all nodes belong to graph 0
418 for (int64_t i = 0; i < new_idx; ++i) out_batch_int64.push_back(0);
419
420 // counts: [num_nodes_selected]
421 out_counts_int64.push_back(static_cast<int64_t>(new_idx));
422
423 // Ntrk: normalize as in Python
424 float ntrk_std = (ntrk_f - m_mean_ntrk) / m_std_ntrk;
425 //out_Ntrk_float.push_back(ntrk_std);
426 const int64_t batch_size = m_expectedBatchSize; // Gaudi::Property
427 out_Ntrk_float.reserve(batch_size);
428 for (int64_t i = 0; i < batch_size; ++i) {
429 out_Ntrk_float.push_back(ntrk_std);
430 }
431 // Final sanity checks
432 if (out_x_float.empty() || out_counts_int64.empty()) {
433 ATH_MSG_DEBUG("After masking, nothing to evaluate.");
434 return false;
435 }
436
437 ATH_MSG_DEBUG("Built ONNX input: nodes=" << new_idx
438 << " edges=" << (out_edge_index_int64.size()/2)
439 << " ntrk=" << ntrk_f);
440
441 return true;
442}
443#endif //ATHENA-ONLY
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
Handle class for reading from StoreGate.
Handle class for adding a decoration to an object.
#define z
Gaudi::Property< float > m_mean_ntrk
Gaudi::Property< std::string > m_modelPath
Gaudi::Property< float > m_mean_dr
Gaudi::Property< std::string > m_inputJetContainer
Gaudi::Property< float > m_std_ntrk
virtual StatusCode execute(const EventContext &ctx) const override
Gaudi::Property< float > m_std_kt
Gaudi::Property< std::string > m_scoreName
Gaudi::Property< float > m_mean_z
Gaudi::Property< float > m_std_z
Gaudi::Property< float > m_std_dr
std::unique_ptr< Ort::Env > m_env
Gaudi::Property< float > m_kTSelection
SG::WriteDecorHandleKey< xAOD::JetContainer > m_validDecorKey
bool buildOnnxInputs(const xAOD::Jet &jet, std::vector< float > &out_x_float, std::vector< int64_t > &out_edge_index_int64, std::vector< int64_t > &out_batch_int64, std::vector< int64_t > &out_counts_int64, std::vector< float > &out_Ntrk_float) const
Helper to build ONNX inputs (batch_size=1)
std::unique_ptr< Ort::Session > m_session
Gaudi::Property< int > m_expectedBatchSize
Gaudi::Property< float > m_mean_kt
SG::WriteDecorHandleKey< xAOD::JetContainer > m_scoreDecorKey
std::string m_resolvedModelPath
Gaudi::Property< std::string > m_prefix
virtual StatusCode initialize() override
static std::string find_file(const std::string &logical_file_name, const std::string &search_path)
Handle class for adding a decoration to an object.
Class providing the definition of the 4-vector interface.
Jet_v1 Jet
Definition of the current "jet version".
FlowElement_v1 FlowElement
Definition of the current "pfo version".
Definition FlowElement.h:16