10#include "Acts/Utilities/MathHelpers.hpp"
16 std::ostream&
operator<<(std::ostream& ostr,
const std::vector<T>&
vec) {
18 for (
const T& val :
vec) {
24 template <
typename T1 ,
typename T2>
25 std::ostream&
operator<<(std::ostream& ostr,
const std::pair<T1, T2>&
pair) {
26 ostr <<
"(" <<
pair.first <<
", " <<
pair.second <<
")";
30 std::vector<std::pair<int64_t, int64_t>> makeIndexPairs(
const std::vector<int64_t>& edges) {
31 std::vector<std::pair<int64_t, int64_t>> indexPairs;
32 const size_t nEdges = edges.size() / 2;
33 indexPairs.reserve(nEdges);
34 for (
size_t i = 0;
i < nEdges; ++
i) {
35 indexPairs.emplace_back(std::make_pair(edges[i], edges[i+nEdges]));
40 std::vector<int64_t> makeSortedEdges(
const std::vector<int64_t>& edges) {
41 const size_t nEdges = edges.size() / 2;
42 std::vector<std::pair<int64_t, int64_t>> indexPairs;
43 indexPairs.reserve(nEdges);
46 for (
size_t i = 0;
i < nEdges; ++
i) {
47 indexPairs.emplace_back(edges[i], edges[i + nEdges]);
51 std::sort(indexPairs.begin(), indexPairs.end(), [](
const auto&
a,
const auto& b) {
52 return (a.first < b.first) || (a.first == b.first && a.second < b.second);
56 std::vector<int64_t> sortedEdges;
57 sortedEdges.reserve(2 * nEdges);
58 for (
const auto&
pair : indexPairs) {
59 sortedEdges.push_back(
pair.first);
61 for (
const auto&
pair : indexPairs) {
62 sortedEdges.push_back(
pair.second);
68 std::string formatNodeFeatures(
const std::vector<float>& featureLeaves,
size_t numFeaturesPerNode) {
69 std::ostringstream oss;
71 const size_t numNodes = featureLeaves.size() / numFeaturesPerNode;
72 oss <<
"Number of nodes: " << numNodes <<
"\n";
73 oss <<
"Features per node: " << numFeaturesPerNode <<
"\n";
75 for (
size_t nodeIdx = 0; nodeIdx < numNodes; ++nodeIdx) {
76 oss <<
"Node[" << nodeIdx <<
"]: [";
77 for (
size_t f = 0;
f < numFeaturesPerNode; ++
f) {
78 if (f > 0) oss <<
", ";
79 oss << featureLeaves[nodeIdx * numFeaturesPerNode +
f];
101 <<
". I/O binding will be used.");
107 Ort::ModelMetadata metadata =
model().GetModelMetadata();
108 Ort::AllocatorWithDefaultOptions allocator;
109 Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated(
"feature_names", allocator);
111 if (feature_json_ptr) {
112 std::string feature_json = feature_json_ptr.get();
113 nlohmann::json json_obj = nlohmann::json::parse(feature_json);
114 for (
const auto& feature : json_obj) {
122 return StatusCode::FAILURE;
124 return StatusCode::SUCCESS;
132 graphData.
graph.reset();
135 if (graphData.
graph) {
136 return StatusCode::SUCCESS;
139 ATH_MSG_ERROR(
"The feature list is in complete. Either it has no features or no node connector set");
140 return StatusCode::FAILURE;
143 graphData.
graph = std::make_unique<InferenceGraph>();
148 int64_t nNodes{0}, possConn{0};
161 graphData.
srcEdges.reserve(possConn);
162 graphData.
desEdges.reserve(possConn);
171 std::make_move_iterator(graphData.
desEdges.end()));
177 std::vector<int64_t> featShape{nNodes,
static_cast<int64_t
>(
m_graphFeatures.numFeatures())};
178 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(graphData.
srcEdges.size() / 2)};
181 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
182 graphData.
graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo,
184 featShape.data(), featShape.size()));
187 Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.
srcEdges.data(),
188 graphData.
srcEdges.size(), edgeShape.data(), edgeShape.size());
190 graphData.
graph->dataTensor.emplace_back(std::move(edge_tensor));
198 return StatusCode::SUCCESS;
203 ATH_MSG_ERROR(
"ONNX model is not loaded. Please call setupModel()");
204 return StatusCode::FAILURE;
206 if (!graphData.
graph) {
208 return StatusCode::FAILURE;
211 if (graphData.
graph->dataTensor.size() < 2) {
212 ATH_MSG_ERROR(
"Data tensor does not contain both feature and edge tensors.");
213 return StatusCode::FAILURE;
216 std::vector<const char*> inputNames = {
"features",
"edge_index"};
217 std::vector<const char*> outputNames = {
"output"};
219 Ort::RunOptions run_options;
220 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
224 Ort::IoBinding binding(
model());
225 for (std::size_t i = 0; i < inputNames.size(); ++i) {
226 binding.BindInput(inputNames[i], graphData.
graph->dataTensor[i]);
229 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
230 for (
const char* outName : outputNames) {
231 binding.BindOutput(outName, cpuOut);
234 model().Run(run_options, binding);
235 binding.SynchronizeOutputs();
237 std::vector<Ort::Value> outputTensors = binding.GetOutputValues();
238 if (outputTensors.empty()) {
240 return StatusCode::FAILURE;
243 float* output_data = outputTensors[0].GetTensorMutableData<
float>();
244 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
246 std::span<float> predictions(output_data, output_data + output_size);
247 for (
size_t i = 0; i < output_size; i++) {
248 if (!std::isfinite(predictions[i])) {
250 predictions[i] = -100.0f;
253 graphData.
graph->dataTensor.emplace_back(std::move(outputTensors[0]));
254 return StatusCode::SUCCESS;
258 std::vector<Ort::Value> outputTensors =
model().Run(run_options,
260 graphData.
graph->dataTensor.data(),
261 graphData.
graph->dataTensor.size(),
265 if (outputTensors.empty()) {
267 return StatusCode::FAILURE;
270 float* output_data = outputTensors[0].GetTensorMutableData<
float>();
271 size_t output_size = outputTensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
273 std::span<float> predictions(output_data, output_data + output_size);
275 for (
size_t i = 0; i < output_size; i++) {
276 if (!std::isfinite(predictions[i])) {
278 predictions[i] = -100.0f;
282 graphData.
graph->dataTensor.emplace_back(std::move(outputTensors[0]));
284 return StatusCode::SUCCESS;
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_WARNING(x)
std::vector< size_t > vec
std::ostream & operator<<(std::ostream &lhs, const TestGaudiProperty &rhs)
The LayerSpBucket is a space pointbucket where the points are internally sorted by their layer number...
: The muon space point bucket represents a collection of points that will bre processed together in t...
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
Helper struct to ship the Graph from the space point buckets to ONNX.
FeatureVec_t featureLeaves
Vector containing all features.
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
unsigned int nodeIndex
Number of the already filled nodes.
const NodeFeatureList * previousList
Pointer to the latest parsed NodeFeatureList.
std::vector< float >::iterator currLeave
The following variables are needed to fill the consistently the raw data for the Graph Building.
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
EdgeCounterVec_t desEdges
Vect.
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.