ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModel Class Referencefinal

#include <SaltModel.h>

Collaboration diagram for SaltModel:

Public Member Functions

 SaltModel (const std::string &name)
 ~SaltModel ()=default
void initialize ()
void runInference (const std::vector< std::vector< float > > &node_feat, std::vector< float > &effAllJet) const
void runInference (const std::vector< std::vector< float > > &node_feat, std::vector< std::vector< float > > &effAllJetAllWp) const

Private Attributes

std::vector< std::string > m_input_node_names
std::vector< std::string > m_output_node_names
std::unique_ptr< Ort::Session > m_session
std::unique_ptr< Ort::Env > m_env
std::string m_path_to_onnx
int m_num_wp {}

Detailed Description

Constructor & Destructor Documentation

◆ SaltModel()

SaltModel::SaltModel ( const std::string & name)

◆ ~SaltModel()

SaltModel::~SaltModel ( )
default

Member Function Documentation

◆ initialize()

void SaltModel::initialize ( )

Definition at line 17 of file OnnxUtil.cxx.

17 {
18
19 std::string fullPathToFile = PathResolverFindCalibFile(m_path_to_onnx);
20
21 //load the onnx model to memory using the path m_path_to_onnx
22 m_env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "");
23
24 // initialize session options if needed
25 Ort::SessionOptions session_options;
26 session_options.SetIntraOpNumThreads(1);
27 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
28
29 // create session and load model into memory
30 m_session = std::make_unique< Ort::Session >(*m_env, fullPathToFile.c_str(), session_options);
31 Ort::AllocatorWithDefaultOptions allocator;
32
33 // get the input nodes
34 size_t num_input_nodes = m_session->GetInputCount();
35
36 // iterate over all input nodes
37 for (std::size_t i = 0; i < num_input_nodes; i++) {
38 auto input_name = m_session->GetInputNameAllocated(i, allocator);
39 m_input_node_names.emplace_back(input_name.get());
40 }
41
42 // get the output nodes
43 size_t num_output_nodes = m_session->GetOutputCount();
44 std::vector<int64_t> output_node_dims;
45
46 // iterate over all output nodes
47 for(std::size_t i = 0; i < num_output_nodes; i++ ) {
48 auto output_name = m_session->GetOutputNameAllocated(i, allocator);
49 m_output_node_names.emplace_back(output_name.get());
50
51 // get output node types
52 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
53 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
54
55 output_node_dims = tensor_info.GetShape();
56
57 // output is of the shape {1, num_jets, num_wp}
58 m_num_wp = output_node_dims.at(2);
59 }
60}
std::string PathResolverFindCalibFile(const std::string &logical_file_name)

◆ runInference() [1/2]

void SaltModel::runInference ( const std::vector< std::vector< float > > & node_feat,
std::vector< float > & effAllJet ) const

Definition at line 64 of file OnnxUtil.cxx.

66 {
67
68 // Inputs:
69 // node_feat : vector<vector<float>>
70 // effAllJet : vector<double>&
71
72 std::vector<float> input_tensor_values;
73 std::vector<int64_t> input_node_dims = {1, static_cast<int>(node_feat.size()), static_cast<int>(node_feat.at(0).size())};
74
75 for (const auto& it : node_feat){
76 input_tensor_values.insert(input_tensor_values.end(), it.begin(), it.end());
77 }
78
79 // create input tensor object from data values
80 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
81 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_node_dims.data(), input_node_dims.size());
82
83 // casting vector<string> to vector<const char*>. this is what ORT expects
84 std::vector<const char*> input_node_names(m_input_node_names.size(),nullptr);
85 for (unsigned int i=0; i<m_input_node_names.size(); i++) {
86 input_node_names[i]= m_input_node_names.at(i).c_str();
87 }
88 std::vector<const char*> output_node_names(m_output_node_names.size(),nullptr);
89 for (int i=0; i<static_cast<int>(m_output_node_names.size()); i++) {
90 output_node_names[i]= m_output_node_names.at(i).c_str();
91 }
92
93 // score model & input tensor, get back output tensor
94 // Although Session::Run is non-const, the onnx authors say
95 // it is safe to call from multiple threads:
96 // https://github.com/microsoft/onnxruntime/discussions/10107
97 Ort::Session& session ATLAS_THREAD_SAFE = *m_session;
98 auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
99
100 // set the output vector values to the inference results
101 float* float_ptr = output_tensors.front().GetTensorMutableData<float>();
102 int num_jets = node_feat.size();
103 effAllJet = {float_ptr, float_ptr + num_jets};
104}
#define ATLAS_THREAD_SAFE

◆ runInference() [2/2]

void SaltModel::runInference ( const std::vector< std::vector< float > > & node_feat,
std::vector< std::vector< float > > & effAllJetAllWp ) const

Definition at line 108 of file OnnxUtil.cxx.

110 {
111
112 // Inputs:
113 // node_feat : vector<vector<float>>
114 // effAllJetAllWp : vector<vector<double>>& shape:{num_jets, num_wp}
115
116 // using float because that's what the model expects
117 // ort exectues type casting wrong (x = x.float()), so can't change the type inside the model
118 std::vector<float> input_tensor_values;
119 std::vector<int64_t> input_node_dims = {1, static_cast<int>(node_feat.size()), static_cast<int>(node_feat.at(0).size())};
120
121 for (auto& it : node_feat){
122 input_tensor_values.insert(input_tensor_values.end(), it.begin(), it.end());
123 }
124
125 // create input tensor object from data values
126 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
127 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_node_dims.data(), input_node_dims.size());
128
129 // casting vector<string> to vector<const char*>. this is what ORT expects
130 std::vector<const char*> input_node_names(m_input_node_names.size(),nullptr);
131 for (int i=0; i<static_cast<int>(m_input_node_names.size()); i++) {
132 input_node_names[i]= m_input_node_names.at(i).c_str();
133 }
134 std::vector<const char*> output_node_names(m_output_node_names.size(),nullptr);
135 for (int i=0; i<static_cast<int>(m_output_node_names.size()); i++) {
136 output_node_names[i]= m_output_node_names.at(i).c_str();
137 }
138
139 // score model & input tensor, get back output tensor
140 // Although Session::Run is non-const, the onnx authors say
141 // it is safe to call from multiple threads:
142 // https://github.com/microsoft/onnxruntime/discussions/10107
143 Ort::Session& session ATLAS_THREAD_SAFE = *m_session;
144 auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), &input_tensor, input_node_names.size(), output_node_names.data(), output_node_names.size());
145
146 // set the output vector values to the inference results
147 float* float_ptr = output_tensors.front().GetTensorMutableData<float>();
148
149 int num_jets = node_feat.size();
150
151 for (int i=0; i<num_jets; i++){
152 std::vector<float> eff_one_jet_tmp;
153 for (int j=0; j<m_num_wp; j++){
154 eff_one_jet_tmp.push_back(float_ptr[i*m_num_wp+j]);
155 }
156 effAllJetAllWp.push_back(std::move(eff_one_jet_tmp));
157 }
158}

Member Data Documentation

◆ m_env

std::unique_ptr< Ort::Env > SaltModel::m_env
private

◆ m_input_node_names

std::vector<std::string> SaltModel::m_input_node_names
private

◆ m_num_wp

int SaltModel::m_num_wp {}
private

◆ m_output_node_names

std::vector<std::string> SaltModel::m_output_node_names
private

◆ m_path_to_onnx

std::string SaltModel::m_path_to_onnx
private

◆ m_session

std::unique_ptr< Ort::Session > SaltModel::m_session
private

The documentation for this class was generated from the following files: