ATLAS Offline Software
Loading...
Searching...
No Matches
LundJetOnnxAlg Class Reference

#include <LundJetOnnxAlg.h>

Inheritance diagram for LundJetOnnxAlg:
Collaboration diagram for LundJetOnnxAlg:

Public Member Functions

virtual StatusCode initialize () override
virtual StatusCode execute (const EventContext &ctx) const override
virtual StatusCode sysInitialize () override
 Override sysInitialize.
virtual bool isClonable () const override
 Specify if the algorithm is clonable.
virtual unsigned int cardinality () const override
 Cardinality (Maximum number of clones that can exist) special value 0 means that algorithm is reentrant.
virtual StatusCode sysExecute (const EventContext &ctx) override
 Execute an algorithm.
virtual const DataObjIDColl & extraOutputDeps () const override
 Return the list of extra output dependencies.
virtual bool filterPassed (const EventContext &ctx) const
virtual void setFilterPassed (bool state, const EventContext &ctx) const
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

bool buildOnnxInputs (const xAOD::Jet &jet, std::vector< float > &out_x_float, std::vector< int64_t > &out_edge_index_int64, std::vector< int64_t > &out_batch_int64, std::vector< int64_t > &out_counts_int64, std::vector< float > &out_Ntrk_float) const
 Helper to build ONNX inputs (batch_size=1)
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>

Private Attributes

Gaudi::Property< std::string > m_inputJetContainer
Gaudi::Property< std::string > m_prefix
Gaudi::Property< float > m_kTSelection
Gaudi::Property< std::string > m_scoreName
Gaudi::Property< int > m_expectedBatchSize
SG::WriteDecorHandleKey< xAOD::JetContainerm_validDecorKey
SG::WriteDecorHandleKey< xAOD::JetContainerm_scoreDecorKey
Gaudi::Property< float > m_mean_z { this, "MeanZ", 2.0568479032747313f, "mean z (for normalization)" }
Gaudi::Property< float > m_std_z { this, "StdZ", 1.4450598054504056f, "std z (for normalization)" }
Gaudi::Property< float > m_mean_dr { this, "MeanDR", 3.8597358364389427f, "mean dr (for normalization)" }
Gaudi::Property< float > m_std_dr { this, "StdDR", 2.2748462855901073f, "std dr (for normalization)" }
Gaudi::Property< float > m_mean_kt { this, "MeanKT", -2.379904791478249f, "mean kt (for normalization)" }
Gaudi::Property< float > m_std_kt { this, "StdKT", 2.940813577366582f, "std kt (for normalization)" }
Gaudi::Property< float > m_mean_ntrk { this, "MeanNtrk", 57.588158609500134f, "mean Ntrk" }
Gaudi::Property< float > m_std_ntrk { this, "StdNtrk", 23.900100132781983f, "std Ntrk" }
std::string m_resolvedModelPath
Gaudi::Property< std::string > m_modelPath
std::unique_ptr< Ort::Env > m_env
std::unique_ptr< Ort::Session > m_session
DataObjIDColl m_extendedExtraObjects
 Extra output dependency collection, extended by AthAlgorithmDHUpdate to add symlinks.
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default)
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default)
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared

Detailed Description

Definition at line 24 of file LundJetOnnxAlg.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Member Function Documentation

◆ buildOnnxInputs()

bool LundJetOnnxAlg::buildOnnxInputs ( const xAOD::Jet & jet,
std::vector< float > & out_x_float,
std::vector< int64_t > & out_edge_index_int64,
std::vector< int64_t > & out_batch_int64,
std::vector< int64_t > & out_counts_int64,
std::vector< float > & out_Ntrk_float ) const
private

Helper to build ONNX inputs (batch_size=1)

Definition at line 255 of file LundJetOnnxAlg.cxx.

