13 : m_path_to_onnx (
name)
22 m_env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING,
"");
25 Ort::SessionOptions session_options;
26 session_options.SetIntraOpNumThreads(1);
27 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
30 m_session = std::make_unique< Ort::Session >(*m_env, fullPathToFile.c_str(), session_options);
31 Ort::AllocatorWithDefaultOptions allocator;
34 size_t num_input_nodes = m_session->GetInputCount();
37 for (std::size_t
i = 0;
i < num_input_nodes;
i++) {
38 auto input_name = m_session->GetInputNameAllocated(
i, allocator);
39 m_input_node_names.emplace_back(input_name.get());
43 size_t num_output_nodes = m_session->GetOutputCount();
44 std::vector<int64_t> output_node_dims;
47 for(std::size_t
i = 0;
i < num_output_nodes;
i++ ) {
48 auto output_name = m_session->GetOutputNameAllocated(
i, allocator);
49 m_output_node_names.emplace_back(output_name.get());
52 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(
i);
53 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
55 output_node_dims = tensor_info.GetShape();
58 m_num_wp = output_node_dims.at(2);
65 const std::vector<std::vector<float>> & node_feat,
66 std::vector<float>& effAllJet)
const {
72 std::vector<float> input_tensor_values;
73 std::vector<int64_t> input_node_dims = {1,
static_cast<int>(node_feat.size()),
static_cast<int>(node_feat.at(0).size())};
75 for (
const auto&
it : node_feat){
76 input_tensor_values.insert(input_tensor_values.end(),
it.begin(),
it.end());
80 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
81 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_node_dims.data(), input_node_dims.size());
84 std::vector<const char*> input_node_names(m_input_node_names.size(),
nullptr);
85 for (
unsigned int i=0;
i<m_input_node_names.size();
i++) {
86 input_node_names[
i]= m_input_node_names.at(
i).c_str();
88 std::vector<const char*> output_node_names(m_output_node_names.size(),
nullptr);
89 for (
int i=0; i<static_cast<int>(m_output_node_names.size());
i++) {
90 output_node_names[
i]= m_output_node_names.at(
i).c_str();
98 auto output_tensors = session.Run(Ort::RunOptions{
nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
101 float* float_ptr = output_tensors.front().GetTensorMutableData<
float>();
102 int num_jets = node_feat.size();
103 effAllJet = {float_ptr, float_ptr + num_jets};
109 const std::vector<std::vector<float>> & node_feat,
110 std::vector<std::vector<float>> & effAllJetAllWp)
const{
118 std::vector<float> input_tensor_values;
119 std::vector<int64_t> input_node_dims = {1,
static_cast<int>(node_feat.size()),
static_cast<int>(node_feat.at(0).size())};
121 for (
auto&
it : node_feat){
122 input_tensor_values.insert(input_tensor_values.end(),
it.begin(),
it.end());
126 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
127 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_node_dims.data(), input_node_dims.size());
130 std::vector<const char*> input_node_names(m_input_node_names.size(),
nullptr);
131 for (
int i=0; i<static_cast<int>(m_input_node_names.size());
i++) {
132 input_node_names[
i]= m_input_node_names.at(
i).c_str();
134 std::vector<const char*> output_node_names(m_output_node_names.size(),
nullptr);
135 for (
int i=0; i<static_cast<int>(m_output_node_names.size());
i++) {
136 output_node_names[
i]= m_output_node_names.at(
i).c_str();
144 auto output_tensors = session.Run(Ort::RunOptions{
nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
147 float* float_ptr = output_tensors.front().GetTensorMutableData<
float>();
149 int num_jets = node_feat.size();
151 for (
int i=0;
i<num_jets;
i++){
152 std::vector<float> eff_one_jet_tmp;
153 for (
int j=0; j<m_num_wp; j++){
154 eff_one_jet_tmp.push_back(float_ptr[
i*m_num_wp+j]);
156 effAllJetAllWp.push_back(eff_one_jet_tmp);