ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxNNCondAlg.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
5#include "OnnxNNCondAlg.h"
8
9namespace InDet {
10
11
13 ATH_CHECK(m_onnxSvc.retrieve());
14 ATH_CHECK(m_writeKey.initialize());
15
16 // Validate that at least the number network path is provided
17 if (m_numberNetworkPath.value().empty()) {
18 ATH_MSG_FATAL("NumberNetworkPath must be set");
19 return StatusCode::FAILURE;
20 }
21
22 ATH_MSG_INFO("OnnxNNCondAlg configured with:"
23 << " NumberNetwork=" << m_numberNetworkPath.value()
24 << " PosNetwork1=" << m_posNetwork1Path.value()
25 << " PosNetwork2=" << m_posNetwork2Path.value()
26 << " PosNetwork3=" << m_posNetwork3Path.value());
27
28 return StatusCode::SUCCESS;
29}
30
31std::unique_ptr<Ort::Session> OnnxNNCondAlg::createSession(
32 const std::string& modelPath) const {
33
34 std::string resolvedPath = PathResolver::find_calib_file(modelPath);
35 if (resolvedPath.empty()) {
36 // Try as absolute/relative path directly
37 resolvedPath = modelPath;
38 }
39
40 ATH_MSG_DEBUG("Loading ONNX model from: " << resolvedPath);
41
42 Ort::SessionOptions sessionOptions;
43 sessionOptions.SetIntraOpNumThreads(1);
44 sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
45 sessionOptions.DisablePerSessionThreads();
46
47 return std::make_unique<Ort::Session>(m_onnxSvc->env(), resolvedPath.c_str(),
48 sessionOptions);
49}
50
51StatusCode OnnxNNCondAlg::execute(const EventContext& ctx) const {
52
54 if (writeHandle.isValid()) {
55 ATH_MSG_DEBUG("OnnxNNCollection already valid, skipping");
56 return StatusCode::SUCCESS;
57 }
58
59 auto collection = std::make_unique<OnnxNNCollection>();
60
61 // Load number network
62 try {
63 collection->numberNetwork = createSession(m_numberNetworkPath.value());
64 } catch (const Ort::Exception& e) {
65 ATH_MSG_FATAL("Failed to load number network ONNX model: " << e.what());
66 return StatusCode::FAILURE;
67 }
68
69 // Load position networks (paths are optional)
70 const std::array<std::string, 3> posPaths = {
71 m_posNetwork1Path.value(),
72 m_posNetwork2Path.value(),
73 m_posNetwork3Path.value()
74 };
75 std::unique_ptr<Ort::Session>* posMembers[3] = {
76 &collection->positionNetwork1,
77 &collection->positionNetwork2,
78 &collection->positionNetwork3
79 };
80
81 for (int i = 0; i < 3; ++i) {
82 if (posPaths[i].empty()) {
83 ATH_MSG_DEBUG("Position network " << (i+1) << " path not set, skipping");
84 continue;
85 }
86 try {
87 *posMembers[i] = createSession(posPaths[i]);
88 } catch (const Ort::Exception& e) {
89 ATH_MSG_FATAL("Failed to load position network " << (i+1)
90 << " ONNX model: " << e.what());
91 return StatusCode::FAILURE;
92 }
93 }
94
95 // Set infinite IOV for file-based models
97
98 ATH_CHECK(writeHandle.record(std::move(collection)));
99 ATH_MSG_DEBUG("Recorded OnnxNNCollection successfully");
100
101 return StatusCode::SUCCESS;
102}
103
104} // namespace InDet
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
static const Attributes_t empty
static EventIDRange infiniteRunLB()
Produces an EventIDRange that is infinite in RunLumi and invalid in Time.
Gaudi::Property< std::string > m_numberNetworkPath
std::unique_ptr< Ort::Session > createSession(const std::string &modelPath) const
Gaudi::Property< std::string > m_posNetwork1Path
virtual StatusCode execute(const EventContext &ctx) const override
SG::WriteCondHandleKey< OnnxNNCollection > m_writeKey
Gaudi::Property< std::string > m_posNetwork3Path
virtual StatusCode initialize() override
Gaudi::Property< std::string > m_posNetwork2Path
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_onnxSvc
static std::string find_calib_file(const std::string &logical_file_name)
void addDependency(const EventIDRange &range)
StatusCode record(const EventIDRange &range, T *t)
record handle, with explicit range DEPRECATED
Primary Vertex Finder.