125 ONNXTensorElementDataType first_input_type =
126 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
130 const int num_input_nodes =
m_session->GetInputCount();
131 Ort::AllocatorWithDefaultOptions allocator;
132 for (
int i = 0; i < num_input_nodes; i++) {
134#if ORT_API_VERSION > 11
135 Ort::AllocatedStringPtr node_names =
m_session->GetInputNameAllocated(i, allocator);
136 m_storeInputNodeNames.push_back(std::move(node_names));
137 const char *input_name = m_storeInputNodeNames.back().get();
139 const char *input_name =
m_session->GetInputName(i, allocator);
144 Ort::TypeInfo type_info =
m_session->GetInputTypeInfo(i);
148 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
150 first_input_type = tensor_info.GetElementType();
152 if (tensor_info.GetElementType() != first_input_type) {
153 ATH_MSG_ERROR(
"First type was " << first_input_type <<
". In node " << i
155 << tensor_info.GetElementType());
156 throw std::runtime_error(
"Networks with varying input types not "
157 "yet impelmented in TFCSONNXHandler.");
160 std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
169 std::vector<int64_t> dimension_of_node;
170 for (int64_t node_dim : recieved_dimension) {
173 << node_dim <<
" in node named " << input_name
174 <<
". Will treat this as dimension 1.");
175 dimension_of_node.push_back(1);
177 dimension_of_node.push_back(node_dim);
187 ONNXTensorElementDataType first_output_type =
188 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
191 int num_output_nodes =
m_session->GetOutputCount();
192 ATH_MSG_DEBUG(
"Getting " << num_output_nodes <<
" output nodes.");
193 for (
int i = 0; i < num_output_nodes; i++) {
194#if ORT_API_VERSION > 11
195 Ort::AllocatedStringPtr node_names =
m_session->GetOutputNameAllocated(i, allocator);
196 m_storeOutputNodeNames.push_back(std::move(node_names));
197 const char *output_name = m_storeOutputNodeNames.back().get();
199 const char *output_name =
m_session->GetOutputName(i, allocator);
204 const Ort::TypeInfo type_info =
m_session->GetOutputTypeInfo(i);
205 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
207 first_output_type = tensor_info.GetElementType();
210 if (tensor_info.GetElementType() != first_output_type) {
211 ATH_MSG_ERROR(
"First type was " << first_output_type <<
". In node " << i
213 << tensor_info.GetElementType());
214 throw std::runtime_error(
"Networks with varying output types not "
215 "yet impelmented in TFCSONNXHandler.");
218 const std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
222 std::vector<int64_t> dimension_of_node;
224 for (int64_t node_dim : recieved_dimension) {
227 << node_dim <<
" in node named " << output_name
228 <<
". Will treat this as dimension 1.");
229 dimension_of_node.push_back(1);
231 dimension_of_node.push_back(node_dim);
232 node_size *= node_dim;
239 for (
int part_n = 0; part_n < node_size; part_n++) {
241 std::string layer_name =
242 std::string(output_name) +
"_" + std::to_string(part_n);
252 << first_input_type <<
" and output type "
253 << first_output_type);
254 if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
255 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
260 }
else if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE &&
261 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
266 throw std::runtime_error(
"Haven't yet implemented that combination of "
267 "input and output types as a subclass of VState.");
343 ATH_MSG_DEBUG(
"Setting up inputs for computation on ONNX network.");
344 ATH_MSG_DEBUG(
"Input type " <<
typeid(Tin).name() <<
" output type "
345 <<
typeid(Tout).name());
351 std::vector<std::vector<Tin>> input_values(num_input_nodes);
352 std::vector<Ort::Value> node_values;
354 std::string node_name;
355 int n_dimensions, elements_in_node, key_number;
361 <<
" input nodes of ONNX network.");
369 elements_in_node = 1;
371 elements_in_node *= dimension_len;
374 for (
const auto & inp : inputs) {
378 const std::map<std::string, double> node_inputs = inputs.at(node_name);
379 std::vector<Tin> node_elements(elements_in_node);
382 << elements_in_node <<
" elements.");
384 for (
const auto & element : node_inputs){
385 first_digit = element.first.find_first_of(
"0123456789");
387 if (first_digit < element.first.length()){
388 key_number = std::stoi(element.first.substr(first_digit));
389 node_elements[key_number] = element.second;
392 input_values[node_n] = std::move(node_elements);
396 <<
", elements_in_node = " << elements_in_node);
398 Ort::Value
node = Ort::Value::CreateTensor<Tin>(
405 node_values.push_back(std::move(
node));
414 ATH_MSG_DEBUG(
"Sorting outputs from computation on ONNX network.");
419 std::string output_name;
420 const Tout *output_node;
423 output_node = output_tensors[node_n].GetTensorMutableData<Tout>();
431 for (
int part_n = 0; part_n < elements_in_node; part_n++) {
433 << output_node[part_n]);
435 output_name = node_name +
"_" + std::to_string(part_n);
436 outputs[output_name] =
static_cast<double>(output_node[part_n]);
440 ATH_MSG_DEBUG(
"Returning outputs from computation on ONNX network.");