23 return StatusCode::SUCCESS;
36 return StatusCode::SUCCESS;
64 return inputDataSize / abs(tensorSize);
81 return StatusCode::SUCCESS;
101 std::string shapeStr =
"\t";
102 for (
const auto& dim : shape) {
103 shapeStr += std::to_string(dim) +
" ";
110 std::string shapeStr =
"\t";
111 for (
const auto& dim : shape) {
112 shapeStr += std::to_string(dim) +
" ";
121 std::vector<Ort::Value> inputTensors;
122 for (
auto& [inputName, inputInfo] : inputData) {
123 const std::vector<int64_t>& shape = inputInfo.first;
124 if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
125 auto&
data = std::get<std::vector<float>>(inputInfo.second);
127 }
else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
128 auto&
data = std::get<std::vector<int64_t>>(inputInfo.second);
132 return StatusCode::FAILURE;
137 std::vector<Ort::Value> outputTensors;
138 outputTensors.reserve(inputData.size());
140 if (outputData.find(outName) == outputData.end()) {
141 ATH_MSG_ERROR(
"Output name " << outName <<
" not found in output data map");
142 return StatusCode::FAILURE;
144 auto& outputInfo = outputData.at(outName);
145 auto& shape = outputInfo.first;
146 auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
148 if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
149 auto&
data = std::get<std::vector<float>>(outputInfo.second);
150 data.resize(tensorSize);
152 }
else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
153 auto&
data = std::get<std::vector<int64_t>>(outputInfo.second);
154 data.resize(tensorSize);
158 return StatusCode::FAILURE;
164 return StatusCode::SUCCESS;
#define ATH_CHECK
Evaluate an expression and check for errors.
char data[hepevt_bytes_allocation_ATLAS]
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)