260 {
261 // Clear outputs
262 out_x_float.clear();
263 out_edge_index_int64.clear();
264 out_batch_int64.clear();
265 out_counts_int64.clear();
266 out_Ntrk_float.clear();
267
268 auto tryGetVectorFloat = [&](const std::string& baseName, std::vector<float>& out)->bool {
269 std::string withPref = m_prefix.value() + baseName;
270 if (jet.getAttribute(withPref, out)) {
271 ATH_MSG_DEBUG("Found attribute: " << withPref << " (used)");
272 return true;
273 }
274 if (jet.getAttribute(baseName, out)) {
275 ATH_MSG_DEBUG("Found attribute: " << baseName << " (used)");
276 return true;
277 }
278 ATH_MSG_DEBUG("Attribute not found: " << withPref << " nor " << baseName);
279 return false;
280 };
281
282 auto tryGetVectorInt = [&](const std::string& baseName, std::vector<int>& out)->bool {
283 std::string withPref = m_prefix.value() + baseName;
284 if (jet.getAttribute(withPref, out)) {
285 ATH_MSG_DEBUG("Found attribute: " << withPref << " (used)");
286 return true;
287 }
288 if (jet.getAttribute(baseName, out)) {
289 ATH_MSG_DEBUG("Found attribute: " << baseName << " (used)");
290 return true;
291 }
292 ATH_MSG_DEBUG("Attribute not found: " << withPref << " nor " << baseName);
293 return false;
294 };
295
296 auto tryGetInt = [&](const std::string& baseName, int& out)->bool {
297 std::string withPref = m_prefix.value() + baseName;
298 if (jet.getAttribute(withPref, out)) {
299 ATH_MSG_DEBUG("Found attribute: " << withPref << " (used)");
300 return true;
301 }
302 if (jet.getAttribute(baseName, out)) {
303 ATH_MSG_DEBUG("Found attribute: " << baseName << " (used)");
304 return true;
305 }
306 ATH_MSG_DEBUG("Attribute not found: " << withPref << " nor " << baseName);
307 return false;
308 };
309
310 // Read Lund decorations (try prefixed and unprefixed names)
311 std::vector<float> lnR, lnkT, z;
312 std::vector<int> idp1, idp2;
313 int nSplits = 0;
314
315 bool ok_lnR = tryGetVectorFloat("LundAllLnR", lnR);
316 bool ok_lnkT = tryGetVectorFloat("LundAllLnKT", lnkT);
317 bool ok_z = tryGetVectorFloat("LundAllZ", z);
318 bool ok_idp1 = tryGetVectorInt("LundAllIDP1", idp1);
319 bool ok_idp2 = tryGetVectorInt("LundAllIDP2", idp2);
320 bool ok_nsp = tryGetInt("nSplits", nSplits);
321
322 if (!(ok_lnR && ok_lnkT && ok_z && ok_idp1 && ok_idp2 && ok_nsp)) {
323 ATH_MSG_DEBUG("Missing one or more Lund decorations (lnR/lnkT/z/idp1/idp2/nSplits). Aborting build.");
324 return false;
325 }
326
327 size_t n_nodes = lnR.size();
328 if (n_nodes == 0) {
329 ATH_MSG_DEBUG("Lund decorations present but zero-length vectors.");
330 return false;
331 }
332
333 // Determine Ntrk: try various attribute names (try prefixed first, then unprefixed)
334 float ntrk_f = 0.f;
335 int tmp_ntrk_i = 0;
336 bool gotNtrk = false;
337 // Candidate names commonly used
338 std::vector<std::string> ntrkCandidates = { "LRJ_Nconst_Charged", "nTrk", "Ntrk", "NTracks" };
339 for (auto &cand : ntrkCandidates) {
340 std::string withPref = m_prefix.value() + cand;
341 if (jet.getAttribute(withPref, tmp_ntrk_i)) { ntrk_f = static_cast<float>(tmp_ntrk_i); gotNtrk = true; ATH_MSG_DEBUG("Using Ntrk attr: " << withPref); break; }
342 if (jet.getAttribute(cand, tmp_ntrk_i)) { ntrk_f = static_cast<float>(tmp_ntrk_i); gotNtrk = true; ATH_MSG_DEBUG("Using Ntrk attr: " << cand); break; }
343 }
344 if (!gotNtrk) {
345 const auto & links = jet.constituentLinks();
346 for (size_t i = 0; i < jet.numConstituents(); ++i) {
347 const xAOD::IParticle* p = *links[i];
348 const xAOD::FlowElement* fe = dynamic_cast<const xAOD::FlowElement*>(p);
349 if (fe && fe->isCharged()) ntrk_f += 1.f;
350 }
351 ATH_MSG_DEBUG("Computed Ntrk from constituents: " << ntrk_f);
352 }
353
354 // Build node mask based on kTSelection.
355 // m_kTSelection is a Gaudi Property<float> — get the value for comparison.
356 float kTsel_val = m_kTSelection;
357 float ln_kTcut = (kTsel_val > 0.f) ? std::log(std::max(1e-12f, kTsel_val)) : -1e9f;
358
359 std::vector<char> mask(n_nodes, 0);
360 size_t n_selected = 0;
361 for (size_t i = 0; i < n_nodes; ++i) {
362 float lnk = lnkT[i];
363 if (lnk > ln_kTcut) { mask[i] = 1; ++n_selected; }
364 else mask[i] = 0;
365 }
366
367 if (n_selected < 1) {
368 ATH_MSG_DEBUG("No nodes passed kT selection (n_selected=" << n_selected << ").");
369 return false;
370 }
371
372 // Build feature matrix x: [ln(1/dR), ln(kt), ln(1/z)] then standardize
373 out_x_float.reserve(n_selected * 3);
374 for (size_t i = 0; i < n_nodes; ++i) {
375 if (!mask[i]) continue;
376 float f_ln1overdR = lnR[i];
377 float f_lnkT = lnkT[i];
378 float zval = std::max(1e-6f, z[i]);
379 float f_ln1overz = -std::log(zval);
380
381 float z_std = (f_ln1overz - m_mean_z) / m_std_z;
382 float kt_std = (f_lnkT - m_mean_kt) / m_std_kt;
383 float dr_std = (f_ln1overdR - m_mean_dr) / m_std_dr;
384
385 out_x_float.push_back(dr_std);
386 out_x_float.push_back(kt_std);
387 out_x_float.push_back(z_std);
388 }
389
390 // Reindex old->new for masked nodes
391 std::vector<int> old2new(n_nodes, -1);
392 int new_idx = 0;
393 for (size_t i = 0; i < n_nodes; ++i) {
394 if (mask[i]) old2new[i] = new_idx++;
395 }
396
397 // Build edges (only include edges where both endpoints survive)
398 for (size_t child = 0; child < n_nodes; ++child) {
399 if (!mask[child]) continue;
400 int p1 = (child < idp1.size()) ? idp1[child] : -1;
401 int p2 = (child < idp2.size()) ? idp2[child] : -1;
402 if (p1 >= 0 && p1 < static_cast<int>(n_nodes) && old2new[p1] >= 0) {
403 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p1]));
404 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
405 // optionally also add reverse if model expects undirected edges (you did both)
406 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
407 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p1]));
408 }
409 if (p2 >= 0 && p2 < static_cast<int>(n_nodes) && old2new[p2] >= 0) {
410 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p2]));
411 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
412 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[child]));
413 out_edge_index_int64.push_back(static_cast<int64_t>(old2new[p2]));
414 }
415 }
416
417 // batch vector: all nodes belong to graph 0
418 for (int64_t i = 0; i < new_idx; ++i) out_batch_int64.push_back(0);
419
420 // counts: [num_nodes_selected]
421 out_counts_int64.push_back(static_cast<int64_t>(new_idx));
422
423 // Ntrk: normalize as in Python
424 float ntrk_std = (ntrk_f - m_mean_ntrk) / m_std_ntrk;
425 //out_Ntrk_float.push_back(ntrk_std);
426 const int64_t batch_size = m_expectedBatchSize; // Gaudi::Property
427 out_Ntrk_float.reserve(batch_size);
428 for (int64_t i = 0; i < batch_size; ++i) {
429 out_Ntrk_float.push_back(ntrk_std);
430 }
431 // Final sanity checks
432 if (out_x_float.empty() || out_counts_int64.empty()) {
433 ATH_MSG_DEBUG("After masking, nothing to evaluate.");
434 return false;
435 }
436
437 ATH_MSG_DEBUG("Built ONNX input: nodes=" << new_idx
438 << " edges=" << (out_edge_index_int64.size()/2)
439 << " ntrk=" << ntrk_f);
440
441 return true;
442}
#define ATH_MSG_DEBUG(x)
#define z
Gaudi::Property< float > m_mean_ntrk
Gaudi::Property< float > m_mean_dr
Gaudi::Property< float > m_std_ntrk
Gaudi::Property< float > m_std_kt
Gaudi::Property< float > m_mean_z
Gaudi::Property< float > m_std_z
Gaudi::Property< float > m_std_dr
Gaudi::Property< float > m_kTSelection
Gaudi::Property< int > m_expectedBatchSize
Gaudi::Property< float > m_mean_kt
Gaudi::Property< std::string > m_prefix
size_t numConstituents() const
Number of constituents in this jets (this is valid even when reading a file where the constituents ha...
Definition Jet_v1.cxx:153
const std::vector< ElementLink< IParticleContainer > > & constituentLinks() const
Direct access to constituents. WARNING expert use only.
Definition Jet_v1.cxx:162
bool getAttribute(AttributeID type, T &value) const
Retrieve attribute moment by enum.
const hsize_t batch_size
Definition defaults.h:9
FlowElement_v1 FlowElement
Definition of the current "pfo version".
Definition FlowElement.h:16

◆ cardinality()

unsigned int AthCommonReentrantAlgorithm< Gaudi::Algorithm >::cardinality ( ) const
overridevirtualinherited

Cardinality (Maximum number of clones that can exist) special value 0 means that algorithm is reentrant.

Override this to return 0 for reentrant algorithms.

Definition at line 75 of file AthCommonReentrantAlgorithm.cxx.

64{
65 return 0;
66}

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158 {
160 hndl.value(),
161 hndl.documentation());
162
163 }
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145 {
146 typedef typename SG::HandleClassifier<T>::type htype;
148 }
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ execute()

