13 #include <onnxruntime_cxx_api.h>
15 #include "nlohmann/json.hpp"
16 #include "lwtnn/parse_json.hh"
28 using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
46 using OutputConfig = std::vector<OnnxOutput>;
48 OnnxUtil(
const std::string& path_to_onnx);
52 struct InferenceOutput {
53 std::map<std::string, float> singleFloat;
54 std::map<std::string, std::vector<char>> vecChar;
55 std::map<std::string, std::vector<float>> vecFloat;
58 InferenceOutput runInference(std::map<std::string, Inputs>& gnn_inputs)
const;
60 const lwt::GraphConfig getLwtConfig()
const;
62 const OutputConfig& getOutputConfig()
const;
64 const std::string& getModelName()
const;
68 const std::string determineModelName()
const;
71 std::string m_path_to_onnx;
73 std::unique_ptr< Ort::Session > m_session;
74 std::unique_ptr< Ort::Env > m_env;
78 std::string m_model_name;
79 std::vector<std::string> m_input_node_names;
80 OutputConfig m_output_nodes;