15 float etmin,
float etmax,
float etamin,
float etamax,
25 Ort::SessionOptions sessionOptions;
26 sessionOptions.SetIntraOpNumThreads( 1 );
27 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
28 m_session = std::make_shared< Ort::Session >( svc->env(), modelPath.c_str(), sessionOptions );
34 Ort::AllocatorWithDefaultOptions allocator;
35 size_t num_input_nodes =
m_session->GetInputCount();
37 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
38 char* input_name =
m_session->GetInputNameAllocated(i, allocator).release();
40 Ort::TypeInfo type_info =
m_session->GetInputTypeInfo(i);
41 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
42 auto input_node_dims = tensor_info.GetShape();
43 for (std::size_t j = 0; j < input_node_dims.size(); j++){
44 if(input_node_dims[j]<0)
45 input_node_dims[j] =1;
52 size_t num_output_nodes =
m_session->GetOutputCount();
53 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
54 char* output_name =
m_session->GetOutputNameAllocated(i, allocator).release();
67 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
68 std::vector<Ort::Value> ort_inputs;
70 size_t num_inputs = input_vecs.size();
71 for(
size_t i=0; i < num_inputs; ++i ){
72 ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(memory_info, input_vecs[i].
data(),
81 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
84 float* output_arr = output_tensors.front().GetTensorMutableData<
float>();
char data[hepevt_bytes_allocation_ATLAS]
size_t size() const
Number of registered mappings.
Service used for managing global objects used by Onnx Runtime.
std::vector< std::vector< int64_t > > m_input_node_dims
std::vector< const char * > m_input_node_names
std::vector< const char * > m_output_node_names
Model(const std::string &modelPath, AthOnnx::IOnnxRuntimeSvc *svc, float etmin, float etmax, float etamin, float etamax, unsigned barcode)
Constructor.
std::shared_ptr< Ort::Session > m_session
float predict(std::vector< std::vector< float > > &) const
Calculate the disriminant.
Namespace dedicated for Ringer utilities.