StatusCode LundJetOnnxAlg::execute ( const EventContext & ctx) const
overridevirtual

Definition at line 89 of file LundJetOnnxAlg.cxx.

89 {
90
91 SG::ReadHandle<xAOD::JetContainer> jets(m_inputJetContainer, ctx);
92 if (!jets.isValid()) {
93 ATH_MSG_ERROR("Failed to retrieve JetContainer: " << m_inputJetContainer);
94 return StatusCode::FAILURE;
95 }
96
97 SG::WriteDecorHandle<xAOD::JetContainer, float> scoreDecor(m_scoreDecorKey, ctx);
98 SG::WriteDecorHandle<xAOD::JetContainer, char> validDecor(m_validDecorKey, ctx);
99 // ONNX helpers
100 Ort::AllocatorWithDefaultOptions allocator;
101 auto memory_info =
102 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
103
104 for (const xAOD::Jet* jptr : *jets) {
105
106 const xAOD::Jet& jet = *jptr;
107
108 validDecor(*jptr) = 0;
109 // ---------------------------
110 // Build inputs
111 // ---------------------------
112 std::vector<float> x_flat; // [num_nodes * 3]
113 std::vector<int64_t> edge_index_flat; // [2 * num_edges]
114 std::vector<int64_t> batch_vec; // [num_nodes]
115 std::vector<int64_t> counts_vec; // unused but kept
116 std::vector<float> Ntrk_vec; // [1]
117
118 bool ok = buildOnnxInputs(
119 jet, x_flat, edge_index_flat, batch_vec, counts_vec, Ntrk_vec
120 );
121
122 if (!ok) {
123 ATH_MSG_DEBUG("Skipping jet: couldn't build inputs");
124 continue;
125 }
126
127 const int64_t num_nodes = batch_vec.size();
128
129 if (num_nodes == 0 || x_flat.size() != static_cast<size_t>(num_nodes * 3)) {
130 ATH_MSG_WARNING("Inconsistent node inputs");
131 continue;
132 }
133
134 // ---------------------------
135 // Create ONNX tensors
136 // ---------------------------
137 const int64_t min_nodes = 2;
138
139 if (num_nodes < min_nodes) {
140 ATH_MSG_DEBUG("Skipping jet: num_nodes < 2");
141 continue;
142 }
143
144 const int64_t num_edges = edge_index_flat.size() / 2;
145 if (num_edges < 1) {
146 ATH_MSG_DEBUG("Skipping jet: no edges in graph");
147 continue;
148 }
149
150 // x : [num_nodes, 3]
151 std::vector<int64_t> x_shape = { num_nodes, 3 };
152 Ort::Value x_tensor =
153 Ort::Value::CreateTensor<float>(
154 memory_info,
155 x_flat.data(),
156 x_flat.size(),
157 x_shape.data(),
158 x_shape.size()
159 );
160
161 // edge_index : [2, num_edges]
162 std::vector<int64_t> ei_shape = {
163 2,
164 static_cast<int64_t>(edge_index_flat.size() / 2)
165 };
166 Ort::Value edge_tensor =
167 Ort::Value::CreateTensor<int64_t>(
168 memory_info,
169 edge_index_flat.data(),
170 edge_index_flat.size(),
171 ei_shape.data(),
172 ei_shape.size()
173 );
174
175 // batch : [num_nodes]
176 std::vector<int64_t> batch_shape = { num_nodes };
177 Ort::Value batch_tensor =
178 Ort::Value::CreateTensor<int64_t>(
179 memory_info,
180 batch_vec.data(),
181 batch_vec.size(),
182 batch_shape.data(),
183 batch_shape.size()
184 );
185
186 // Ntrk : [batch_size]
187 std::vector<int64_t> ntrk_shape = {static_cast<int64_t>(Ntrk_vec.size())};
188 Ort::Value ntrk_tensor =
189 Ort::Value::CreateTensor<float>(
190 memory_info,
191 Ntrk_vec.data(),
192 Ntrk_vec.size(),
193 ntrk_shape.data(),
194 ntrk_shape.size()
195 );
196 std::array<Ort::Value, 4> input_tensors = {
197 std::move(x_tensor),
198 std::move(edge_tensor),
199 std::move(batch_tensor),
200 std::move(ntrk_tensor)
201 };
202
203 std::array<const char*, 4> input_names = {
204 "x",
205 "edge_index",
206 "batch",
207 "Ntrk"
208 };
209
210 std::array<const char*, 1> output_names = {
211 "output"
212 };
213
214 // ---------------------------
215 // Run inference
216 // ---------------------------
217 std::vector<Ort::Value> output_tensors;
218 try {
219 output_tensors = m_session->Run(
220 Ort::RunOptions{nullptr},
221 input_names.data(),
222 input_tensors.data(),
223 input_tensors.size(),
224 output_names.data(),
225 output_names.size()
226 );
227 }
228 catch (const Ort::Exception& e) {
229 ATH_MSG_ERROR("ONNX Runtime exception: " << e.what());
230 continue;
231 }
232
233 // ---------------------------
234 // Read output
235 // ---------------------------
236 if (output_tensors.empty() || !output_tensors.front().IsTensor()) {
237 ATH_MSG_WARNING("Invalid output tensor");
238 //scoreDecor(*jptr) = -999.f;
239 continue;
240 }
241
242 float* out_data =
243 output_tensors.front().GetTensorMutableData<float>();
244
245 float score = out_data[0]; // shape [-1,1], batch=1
246
247 scoreDecor(*jptr) = score;
248 validDecor(*jptr) = 1;
249
250 ATH_MSG_DEBUG("Jet decorated with LundNet score = " << score);
251 }
252
253 return StatusCode::SUCCESS;
254}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
Gaudi::Property< std::string > m_inputJetContainer
SG::WriteDecorHandleKey< xAOD::JetContainer > m_validDecorKey
bool buildOnnxInputs(const xAOD::Jet &jet, std::vector< float > &out_x_float, std::vector< int64_t > &out_edge_index_int64, std::vector< int64_t > &out_batch_int64, std::vector< int64_t > &out_counts_int64, std::vector< float > &out_Ntrk_float) const
Helper to build ONNX inputs (batch_size=1)
std::unique_ptr< Ort::Session > m_session
SG::WriteDecorHandleKey< xAOD::JetContainer > m_scoreDecorKey
Jet_v1 Jet
Definition of the current "jet version".

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ extraOutputDeps()

