ATLAS Offline Software
Loading...
Searching...
No Matches
Ringer::onnx::Model Class Reference

#include <Model.h>

Collaboration diagram for Ringer::onnx::Model:

Public Member Functions

 Model (const std::string &modelPath, AthOnnx::IOnnxRuntimeSvc *svc, float etmin, float etmax, float etamin, float etamax, unsigned barcode)
 Constructor.
 ~Model ()=default
 Destructor.
void compile ()
float etMin () const
 Get the Et lower edge.
float etMax () const
 Get the Et high edge.
float etaMin () const
 Get the Eta lower edge.
float etaMax () const
 Get the Eta high edge.
float predict (std::vector< std::vector< float > > &) const
 Calculate the disriminant.
unsigned barcode () const

Private Attributes

std::shared_ptr< Ort::Session > m_session
std::vector< std::vector< int64_t > > m_input_node_dims
std::vector< int64_t > m_output_node_dims
std::vector< const char * > m_input_node_names
std::vector< const char * > m_output_node_names
float m_etmin
float m_etmax
float m_etamin
float m_etamax
unsigned m_barcode

Detailed Description

Definition at line 21 of file Model.h.

Constructor & Destructor Documentation

◆ Model()

Ringer::onnx::Model::Model ( const std::string & modelPath,
AthOnnx::IOnnxRuntimeSvc * svc,
float etmin,
float etmax,
float etamin,
float etamax,
unsigned barcode )

Constructor.

Definition at line 14 of file Model.cxx.

16 :
17 m_etmin(etmin),
18 m_etmax(etmax),
19 m_etamin(etamin),
20 m_etamax(etamax),
22
23 {
24 // Some ORT related initialization
25 Ort::SessionOptions sessionOptions;
26 sessionOptions.SetIntraOpNumThreads( 1 );
27 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
28 m_session = std::make_shared< Ort::Session >( svc->env(), modelPath.c_str(), sessionOptions );
29 }
unsigned barcode() const
Definition Model.h:50
unsigned m_barcode
Definition Model.h:63
std::shared_ptr< Ort::Session > m_session
Definition Model.h:55
AthROOTErrorHandlerSvc * svc

◆ ~Model()

Ringer::onnx::Model::~Model ( )
default

Destructor.

Member Function Documentation

◆ barcode()

unsigned Ringer::onnx::Model::barcode ( ) const
inline

Definition at line 50 of file Model.h.

50{ return m_barcode; };

◆ compile()

void Ringer::onnx::Model::compile ( )

Definition at line 32 of file Model.cxx.

33 {
34 Ort::AllocatorWithDefaultOptions allocator;
35 size_t num_input_nodes = m_session->GetInputCount();
36
37 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
38 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
39 m_input_node_names.push_back(input_name);
40 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
41 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
42 auto input_node_dims = tensor_info.GetShape();
43 for (std::size_t j = 0; j < input_node_dims.size(); j++){
44 if(input_node_dims[j]<0)
45 input_node_dims[j] =1;
46 }
47 m_input_node_dims.push_back(std::move(input_node_dims));
48 }
49
50
51 // Always have only one output
52 size_t num_output_nodes = m_session->GetOutputCount();
53 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
54 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
55 m_output_node_names.push_back(output_name);
56 }
57
58
59 }
std::vector< std::vector< int64_t > > m_input_node_dims
Definition Model.h:56
std::vector< const char * > m_input_node_names
Definition Model.h:58
std::vector< const char * > m_output_node_names
Definition Model.h:59

◆ etaMax()

float Ringer::onnx::Model::etaMax ( ) const
inline

Get the Eta high edge.

Definition at line 45 of file Model.h.

45{ return m_etamax; };

◆ etaMin()

float Ringer::onnx::Model::etaMin ( ) const
inline

Get the Eta lower edge.

Definition at line 42 of file Model.h.

42{ return m_etamin; };

◆ etMax()

float Ringer::onnx::Model::etMax ( ) const
inline

Get the Et high edge.

Definition at line 39 of file Model.h.

39{ return m_etmax; };

◆ etMin()

float Ringer::onnx::Model::etMin ( ) const
inline

Get the Et lower edge.

Definition at line 36 of file Model.h.

36{ return m_etmin; };

◆ predict()

float Ringer::onnx::Model::predict ( std::vector< std::vector< float > > & input_vecs) const

Calculate the disriminant.

Definition at line 65 of file Model.cxx.

66 {
67 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
68 std::vector<Ort::Value> ort_inputs;
69
70 size_t num_inputs = input_vecs.size();
71 for( size_t i=0; i < num_inputs; ++i ){
72 ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(memory_info, input_vecs[i].data(),
73 input_vecs[i].size(), m_input_node_dims[i].data(),
74 m_input_node_dims[i].size()) );
75 }
76
77 // score model & input tensor, get back output tensor
78 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), ort_inputs.data(),
79 ort_inputs.size(), m_output_node_names.data(), 1);
80
81 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
82
83 // Get pointer to output tensor float values
84 float* output_arr = output_tensors.front().GetTensorMutableData<float>();
85
86 return output_arr[0];
87 }
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11

Member Data Documentation

◆ m_barcode

unsigned Ringer::onnx::Model::m_barcode
private

Definition at line 63 of file Model.h.

◆ m_etamax

float Ringer::onnx::Model::m_etamax
private

Definition at line 62 of file Model.h.

◆ m_etamin

float Ringer::onnx::Model::m_etamin
private

Definition at line 62 of file Model.h.

◆ m_etmax

float Ringer::onnx::Model::m_etmax
private

Definition at line 61 of file Model.h.

◆ m_etmin

float Ringer::onnx::Model::m_etmin
private

Definition at line 61 of file Model.h.

◆ m_input_node_dims

std::vector<std::vector<int64_t> > Ringer::onnx::Model::m_input_node_dims
private

Definition at line 56 of file Model.h.

◆ m_input_node_names

std::vector<const char*> Ringer::onnx::Model::m_input_node_names
private

Definition at line 58 of file Model.h.

◆ m_output_node_dims

std::vector<int64_t> Ringer::onnx::Model::m_output_node_dims
private

Definition at line 57 of file Model.h.

◆ m_output_node_names

std::vector<const char*> Ringer::onnx::Model::m_output_node_names
private

Definition at line 59 of file Model.h.

◆ m_session

std::shared_ptr<Ort::Session> Ringer::onnx::Model::m_session
private

Definition at line 55 of file Model.h.


The documentation for this class was generated from the following files: