16 std::ostream&
operator<<(std::ostream& ostr,
const std::vector<T>&
vec) {
24 template <
typename T1 ,
typename T2>
25 std::ostream&
operator<<(std::ostream& ostr,
const std::pair<T1, T2>& pair) {
26 ostr <<
"(" << pair.first <<
", " << pair.second <<
")";
30 std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(
const std::vector<int64_t>& edges) {
31 std::vector<std::pair<int64_t, int64_t>> indexPairs;
32 const size_t nEdges = edges.size() / 2;
33 indexPairs.reserve(nEdges);
34 for (
size_t i = 0;
i < nEdges; ++
i) {
35 indexPairs.emplace_back(std::make_pair(edges[
i], edges[
i+nEdges]));
40 std::vector<int64_t> makeSortedEdges(
const std::vector<int64_t>& edges) {
41 const size_t nEdges = edges.size() / 2;
42 std::vector<std::pair<int64_t, int64_t>> indexPairs;
43 indexPairs.reserve(nEdges);
46 for (
size_t i = 0;
i < nEdges; ++
i) {
47 indexPairs.emplace_back(edges[
i], edges[
i + nEdges]);
51 std::sort(indexPairs.begin(), indexPairs.end(), [](
const auto&
a,
const auto&
b) {
52 return (a.first < b.first) || (a.first == b.first && a.second < b.second);
56 std::vector<int64_t> sortedEdges;
57 sortedEdges.reserve(2 * nEdges);
58 for (
const auto& pair : indexPairs) {
59 sortedEdges.push_back(pair.first);
61 for (
const auto& pair : indexPairs) {
62 sortedEdges.push_back(pair.second);
68 std::string formatNodeFeatures(
const std::vector<float>& featureLeaves,
size_t numFeaturesPerNode) {
69 std::ostringstream oss;
71 const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
72 oss <<
"Number of nodes: " << numNodes <<
"\n";
73 oss <<
"Features per node: " << numFeaturesPerNode <<
"\n";
75 for (
size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
76 oss <<
"Node[" << nodeIdx <<
"]: [";
77 for (
size_t f = 0;
f < numFeaturesPerNode; ++
f) {
78 if (
f > 0) oss <<
", ";
79 oss << featureLeaves[nodeIdx * numFeaturesPerNode +
f];
96 Ort::AllocatorWithDefaultOptions allocator;
97 Ort::AllocatedStringPtr feature_json_ptr =
metadata.LookupCustomMetadataMapAllocated(
"feature_names", allocator);
99 if (feature_json_ptr) {
100 std::string feature_json = feature_json_ptr.get();
102 for (
const auto& feature : json_obj) {
110 return StatusCode::FAILURE;
112 return StatusCode::SUCCESS;
120 graphData.
graph.reset();
123 if (graphData.
graph) {
124 return StatusCode::SUCCESS;
127 ATH_MSG_ERROR(
"The feature list is in complete. Either it has no features or no node connector set");
128 return StatusCode::FAILURE;
131 graphData.
graph = std::make_unique<InferenceGraph>();
136 int64_t nNodes{0}, possConn{0};
149 graphData.
srcEdges.reserve(possConn);
150 graphData.
desEdges.reserve(possConn);
159 std::make_move_iterator(graphData.
desEdges.end()));
166 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(graphData.
srcEdges.size() / 2)};
169 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
170 graphData.
graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
172 featShape.data(), featShape.size()));
175 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.
srcEdges.data(),
176 graphData.
srcEdges.size(), edgeShape.data(), edgeShape.size());
178 graphData.
graph->dataTensor.emplace_back(std::move(edge_tensor));
186 return StatusCode::SUCCESS;
191 ATH_MSG_ERROR(
"ONNX model is not loaded. Please call setupModel()");
192 return StatusCode::FAILURE;
194 if (!graphData.
graph) {
196 return StatusCode::FAILURE;
199 if (graphData.
graph->dataTensor.size() < 2) {
200 ATH_MSG_ERROR(
"Data tensor does not contain both feature and edge tensors.");
201 return StatusCode::FAILURE;
204 std::vector<const char*> inputNames = {
"features",
"edge_index"};
207 Ort::RunOptions run_options;
208 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
210 std::vector<Ort::Value> outputTensors =
model().Run(run_options,
212 graphData.
graph->dataTensor.data(),
213 graphData.
graph->dataTensor.size(),
217 if (outputTensors.empty()) {
219 return StatusCode::FAILURE;
222 float* output_data = outputTensors[0].GetTensorMutableData<
float>();
223 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
225 std::span<float> predictions(output_data, output_data + output_size);
227 for (
size_t i = 0;
i < output_size;
i++) {
228 if (!std::isfinite(predictions[
i])) {
230 predictions[
i] = -100.0f;
234 graphData.
graph->dataTensor.emplace_back(std::move(outputTensors[0]));
236 return StatusCode::SUCCESS;