const DataObjIDColl & AthCommonReentrantAlgorithm< Gaudi::Algorithm >::extraOutputDeps ( ) const
overridevirtualinherited

Return the list of extra output dependencies.

This list is extended to include symlinks implied by inheritance relations.

Definition at line 94 of file AthCommonReentrantAlgorithm.cxx.

90{
91 // If we didn't find any symlinks to add, just return the collection
92 // from the base class. Otherwise, return the extended collection.
93 if (!m_extendedExtraObjects.empty()) {
95 }
97}
An algorithm that can be simultaneously executed in multiple threads.

◆ filterPassed()

virtual bool AthCommonReentrantAlgorithm< Gaudi::Algorithm >::filterPassed ( const EventContext & ctx) const
inlinevirtualinherited

Definition at line 96 of file AthCommonReentrantAlgorithm.h.

96 {
97 return execState( ctx ).filterPassed();
98 }
virtual bool filterPassed(const EventContext &ctx) const

◆ initialize()

StatusCode LundJetOnnxAlg::initialize ( )
overridevirtual

Definition at line 25 of file LundJetOnnxAlg.cxx.

25 {
26
27 ATH_MSG_INFO("Initializing LundJetOnnxAlg");
28
29 ATH_MSG_INFO("Loading ONNX model from: " << m_modelPath);
30 // -------------------------
31 // Create ONNX Runtime session
32 // -------------------------
33 Ort::SessionOptions sessionOptions;
34 sessionOptions.SetIntraOpNumThreads(1);
35 sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
36
37 // Recommended in Athena to reduce memory usage
38 sessionOptions.DisableCpuMemArena();
39
40 m_env = std::make_unique<Ort::Env>(
41 ORT_LOGGING_LEVEL_WARNING,
42 "LundNetGNN"
43 );
46
47 if (m_resolvedModelPath.empty()) {
48 ATH_MSG_ERROR("Could not resolve ONNX model path: " << m_modelPath);
49 return StatusCode::FAILURE;
50 }
51
52 try {
53 m_session = std::make_unique<Ort::Session>(
54 *m_env,
55 m_resolvedModelPath.c_str(), // ← ahora es estable
56 sessionOptions
57 );
58 } catch (const Ort::Exception& e) {
59 ATH_MSG_ERROR("Failed to create ONNX Runtime session: " << e.what());
60 return StatusCode::FAILURE;
61 }
62
63 ATH_MSG_INFO("ONNX Runtime session successfully created");
64
65 // -------------------------
66 // Decorations
67 // -------------------------
68 ATH_MSG_INFO("InputJetContainer: " << m_inputJetContainer);
69 ATH_MSG_INFO("Prefix: '" << m_prefix << "'");
70 ATH_MSG_INFO("kT selection: " << m_kTSelection);
71 std::string decorFull =
72 m_inputJetContainer.value() + "." +
73 m_prefix.value() +
74 m_scoreName.value();
75 std::string validDecorFull =
76 m_inputJetContainer.value() + "." +
77 m_prefix.value() +
78 "LundNetValid";
79
80 m_scoreDecorKey = decorFull;
81 m_validDecorKey = validDecorFull;
82 ATH_CHECK(m_scoreDecorKey.initialize());
83 ATH_CHECK(m_validDecorKey.initialize());
84 ATH_MSG_INFO("Will write decoration: " << decorFull);
85
86 return StatusCode::SUCCESS;
87}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_INFO(x)
Gaudi::Property< std::string > m_modelPath
Gaudi::Property< std::string > m_scoreName
std::unique_ptr< Ort::Env > m_env
std::string m_resolvedModelPath
static std::string find_file(const std::string &logical_file_name, const std::string &search_path)

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ isClonable()

