14 #include <onnxruntime_cxx_api.h>
80 strm <<
"Unknown network";
84 strm <<
"\nHas input nodes (name:dimensions);\n";
88 strm <<
" " <<
dim <<
",";
92 strm <<
"\nHas output nodes (name:dimensions);\n";
96 strm <<
" " <<
dim <<
",";
103 ATH_MSG_DEBUG(
"Setting up persisted variables for ONNX network.");
109 ATH_MSG_DEBUG(
"Setup persisted variables for ONNX network.");
125 ONNXTensorElementDataType first_input_type =
126 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
130 const int num_input_nodes =
m_session->GetInputCount();
131 Ort::AllocatorWithDefaultOptions allocator;
132 for (
int i = 0;
i < num_input_nodes;
i++) {
134 #if ORT_API_VERSION > 11
135 Ort::AllocatedStringPtr node_names =
m_session->GetInputNameAllocated(
i, allocator);
136 m_storeInputNodeNames.push_back(std::move(node_names));
137 const char *input_name = m_storeInputNodeNames.back().get();
139 const char *input_name =
m_session->GetInputName(
i, allocator);
144 Ort::TypeInfo type_info =
m_session->GetInputTypeInfo(
i);
148 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
150 first_input_type = tensor_info.GetElementType();
152 if (tensor_info.GetElementType() != first_input_type) {
153 ATH_MSG_ERROR(
"First type was " << first_input_type <<
". In node " <<
i
155 << tensor_info.GetElementType());
156 throw std::runtime_error(
"Networks with varying input types not "
157 "yet impelmented in TFCSONNXHandler.");
160 std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
169 std::vector<int64_t> dimension_of_node;
170 for (int64_t node_dim : recieved_dimension) {
173 << node_dim <<
" in node named " << input_name
174 <<
". Will treat this as dimension 1.");
175 dimension_of_node.push_back(1);
177 dimension_of_node.push_back(node_dim);
187 ONNXTensorElementDataType first_output_type =
188 ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
191 int num_output_nodes =
m_session->GetOutputCount();
192 ATH_MSG_DEBUG(
"Getting " << num_output_nodes <<
" output nodes.");
193 for (
int i = 0;
i < num_output_nodes;
i++) {
194 #if ORT_API_VERSION > 11
195 Ort::AllocatedStringPtr node_names =
m_session->GetOutputNameAllocated(
i, allocator);
196 m_storeOutputNodeNames.push_back(std::move(node_names));
197 const char *output_name = m_storeOutputNodeNames.back().get();
199 const char *output_name =
m_session->GetOutputName(
i, allocator);
204 const Ort::TypeInfo type_info =
m_session->GetOutputTypeInfo(
i);
205 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
207 first_output_type = tensor_info.GetElementType();
210 if (tensor_info.GetElementType() != first_output_type) {
211 ATH_MSG_ERROR(
"First type was " << first_output_type <<
". In node " <<
i
213 << tensor_info.GetElementType());
214 throw std::runtime_error(
"Networks with varying output types not "
215 "yet impelmented in TFCSONNXHandler.");
218 const std::vector<int64_t> recieved_dimension = tensor_info.GetShape();
222 std::vector<int64_t> dimension_of_node;
224 for (int64_t node_dim : recieved_dimension) {
227 << node_dim <<
" in node named " << output_name
228 <<
". Will treat this as dimension 1.");
229 dimension_of_node.push_back(1);
231 dimension_of_node.push_back(node_dim);
232 node_size *= node_dim;
239 for (
int part_n = 0; part_n < node_size; part_n++) {
241 std::string layer_name =
252 << first_input_type <<
" and output type "
253 << first_output_type);
254 if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
255 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
258 return computeTemplate<float, float>(
inputs);
260 }
else if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE &&
261 first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
263 return computeTemplate<double, double>(
inputs);
266 throw std::runtime_error(
"Haven't yet implemented that combination of "
267 "input and output types as a subclass of VState.");
274 ATH_MSG_DEBUG(
"Getting serialized session for ONNX network.");
279 TTree *
tree = (TTree *)
tfile.Get(tree_name.c_str());
288 std::vector<char> bytes((std::istreambuf_iterator<char>(
input)),
289 (std::istreambuf_iterator<char>()));
299 std::vector<char> bytes;
301 tree.SetBranchAddress(
"serialized_m_session", &
data);
302 for (
int i = 0;
tree.LoadTree(
i) >= 0;
i++) {
304 bytes.push_back(
data);
311 const std::vector<char> &bytes) {
314 tree.Branch(
"serialized_m_session", &m_session_data,
315 "serialized_m_session/B");
316 for (Char_t here : bytes) {
317 m_session_data = here;
326 Ort::Env
env(ORT_LOGGING_LEVEL_WARNING,
"test");
327 Ort::SessionOptions
opts;
328 opts.SetInterOpNumThreads(1);
329 opts.SetIntraOpNumThreads(1);
336 template <
typename Tin,
typename Tout>
343 ATH_MSG_DEBUG(
"Setting up inputs for computation on ONNX network.");
345 <<
typeid(Tout).
name());
351 std::vector<std::vector<Tin>> input_values(num_input_nodes);
352 std::vector<Ort::Value> node_values;
354 std::string node_name;
355 int n_dimensions, elements_in_node, key_number;
361 <<
" input nodes of ONNX network.");
369 elements_in_node = 1;
371 elements_in_node *= dimension_len;
378 const std::map<std::string, double> node_inputs =
inputs.at(node_name);
379 std::vector<Tin> node_elements(elements_in_node);
382 << elements_in_node <<
" elements.");
384 for (
auto element : node_inputs){
385 first_digit = element.first.find_first_of(
"0123456789");
387 if (first_digit < element.first.length()){
388 key_number = std::stoi(element.first.substr(first_digit));
389 node_elements[key_number] = element.second;
392 input_values[node_n] = node_elements;
396 <<
", elements_in_node = " << elements_in_node);
398 Ort::Value
node = Ort::Value::CreateTensor<Tin>(
405 node_values.push_back(std::move(
node));
414 ATH_MSG_DEBUG(
"Sorting outputs from computation on ONNX network.");
419 std::string output_name;
420 const Tout *output_node;
423 output_node = output_tensors[node_n].GetTensorMutableData<Tout>();
431 for (
int part_n = 0; part_n < elements_in_node; part_n++) {
433 << output_node[part_n]);
436 outputs[output_name] =
static_cast<double>(output_node[part_n]);
440 ATH_MSG_DEBUG(
"Returning outputs from computation on ONNX network.");
450 void TFCSONNXHandler::Streamer(TBuffer &buf) {
452 if (buf.IsReading()) {
455 TFCSONNXHandler::Class()->ReadBuffer(buf,
this);
459 #ifndef __FastCaloSimStandAlone__
467 TFCSONNXHandler::Class()->WriteBuffer(buf,
this);