23 return StatusCode::SUCCESS;
28 auto& session = m_onnxSessionTool->session();
30 m_numInputs = session.GetInputCount();
31 m_numOutputs = session.GetOutputCount();
36 return StatusCode::SUCCESS;
47 for (
auto& shape : m_inputShapes) {
53 for (
auto& shape : m_outputShapes) {
64 return inputDataSize / abs(tensorSize);
72 assert (inputTensors.size() == m_numInputs);
73 assert (outputTensors.size() == m_numOutputs);
77 m_onnxSessionTool->session(),
78 m_inputNodeNames, inputTensors,
79 m_outputNodeNames, outputTensors);
81 return StatusCode::SUCCESS;
90 for (
const auto&
name : m_inputNodeNames) {
95 for (
const auto&
name : m_outputNodeNames) {
100 for (
const auto& shape : m_inputShapes) {
101 std::string shapeStr =
"\t";
102 for (
const auto&
dim : shape) {
109 for (
const auto& shape : m_outputShapes) {
110 std::string shapeStr =
"\t";
111 for (
const auto&
dim : shape) {
121 std::vector<Ort::Value> inputTensors;
122 for (
auto& [inputName, inputInfo] : inputData) {
123 const std::vector<int64_t>& shape = inputInfo.first;
124 if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
125 auto&
data = std::get<std::vector<float>>(inputInfo.second);
127 }
else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
128 auto&
data = std::get<std::vector<int64_t>>(inputInfo.second);
132 return StatusCode::FAILURE;
137 std::vector<Ort::Value> outputTensors;
138 outputTensors.reserve(inputData.size());
139 for (
auto& [
outputName, outputInfo] : outputData) {
140 auto& shape = outputInfo.first;
141 auto tensorSize =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
143 if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
144 auto&
data = std::get<std::vector<float>>(outputInfo.second);
145 data.resize(tensorSize);
147 }
else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
148 auto&
data = std::get<std::vector<int64_t>>(outputInfo.second);
149 data.resize(tensorSize);
153 return StatusCode::FAILURE;
157 ATH_CHECK(inference(inputTensors, outputTensors));
159 return StatusCode::SUCCESS;