◆ msg()

MsgStream & AthCommonMsg< Gaudi::Algorithm >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24 {
25 return this->msgStream();
26 }

◆ msgLvl()

bool AthCommonMsg< Gaudi::Algorithm >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30 {
31 return this->msgLevel(lvl);
32 }

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381 {
382 h.renounce();
384 }
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364 {
366 }

◆ setFilterPassed()

virtual void AthCommonReentrantAlgorithm< Gaudi::Algorithm >::setFilterPassed ( bool state,
const EventContext & ctx ) const
inlinevirtualinherited

Definition at line 100 of file AthCommonReentrantAlgorithm.h.

100 {
102 }
virtual void setFilterPassed(bool state, const EventContext &ctx) const

◆ sysExecute()

StatusCode AthCommonReentrantAlgorithm< Gaudi::Algorithm >::sysExecute ( const EventContext & ctx)
overridevirtualinherited

Execute an algorithm.

We override this in order to work around an issue with the Algorithm base class storing the event context in a member variable that can cause crashes in MT jobs.

Definition at line 85 of file AthCommonReentrantAlgorithm.cxx.

77{
78 return BaseAlg::sysExecute (ctx);
79}

◆ sysInitialize()

StatusCode AthCommonReentrantAlgorithm< Gaudi::Algorithm >::sysInitialize ( )
overridevirtualinherited

