25 return StatusCode::SUCCESS;
30 auto& session = m_onnxSessionTool->session();
32 m_numInputs = session.GetInputCount();
33 m_numOutputs = session.GetOutputCount();
38 return StatusCode::SUCCESS;
49 for (
auto& shape : m_inputShapes) {
55 for (
auto& shape : m_outputShapes) {
66 return inputDataSize / abs(tensorSize);
74 assert (inputTensors.size() == m_numInputs);
75 assert (outputTensors.size() == m_numOutputs);
79 m_onnxSessionTool->session(),
80 m_inputNodeNames, inputTensors,
81 m_outputNodeNames, outputTensors);
83 return StatusCode::SUCCESS;
92 for (
const auto&
name : m_inputNodeNames) {
97 for (
const auto&
name : m_outputNodeNames) {
102 for (
const auto& shape : m_inputShapes) {
103 std::string shapeStr =
"\t";
104 for (
const auto&
dim : shape) {
111 for (
const auto& shape : m_outputShapes) {
112 std::string shapeStr =
"\t";
113 for (
const auto&
dim : shape) {
123 std::vector<Ort::Value> inputTensors;
124 for (
auto& [inputName, inputInfo] : inputData) {
125 const std::vector<int64_t>& shape = inputInfo.first;
126 if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
127 auto&
data = std::get<std::vector<float>>(inputInfo.second);
129 }
else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
130 auto&
data = std::get<std::vector<int64_t>>(inputInfo.second);
134 return StatusCode::FAILURE;
139 std::vector<Ort::Value> outputTensors;
140 outputTensors.reserve(inputData.size());
141 for (
auto& [
outputName, outputInfo] : outputData) {
142 auto& shape = outputInfo.first;
143 auto tensorSize =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
145 if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
146 auto&
data = std::get<std::vector<float>>(outputInfo.second);
147 data.resize(tensorSize);
149 }
else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
150 auto&
data = std::get<std::vector<int64_t>>(outputInfo.second);
151 data.resize(tensorSize);
155 return StatusCode::FAILURE;
159 ATH_CHECK(inference(inputTensors, outputTensors));
161 return StatusCode::SUCCESS;