29 return StatusCode::SUCCESS;
34 graphData.
graph = std::make_unique<InferenceGraph>();
46 std::vector<BucketGraphUtils::NodeAux> nodes;
51 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
53 <<
" -> nodes (size>0): " << numNodes
57 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping inference.");
58 return StatusCode::SUCCESS;
61 const int64_t nFeatPerNode = 6;
62 if (numNodes * nFeatPerNode !=
static_cast<int64_t
>(graphData.
featureLeaves.size())) {
63 ATH_MSG_ERROR(
"Feature size mismatch: expected " << (numNodes * nFeatPerNode)
65 return StatusCode::FAILURE;
68 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
69 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
70 graphData.
graph->dataTensor.emplace_back(
71 Ort::Value::CreateTensor<float>(memInfo,
76 return StatusCode::SUCCESS;
86 const int64_t
S =
static_cast<int64_t
>(featuresFlat.size() / 6);
89 ATH_MSG_WARNING(
"No valid features for transformer input. Skipping inference.");
90 return StatusCode::SUCCESS;
95 ATH_MSG_DEBUG(
"=== DEBUGGING: Transformer input features for first 10 nodes ===");
96 const int64_t debugNodes =
std::min(
S,
static_cast<int64_t
>(10));
97 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
98 const int64_t baseIdx = nodeIdx * 6;
100 <<
"x=" << featuresFlat[baseIdx + 0] <<
", "
101 <<
"y=" << featuresFlat[baseIdx + 1] <<
", "
102 <<
"z=" << featuresFlat[baseIdx + 2] <<
", "
103 <<
"layers=" << featuresFlat[baseIdx + 3] <<
", "
104 <<
"nSp=" << featuresFlat[baseIdx + 4] <<
", "
105 <<
"bucketSize=" << featuresFlat[baseIdx + 5]);
111 graphData.
graph = std::make_unique<InferenceGraph>();
113 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
116 std::vector<int64_t> fShape{1,
S, 6};
118 graphData.
graph->dataTensor.emplace_back(
119 Ort::Value::CreateTensor<float>(memInfo,
126 Ort::AllocatorWithDefaultOptions allocator;
127 std::vector<int64_t> mShape{1,
S};
128 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
131 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
132 bool* maskPtr = padVal.GetTensorMutableData<
bool>();
133 for (int64_t
i = 0;
i <
S; ++
i) maskPtr[
i] =
false;
134 graphData.
graph->dataTensor.emplace_back(std::move(padVal));
136 return StatusCode::SUCCESS;
149 std::vector<BucketGraphUtils::NodeAux> nodes;
150 std::vector<float> throwawayFeatures;
151 std::vector<int64_t> throwawaySp;
154 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
156 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping graph building.");
157 return StatusCode::SUCCESS;
160 std::vector<int64_t> srcEdges, dstEdges;
170 std::vector<int64_t> newSrc;
171 std::vector<int64_t> newDst;
172 newSrc.reserve(srcEdges.size());
173 newDst.reserve(dstEdges.size());
174 for (
size_t k = 0;
k < srcEdges.size(); ++
k) {
175 const int64_t
u = srcEdges[
k];
176 const int64_t
v = dstEdges[
k];
177 const bool okU = (
u >= 0 &&
u < numNodes);
178 const bool okV = (
v >= 0 &&
v < numNodes);
185 <<
"), valid node range [0," << (numNodes-1) <<
"]");
191 srcEdges.swap(newSrc);
192 dstEdges.swap(newDst);
196 const size_t E = srcEdges.size();
202 for (
unsigned int k = 0;
k < dumpE; ++
k) {
203 ATH_MSG_DEBUG(
"EDGE[" <<
k <<
"]: " << srcEdges[
k] <<
" -> " << dstEdges[
k]);
205 std::vector<int> nodeConnections(numNodes, 0);
206 for (
size_t k = 0;
k < srcEdges.size(); ++
k) {
207 const int64_t
u = srcEdges[
k];
208 const int64_t
v = dstEdges[
k];
209 if (
u >= 0 &&
u < numNodes) nodeConnections[
u]++;
210 if (
v >= 0 &&
v < numNodes) nodeConnections[
v]++;
213 ATH_MSG_INFO(
"=== DEBUGGING: Node Connections (first 10 nodes) ===");
214 const int64_t debugNodeCount =
std::min(numNodes,
static_cast<int64_t
>(10));
215 for (int64_t
i = 0;
i < debugNodeCount; ++
i) {
216 ATH_MSG_DEBUG(
"Node[" <<
i <<
"] connections: " << nodeConnections[
i]);
221 ATH_MSG_DEBUG(
"=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
222 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
223 std::stringstream connections;
224 connections <<
"Node[" << nodeIdx <<
"] connected to: ";
225 bool foundAny =
false;
227 for (
size_t k = 0;
k < srcEdges.size(); ++
k) {
228 const int64_t
u = srcEdges[
k];
229 const int64_t
v = dstEdges[
k];
232 if (foundAny) connections <<
", ";
235 }
else if (
v == nodeIdx) {
236 if (foundAny) connections <<
", ";
242 if (!foundAny) connections <<
"none";
251 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
252 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(Efinal)};
253 graphData.
graph->dataTensor.emplace_back(
254 Ort::Value::CreateTensor<int64_t>(memInfo,
260 ATH_MSG_DEBUG(
"Built sparse bucket graph: N=" << numNodes <<
", E=" << Efinal);
261 return StatusCode::SUCCESS;
266 const std::vector<const char*>& inputNames,
269 if (!graphData.
graph) {
271 return StatusCode::FAILURE;
273 if (graphData.
graph->dataTensor.empty()) {
275 return StatusCode::FAILURE;
282 if (!graphData.
graph->dataTensor.empty()) {
283 const auto& featureTensor = graphData.
graph->dataTensor[0];
284 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
286 << (featShape.size()>1 ? (
"," +
std::to_string(featShape[1])) :
"")
287 << (featShape.size()>2 ? (
"," +
std::to_string(featShape[2])) :
"") <<
"]");
289 float* featData =
const_cast<Ort::Value&
>(featureTensor).GetTensorMutableData<float>();
290 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
291 ATH_MSG_DEBUG(
"Features tensor total elements: " << totalElements);
294 const size_t debugElements =
std::min(totalElements,
static_cast<size_t>(60));
295 for (
size_t i = 0;
i < debugElements;
i += 6) {
296 if (
i + 5 < totalElements) {
298 <<
"x=" << featData[
i+0] <<
", "
299 <<
"y=" << featData[
i+1] <<
", "
300 <<
"z=" << featData[
i+2] <<
", "
301 <<
"layers=" << featData[
i+3] <<
", "
302 <<
"nSp=" << featData[
i+4] <<
", "
303 <<
"bucketSize=" << featData[
i+5]);
310 Ort::RunOptions run_options;
311 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
313 std::vector<Ort::Value>
outputs =
314 model().Run(run_options,
316 graphData.
graph->dataTensor.data(),
317 graphData.
graph->dataTensor.size(),
323 return StatusCode::FAILURE;
326 float* outData =
outputs[0].GetTensorMutableData<
float>();
327 const size_t outSize =
outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
330 std::span<float> preds(outData, outData + outSize);
331 for (
size_t i = 0;
i < outSize; ++
i) {
332 if (!std::isfinite(preds[
i])) {
333 ATH_MSG_WARNING(
"Non-finite prediction detected at " <<
i <<
" -> set to -100.");
339 graphData.
graph->dataTensor.emplace_back(std::move(
v));
341 return StatusCode::SUCCESS;
345 std::vector<const char*> inputNames = {
"features",
"edge_index"};