Override sysInitialize.

Override sysInitialize from the base class.

Loop through all output handles, and if they're WriteCondHandles, automatically register them and this Algorithm with the CondSvc

Scan through all outputHandles, and if they're WriteCondHandles, register them with the CondSvc

Reimplemented from AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >.

Reimplemented in HypoBase, and InputMakerBase.

Definition at line 61 of file AthCommonReentrantAlgorithm.cxx.

107 {
109
110 if (sc.isFailure()) {
111 return sc;
112 }
113
114 ServiceHandle<ICondSvc> cs("CondSvc",name());
115 for (auto h : outputHandles()) {
116 if (h->isCondition() && h->mode() == Gaudi::DataHandle::Writer) {
117 // do this inside the loop so we don't create the CondSvc until needed
118 if ( cs.retrieve().isFailure() ) {
119 ATH_MSG_WARNING("no CondSvc found: won't autoreg WriteCondHandles");
120 return StatusCode::SUCCESS;
121 }
122 if (cs->regHandle(this,*h).isFailure()) {
124 ATH_MSG_ERROR("unable to register WriteCondHandle " << h->fullKey()
125 << " with CondSvc");
126 }
127 }
128 }
129 return sc;
130}
virtual std::vector< Gaudi::DataHandle * > outputHandles() const override

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308 {
309 // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310 // << " size: " << m_vhka.size() << endmsg;
311 for (auto &a : m_vhka) {
313 for (auto k : keys) {
314 k->setOwner(this);
315 }
316 }
317 }

