18 const Ort::Session& session,
19 std::vector<std::vector<int64_t> >& dataShape,
20 std::vector<std::string>& nodeNames,
26 size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
27 dataShape.reserve(numNodes);
28 nodeNames.reserve(numNodes);
30 Ort::AllocatorWithDefaultOptions allocator;
31 for( std::size_t i = 0; i < numNodes; i++ ) {
32 Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
33 auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
34 dataShape.emplace_back(tensorInfo.GetShape());
36 auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
37 nodeNames.emplace_back(nodeName.get());
58 const std::vector<std::string>& inputNames,
59 const std::vector<Ort::Value>& inputData,
60 const std::vector<std::string>& outputNames,
61 const std::vector<Ort::Value>& outputData){
63 if (inputNames.empty()) {
64 throw std::runtime_error(
"Onnxruntime input data maping cannot be empty");
66 assert(inputNames.size() == inputData.size());
68 Ort::IoBinding iobinding(session);
69 for(
size_t idx = 0; idx < inputNames.size(); ++idx){
70 iobinding.BindInput(inputNames[idx].
data(), inputData[idx]);
74 for(
size_t idx = 0; idx < outputNames.size(); ++idx){
75 iobinding.BindOutput(outputNames[idx].
data(), outputData[idx]);
78 session.Run(Ort::RunOptions{
nullptr}, iobinding);
83 const std::vector<std::string>& inputNames,
84 const std::vector<Ort::Value>& inputData,
85 const std::vector<std::string>& outputNames,
86 std::vector<Ort::Value>& outputData,
88 if (inputNames.empty()) {
89 throw std::runtime_error(
"Onnxruntime input data mapping cannot be empty");
91 assert(inputNames.size() == inputData.size());
92 assert(outputNames.size() == outputData.size());
94 Ort::RunOptions runOptions{};
97 std::vector<const char*> inputNamesArray{};
98 std::vector<const char*> outputNamesArray{};
99 inputNamesArray.reserve(inputNames.size());
100 outputNamesArray.reserve(outputNames.size());
101 for (
const auto& name : inputNames) {
102 inputNamesArray.push_back(name.c_str());
104 for (
const auto& name : outputNames) {
105 outputNamesArray.push_back(name.c_str());
109 using Promise_t = boost::fibers::promise<std::string>;
111 boost::fibers::future<std::string> future{promise.get_future()};
114 const auto callback = [](
void* promise, OrtValue**, std::size_t,
115 OrtStatusPtr statusPtr)
mutable {
116 std::string errorMsg{};
117 if (statusPtr !=
nullptr) {
118 Ort::Status status{statusPtr};
119 if (!status.IsOK()) {
120 errorMsg = status.GetErrorMessage();
123 static_cast<Promise_t*
>(promise)->set_value(errorMsg);
127 session.RunAsync(runOptions, inputNamesArray.data(), inputData.data(),
128 inputData.size(), outputNamesArray.data(), outputData.data(),
129 outputData.size(), callback,
static_cast<void*
>(&promise));
131 std::string errorMsg = future.get();
132 parentAlg->
restoreAfterSuspend().orThrow(
"Failed to restore after suspension",
"AsyncAlg");
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
std::string asyncInference(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, std::vector< Ort::Value > &outputData, const AthAsynchronousAlgorithm *parentAlg)