10 #include "Acts/Utilities/MathHelpers.hpp"
15 std::ostream&
operator<<(std::ostream& ostr,
const std::vector<T>&
vec) {
23 template <
typename T1 ,
typename T2>
24 std::ostream&
operator<<(std::ostream& ostr,
const std::pair<T1, T2>& pair) {
25 ostr <<
"(" << pair.first <<
", " << pair.second <<
")";
29 std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(
const std::vector<int64_t>& edges) {
30 std::vector<std::pair<int64_t, int64_t>> indexPairs;
31 const size_t nEdges = edges.size() / 2;
32 indexPairs.reserve(nEdges);
33 for (
size_t i = 0;
i < nEdges; ++
i) {
34 indexPairs.emplace_back(std::make_pair(edges[
i], edges[
i+nEdges]));
39 std::vector<int64_t> makeSortedEdges(
const std::vector<int64_t>& edges) {
40 const size_t nEdges = edges.size() / 2;
41 std::vector<std::pair<int64_t, int64_t>> indexPairs;
42 indexPairs.reserve(nEdges);
45 for (
size_t i = 0;
i < nEdges; ++
i) {
46 indexPairs.emplace_back(edges[
i], edges[
i + nEdges]);
50 std::sort(indexPairs.begin(), indexPairs.end(), [](
const auto&
a,
const auto&
b) {
51 return (a.first < b.first) || (a.first == b.first && a.second < b.second);
55 std::vector<int64_t> sortedEdges;
56 sortedEdges.reserve(2 * nEdges);
57 for (
const auto& pair : indexPairs) {
58 sortedEdges.push_back(pair.first);
60 for (
const auto& pair : indexPairs) {
61 sortedEdges.push_back(pair.second);
67 std::string formatNodeFeatures(
const std::vector<float>& featureLeaves,
size_t numFeaturesPerNode) {
68 std::ostringstream oss;
70 const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
71 oss <<
"Number of nodes: " << numNodes <<
"\n";
72 oss <<
"Features per node: " << numFeaturesPerNode <<
"\n";
74 for (
size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
75 oss <<
"Node[" << nodeIdx <<
"]: [";
76 for (
size_t f = 0;
f < numFeaturesPerNode; ++
f) {
77 if (
f > 0) oss <<
", ";
78 oss << featureLeaves[nodeIdx * numFeaturesPerNode +
f];
95 Ort::AllocatorWithDefaultOptions allocator;
96 Ort::AllocatedStringPtr feature_json_ptr =
metadata.LookupCustomMetadataMapAllocated(
"feature_names", allocator);
98 if (feature_json_ptr) {
99 std::string feature_json = feature_json_ptr.get();
101 for (
const auto& feature : json_obj) {
109 return StatusCode::FAILURE;
111 return StatusCode::SUCCESS;
119 graphData.
graph.reset();
122 if (graphData.
graph) {
123 return StatusCode::SUCCESS;
126 ATH_MSG_ERROR(
"The feature list is in complete. Either it has no features or no node connector set");
127 return StatusCode::FAILURE;
130 graphData.
graph = std::make_unique<InferenceGraph>();
135 int64_t nNodes{0}, possConn{0};
148 graphData.
srcEdges.reserve(possConn);
149 graphData.
desEdges.reserve(possConn);
158 std::make_move_iterator(graphData.
desEdges.end()));
165 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(graphData.
srcEdges.size() / 2)};
168 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
169 graphData.
graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
171 featShape.data(), featShape.size()));
174 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.
srcEdges.data(),
175 graphData.
srcEdges.size(), edgeShape.data(), edgeShape.size());
177 graphData.
graph->dataTensor.emplace_back(std::move(edge_tensor));
185 return StatusCode::SUCCESS;
190 ATH_MSG_ERROR(
"ONNX model is not loaded. Please call setupModel()");
191 return StatusCode::FAILURE;
193 if (!graphData.
graph) {
195 return StatusCode::FAILURE;
198 if (graphData.
graph->dataTensor.size() < 2) {
199 ATH_MSG_ERROR(
"Data tensor does not contain both feature and edge tensors.");
200 return StatusCode::FAILURE;
203 std::vector<const char*> inputNames = {
"features",
"edge_index"};
206 Ort::RunOptions run_options;
207 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
209 std::vector<Ort::Value> outputTensors =
model().Run(run_options,
211 graphData.
graph->dataTensor.data(),
212 graphData.
graph->dataTensor.size(),
216 if (outputTensors.empty()) {
218 return StatusCode::FAILURE;
221 float* output_data = outputTensors[0].GetTensorMutableData<
float>();
222 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
224 std::span<float> predictions(output_data, output_data + output_size);
226 for (
size_t i = 0;
i < output_size;
i++) {
227 if (!std::isfinite(predictions[
i])) {
229 predictions[
i] = -100.0f;
233 graphData.
graph->dataTensor.emplace_back(std::move(outputTensors[0]));
235 return StatusCode::SUCCESS;