Member Data Documentation

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_env

std::unique_ptr<Ort::Env> LundJetOnnxAlg::m_env
private

Definition at line 79 of file LundJetOnnxAlg.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_expectedBatchSize

Gaudi::Property<int> LundJetOnnxAlg::m_expectedBatchSize
private
Initial value:
{
this,
"ExpectedBatchSize",
1,
"Batch size expected by the ONNX model (from training/export)"
}

Definition at line 50 of file LundJetOnnxAlg.h.

50 {
51 this,
52 "ExpectedBatchSize",
53 1,
54 "Batch size expected by the ONNX model (from training/export)"
55 };

◆ m_extendedExtraObjects

DataObjIDColl AthCommonReentrantAlgorithm< Gaudi::Algorithm >::m_extendedExtraObjects
privateinherited

Extra output dependency collection, extended by AthAlgorithmDHUpdate to add symlinks.

Empty if no symlinks were found.

Definition at line 114 of file AthCommonReentrantAlgorithm.h.

◆ m_inputJetContainer

Gaudi::Property<std::string> LundJetOnnxAlg::m_inputJetContainer
private
Initial value:
{
this, "InputJetContainer", "AntiKt10UFO",
"Name of input jet container"
}

Definition at line 32 of file LundJetOnnxAlg.h.

32 {
33 this, "InputJetContainer", "AntiKt10UFO",
34 "Name of input jet container"
35 };

◆ m_kTSelection

Gaudi::Property<float> LundJetOnnxAlg::m_kTSelection
private
Initial value:
{
this, "kTSelection", -1000.f,
"kT cut to apply (same meaning as kT_Cut in python loader)"
}

Definition at line 40 of file LundJetOnnxAlg.h.

40 {
41 this, "kTSelection", -1000.f,
42 "kT cut to apply (same meaning as kT_Cut in python loader)"
43 };

◆ m_mean_dr

Gaudi::Property<float> LundJetOnnxAlg::m_mean_dr { this, "MeanDR", 3.8597358364389427f, "mean dr (for normalization)" }
private

Definition at line 66 of file LundJetOnnxAlg.h.

66{ this, "MeanDR", 3.8597358364389427f, "mean dr (for normalization)" };

◆ m_mean_kt

Gaudi::Property<float> LundJetOnnxAlg::m_mean_kt { this, "MeanKT", -2.379904791478249f, "mean kt (for normalization)" }
private

Definition at line 68 of file LundJetOnnxAlg.h.

68{ this, "MeanKT", -2.379904791478249f, "mean kt (for normalization)" };

◆ m_mean_ntrk

Gaudi::Property<float> LundJetOnnxAlg::m_mean_ntrk { this, "MeanNtrk", 57.588158609500134f, "mean Ntrk" }
private

