ATLAS Offline Software
Public Member Functions | Private Attributes | List of all members
Ringer::onnx::Model Class Reference

#include <Model.h>

Collaboration diagram for Ringer::onnx::Model:

Public Member Functions

 Model (const std::string &modelPath, AthOnnx::IOnnxRuntimeSvc *svc, float etmin, float etmax, float etamin, float etamax, unsigned barcode)
 Constructor. More...
 
 ~Model ()=default
 Destructor. More...
 
void compile ()
 
float etMin () const
 Get the Et lower edge. More...
 
float etMax () const
 Get the Et high edge. More...
 
float etaMin () const
 Get the Eta lower edge. More...
 
float etaMax () const
 Get the Eta high edge. More...
 
float predict (std::vector< std::vector< float > > &) const
 Calculate the disriminant. More...
 
unsigned barcode () const
 

Private Attributes

std::shared_ptr< Ort::Session > m_session
 
std::vector< std::vector< int64_t > > m_input_node_dims
 
std::vector< int64_t > m_output_node_dims
 
std::vector< const char * > m_input_node_names
 
std::vector< const char * > m_output_node_names
 
float m_etmin
 
float m_etmax
 
float m_etamin
 
float m_etamax
 
unsigned m_barcode
 

Detailed Description

Definition at line 21 of file Model.h.

Constructor & Destructor Documentation

◆ Model()

Ringer::onnx::Model::Model ( const std::string &  modelPath,
AthOnnx::IOnnxRuntimeSvc svc,
float  etmin,
float  etmax,
float  etamin,
float  etamax,
unsigned  barcode 
)

Constructor.

Definition at line 14 of file Model.cxx.

16  :
17  m_etmin(etmin),
18  m_etmax(etmax),
20  m_etamax(etamax),
22 
23  {
24  // Some ORT related initialization
25  Ort::SessionOptions sessionOptions;
26  sessionOptions.SetIntraOpNumThreads( 1 );
27  sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
28  m_session = std::make_shared< Ort::Session >( svc->env(), modelPath.c_str(), sessionOptions );
29  }

◆ ~Model()

Ringer::onnx::Model::~Model ( )
default

Destructor.

Member Function Documentation

◆ barcode()

unsigned Ringer::onnx::Model::barcode ( ) const
inline

Definition at line 50 of file Model.h.

50 { return m_barcode; };

◆ compile()

void Ringer::onnx::Model::compile ( )

Definition at line 32 of file Model.cxx.

33  {
34  Ort::AllocatorWithDefaultOptions allocator;
35  size_t num_input_nodes = m_session->GetInputCount();
36 
37  for( std::size_t i = 0; i < num_input_nodes; i++ ) {
38  char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
39  m_input_node_names.push_back(input_name);
40  Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
41  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
42  auto input_node_dims = tensor_info.GetShape();
43  for (std::size_t j = 0; j < input_node_dims.size(); j++){
44  if(input_node_dims[j]<0)
45  input_node_dims[j] =1;
46  }
47  m_input_node_dims.push_back(input_node_dims);
48  }
49 
50 
51  // Always have only one output
52  size_t num_output_nodes = m_session->GetOutputCount();
53  for( std::size_t i = 0; i < num_output_nodes; i++ ) {
54  char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
55  m_output_node_names.push_back(output_name);
56  }
57 
58 
59  }

◆ etaMax()

float Ringer::onnx::Model::etaMax ( ) const
inline

Get the Eta high edge.

Definition at line 45 of file Model.h.

45 { return m_etamax; };

◆ etaMin()

float Ringer::onnx::Model::etaMin ( ) const
inline

Get the Eta lower edge.

Definition at line 42 of file Model.h.

42 { return m_etamin; };

◆ etMax()

float Ringer::onnx::Model::etMax ( ) const
inline

Get the Et high edge.

Definition at line 39 of file Model.h.

39 { return m_etmax; };

◆ etMin()

float Ringer::onnx::Model::etMin ( ) const
inline

Get the Et lower edge.

Definition at line 36 of file Model.h.

36 { return m_etmin; };

◆ predict()

float Ringer::onnx::Model::predict ( std::vector< std::vector< float > > &  input_vecs) const

Calculate the disriminant.

Definition at line 65 of file Model.cxx.

66  {
67  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
68  std::vector<Ort::Value> ort_inputs;
69 
70  size_t num_inputs = input_vecs.size();
71  for( size_t i=0; i < num_inputs; ++i ){
72  ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(memory_info, input_vecs[i].data(),
73  input_vecs[i].size(), m_input_node_dims[i].data(),
74  m_input_node_dims[i].size()) );
75  }
76 
77  // score model & input tensor, get back output tensor
78  auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), ort_inputs.data(),
79  ort_inputs.size(), m_output_node_names.data(), 1);
80 
81  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
82 
83  // Get pointer to output tensor float values
84  float* output_arr = output_tensors.front().GetTensorMutableData<float>();
85 
86  return output_arr[0];
87  }

Member Data Documentation

◆ m_barcode

unsigned Ringer::onnx::Model::m_barcode
private

Definition at line 63 of file Model.h.

◆ m_etamax

float Ringer::onnx::Model::m_etamax
private

Definition at line 62 of file Model.h.

◆ m_etamin

float Ringer::onnx::Model::m_etamin
private

Definition at line 62 of file Model.h.

◆ m_etmax

float Ringer::onnx::Model::m_etmax
private

Definition at line 61 of file Model.h.

◆ m_etmin

float Ringer::onnx::Model::m_etmin
private

Definition at line 61 of file Model.h.

◆ m_input_node_dims

std::vector<std::vector<int64_t> > Ringer::onnx::Model::m_input_node_dims
private

Definition at line 56 of file Model.h.

◆ m_input_node_names

std::vector<const char*> Ringer::onnx::Model::m_input_node_names
private

Definition at line 58 of file Model.h.

◆ m_output_node_dims

std::vector<int64_t> Ringer::onnx::Model::m_output_node_dims
private

Definition at line 57 of file Model.h.

◆ m_output_node_names

std::vector<const char*> Ringer::onnx::Model::m_output_node_names
private

Definition at line 59 of file Model.h.

◆ m_session

std::shared_ptr<Ort::Session> Ringer::onnx::Model::m_session
private

Definition at line 55 of file Model.h.


The documentation for this class was generated from the following files:
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
Ringer::onnx::Model::m_input_node_dims
std::vector< std::vector< int64_t > > m_input_node_dims
Definition: Model.h:56
Ringer::onnx::Model::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: Model.h:59
Ringer::onnx::Model::m_barcode
unsigned m_barcode
Definition: Model.h:63
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
Ringer::onnx::Model::m_etamax
float m_etamax
Definition: Model.h:62
lumiFormat.i
int i
Definition: lumiFormat.py:92
Ringer::onnx::Model::m_etmax
float m_etmax
Definition: Model.h:61
Ringer::onnx::Model::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: Model.h:58
Handler::svc
AthROOTErrorHandlerSvc * svc
Definition: AthROOTErrorHandlerSvc.cxx:10
Ringer::onnx::Model::m_etamin
float m_etamin
Definition: Model.h:62
Ringer::onnx::Model::m_session
std::shared_ptr< Ort::Session > m_session
Definition: Model.h:50
Ringer::onnx::Model::m_etmin
float m_etmin
Definition: Model.h:61
Ringer::onnx::Model::barcode
unsigned barcode() const
Definition: Model.h:50
LArCellBinning.etamin
etamin
Definition: LArCellBinning.py:137