ATLAS Offline Software
Model.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 
7 #include <iostream>
8 using namespace std;
9 
10 namespace Ringer{
11 
12  namespace onnx{
13 
14  Model::Model( const std::string& modelPath, AthOnnx::IOnnxRuntimeSvc *svc,
15  float etmin, float etmax, float etamin, float etamax,
16  unsigned barcode):
17  m_etmin(etmin),
18  m_etmax(etmax),
19  m_etamin(etamin),
20  m_etamax(etamax),
21  m_barcode(barcode)
22 
23  {
24  // Some ORT related initialization
25  Ort::SessionOptions sessionOptions;
26  sessionOptions.SetIntraOpNumThreads( 1 );
27  sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
28  m_session = std::make_shared< Ort::Session >( svc->env(), modelPath.c_str(), sessionOptions );
29  }
30 
31 
33  {
34  Ort::AllocatorWithDefaultOptions allocator;
35  size_t num_input_nodes = m_session->GetInputCount();
36 
37  for( std::size_t i = 0; i < num_input_nodes; i++ ) {
38  char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
39  m_input_node_names.push_back(input_name);
40  Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
41  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
42  auto input_node_dims = tensor_info.GetShape();
43  for (std::size_t j = 0; j < input_node_dims.size(); j++){
44  if(input_node_dims[j]<0)
45  input_node_dims[j] =1;
46  }
47  m_input_node_dims.push_back(input_node_dims);
48  }
49 
50 
51  // Always have only one output
52  size_t num_output_nodes = m_session->GetOutputCount();
53  for( std::size_t i = 0; i < num_output_nodes; i++ ) {
54  char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
55  m_output_node_names.push_back(output_name);
56  }
57 
58 
59  }
60 
61 
62 
63 
64 
65  float Model::predict( std::vector< std::vector<float> > &input_vecs ) const
66  {
67  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
68  std::vector<Ort::Value> ort_inputs;
69 
70  size_t num_inputs = input_vecs.size();
71  for( size_t i=0; i < num_inputs; ++i ){
72  ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(memory_info, input_vecs[i].data(),
73  input_vecs[i].size(), m_input_node_dims[i].data(),
74  m_input_node_dims[i].size()) );
75  }
76 
77  // score model & input tensor, get back output tensor
78  auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), ort_inputs.data(),
79  ort_inputs.size(), m_output_node_names.data(), 1);
80 
81  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
82 
83  // Get pointer to output tensor float values
84  float* output_arr = output_tensors.front().GetTensorMutableData<float>();
85 
86  return output_arr[0];
87  }
88 
89  }// onnx
90 }// ringer
AthOnnx::IOnnxRuntimeSvc
Service used for managing global objects used by Onnx Runtime.
Definition: IOnnxRuntimeSvc.h:25
Ringer::onnx::Model::compile
void compile()
Definition: Model.cxx:32
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
Model.h
Ringer::onnx::Model::m_input_node_dims
std::vector< std::vector< int64_t > > m_input_node_dims
Definition: Model.h:56
Ringer::onnx::Model::predict
float predict(std::vector< std::vector< float > > &) const
Calculate the disriminant.
Definition: Model.cxx:65
Ringer::onnx::Model::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: Model.h:59
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
lumiFormat.i
int i
Definition: lumiFormat.py:92
HepMC::barcode
int barcode(const T *p)
Definition: Barcode.h:16
Epos_Base_Fragment.Model
Model
Definition: Epos_Base_Fragment.py:10
Ringer::onnx::Model::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: Model.h:58
Handler::svc
AthROOTErrorHandlerSvc * svc
Definition: AthROOTErrorHandlerSvc.cxx:10
Ringer::onnx::Model::m_session
std::shared_ptr< Ort::Session > m_session
Definition: Model.h:50
LArCellBinning.etamin
etamin
Definition: LArCellBinning.py:137
Ringer
Namespace dedicated for Ringer utilities.
Definition: CaloRingsDefs.h:9