Definition at line 70 of file LundJetOnnxAlg.h.

70{ this, "MeanNtrk", 57.588158609500134f, "mean Ntrk" };

◆ m_mean_z

Gaudi::Property<float> LundJetOnnxAlg::m_mean_z { this, "MeanZ", 2.0568479032747313f, "mean z (for normalization)" }
private

Definition at line 64 of file LundJetOnnxAlg.h.

64{ this, "MeanZ", 2.0568479032747313f, "mean z (for normalization)" };

◆ m_modelPath

Gaudi::Property<std::string> LundJetOnnxAlg::m_modelPath
private
Initial value:
{
this,
"ModelPath",
"",
"Path to the ONNX model"
}

Definition at line 73 of file LundJetOnnxAlg.h.

73 {
74 this,
75 "ModelPath",
76 "",
77 "Path to the ONNX model"
78 };

◆ m_prefix

Gaudi::Property<std::string> LundJetOnnxAlg::m_prefix
private
Initial value:
{
this, "Prefix", "",
"Prefix used by LundVariablesTool decorations"
}

Definition at line 36 of file LundJetOnnxAlg.h.

36 {
37 this, "Prefix", "",
38 "Prefix used by LundVariablesTool decorations"
39 };

◆ m_resolvedModelPath

std::string LundJetOnnxAlg::m_resolvedModelPath
private

Definition at line 72 of file LundJetOnnxAlg.h.

◆ m_scoreDecorKey

SG::WriteDecorHandleKey<xAOD::JetContainer> LundJetOnnxAlg::m_scoreDecorKey
private
Initial value:
{
this, "ScoreDecorKey", "", "Score decoration key"
}

Definition at line 60 of file LundJetOnnxAlg.h.

60 {
61 this, "ScoreDecorKey", "", "Score decoration key"
62 };

◆ m_scoreName

Gaudi::Property<std::string> LundJetOnnxAlg::m_scoreName
private
Initial value:
{
this, "ScoreDecoration", "LundNetScore",
"Name of the decoration (without container prefix). Final decoration will be <container>.<prefix><ScoreDecoration>"
}

Definition at line 45 of file LundJetOnnxAlg.h.

45 {
46 this, "ScoreDecoration", "LundNetScore",
47 "Name of the decoration (without container prefix). Final decoration will be <container>.<prefix><ScoreDecoration>"
48 };

◆ m_session

std::unique_ptr<Ort::Session> LundJetOnnxAlg::m_session
private

Definition at line 80 of file LundJetOnnxAlg.h.

◆ m_std_dr

Gaudi::Property<float> LundJetOnnxAlg::m_std_dr { this, "StdDR", 2.2748462855901073f, "std dr (for normalization)" }
private

Definition at line 67 of file LundJetOnnxAlg.h.

67{ this, "StdDR", 2.2748462855901073f, "std dr (for normalization)" };

◆ m_std_kt

Gaudi::Property<float> LundJetOnnxAlg::m_std_kt { this, "StdKT", 2.940813577366582f, "std kt (for normalization)" }
private

Definition at line 69 of file LundJetOnnxAlg.h.

69{ this, "StdKT", 2.940813577366582f, "std kt (for normalization)" };

◆ m_std_ntrk

Gaudi::Property<float> LundJetOnnxAlg::m_std_ntrk { this, "StdNtrk", 23.900100132781983f, "std Ntrk" }
private

Definition at line 71 of file LundJetOnnxAlg.h.

71{ this, "StdNtrk", 23.900100132781983f, "std Ntrk" };

◆ m_std_z

Gaudi::Property<float> LundJetOnnxAlg::m_std_z { this, "StdZ", 1.4450598054504056f, "std z (for normalization)" }
private

Definition at line 65 of file LundJetOnnxAlg.h.

65{ this, "StdZ", 1.4450598054504056f, "std z (for normalization)" };

◆ m_validDecorKey

SG::WriteDecorHandleKey<xAOD::JetContainer> LundJetOnnxAlg::m_validDecorKey
private
Initial value:
{
this, "ValidDecor", "", "Valid LundNet inference"
}

Definition at line 56 of file LundJetOnnxAlg.h.

56 {
57 this, "ValidDecor", "", "Valid LundNet inference"
58 };

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< Gaudi::Algorithm > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: