Generic named inference, for tools with different I/O conventions.
363{
364 if (!graphData.
graph) {
366 return StatusCode::FAILURE;
367 }
368 if (graphData.
graph->dataTensor.empty()) {
370 return StatusCode::FAILURE;
371 }
372
373
374
375 graphData.
graph->dataTensor.reserve(inputNames.size() +
outputNames.size());
376 if (graphData.
graph->dataTensor.size() < inputNames.size()) {
378 << " tensors but inference expects " << inputNames.size() << " inputs.");
379 return StatusCode::FAILURE;
380 }
381
382 if (msgLvl(MSG::DEBUG)) {
383
384
386 if (!graphData.
graph->dataTensor.empty()) {
387 const auto& featureTensor = graphData.
graph->dataTensor[0];
388 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
390 << (featShape.size()>1 ? ("," + std::to_string(featShape[1])) : "")
391 << (featShape.size()>2 ? ("," + std::to_string(featShape[2])) : "") << "]");
392
393 float* featData = const_cast<Ort::Value&>(featureTensor).GetTensorMutableData<float>();
394 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
395 ATH_MSG_DEBUG(
"Features tensor total elements: " << totalElements);
396
397
398 const size_t nFeat = (featShape.size() > 1 && featShape[1] > 0) ? static_cast<size_t>(featShape[1]) : 1;
399 const size_t nNodes = totalElements / nFeat;
400 const size_t debugNodes = std::min(nNodes, static_cast<size_t>(10));
401
402
403
404 std::vector<std::string> featNames;
405 {
406 Ort::AllocatorWithDefaultOptions allocator;
407 Ort::ModelMetadata meta =
model().GetModelMetadata();
408 auto keys = meta.GetCustomMetadataMapKeysAllocated(allocator);
409 std::vector<std::string> keyNames;
410 keyNames.reserve(
keys.size());
411 for (
const auto& k : keys) keyNames.emplace_back(
k.get());
413 "x_feature_names", "node_feature_names", "feature_names", "input_feature_names"};
414 for (const std::string& key : candidates) {
415 if (std::find(keyNames.begin(), keyNames.end(), key) != keyNames.end()) {
416 std::string
val = meta.LookupCustomMetadataMapAllocated(
key.c_str(), allocator).get();
418 break;
419 }
420 }
421 if (featNames.empty()) {
422 ATH_MSG_DEBUG(
"No usable feature-name metadata key found in model; using generic fN labels.");
423 }
424 }
425 auto featLabel = [&](
size_t f) -> std::string {
426 if (f < featNames.size())
return featNames[
f];
427 return "f" + std::to_string(f);
428 };
429
430
431 {
432 std::ostringstream
legend;
433 legend <<
"Node feature legend (" << nFeat <<
" features):";
434 for (
size_t f = 0;
f < nFeat; ++
f) {
435 legend <<
" f" <<
f <<
"=" << featLabel(f);
436 if (f + 1 < nFeat)
legend <<
",";
437 }
439 }
440
441 for (
size_t n = 0;
n < debugNodes; ++
n) {
442 std::ostringstream
row;
443 row <<
"ONNXNode[" <<
n <<
"]:";
444 for (
size_t f = 0;
f < nFeat; ++
f) {
445 row <<
" f" <<
f <<
"=" << featData[
n * nFeat +
f];
446 if (f + 1 < nFeat)
row <<
",";
447 }
449 }
450 }
452 }
453
454 Ort::RunOptions run_options;
455 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
456
458
459 Ort::IoBinding binding(
model());
460 for (std::size_t i = 0;
i < inputNames.size(); ++
i) {
461 binding.BindInput(inputNames[i], graphData.
graph->dataTensor[i]);
462 }
463
464 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
465 for (const char* outName : outputNames) {
466 binding.BindOutput(outName, cpuOut);
467 }
468
469 model().Run(run_options, binding);
470 binding.SynchronizeOutputs();
471
472 std::vector<Ort::Value>
outputs = binding.GetOutputValues();
475 return StatusCode::FAILURE;
476 }
477
478 float* outData =
outputs[0].GetTensorMutableData<
float>();
479 const size_t outSize =
outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
480 ATH_MSG_DEBUG(
"ONNX (IoBinding) raw output elementCount = " << outSize);
481
483 std::span<float> preds(outData, outData + outSize);
484 for (
size_t i = 0;
i < outSize; ++
i) {
485 if (!std::isfinite(preds[i])) {
486 ATH_MSG_WARNING(
"Non-finite prediction detected at " << i <<
" -> set to -100.");
488 }
489 }
490 }
491
492 for (auto& v : outputs) {
493 graphData.
graph->dataTensor.emplace_back(std::move(v));
494 }
495 return StatusCode::SUCCESS;
496 }
497
498
499 std::vector<Ort::Value>
outputs =
500 model().Run(run_options,
501 inputNames.data(),
502 graphData.
graph->dataTensor.data(),
503 inputNames.size(),
506
509 return StatusCode::FAILURE;
510 }
511
512 float* outData =
outputs[0].GetTensorMutableData<
float>();
513 const size_t outSize =
outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
515
517 std::span<float> preds(outData, outData + outSize);
518 for (
size_t i = 0;
i < outSize; ++
i) {
519 if (!std::isfinite(preds[i])) {
520 ATH_MSG_WARNING(
"Non-finite prediction detected at " << i <<
" -> set to -100.");
522 }
523 }
524 }
525
526 for (auto& v : outputs) {
527 graphData.
graph->dataTensor.emplace_back(std::move(v));
528 }
529 return StatusCode::SUCCESS;
530}
row
Appending html table to final .html summary file.