ATLAS Offline Software
Loading...
Searching...
No Matches
Model.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5
7#include <iostream>
8using namespace std;
9
10namespace Ringer{
11
12 namespace onnx{
13
14 Model::Model( const std::string& modelPath, AthOnnx::IOnnxRuntimeSvc *svc,
15 float etmin, float etmax, float etamin, float etamax,
16 unsigned barcode):
17 m_etmin(etmin),
18 m_etmax(etmax),
19 m_etamin(etamin),
20 m_etamax(etamax),
22
23 {
24 // Some ORT related initialization
25 Ort::SessionOptions sessionOptions;
26 sessionOptions.SetIntraOpNumThreads( 1 );
27 sessionOptions.SetGraphOptimizationLevel( ORT_ENABLE_BASIC );
28 m_session = std::make_shared< Ort::Session >( svc->env(), modelPath.c_str(), sessionOptions );
29 }
30
31
33 {
34 Ort::AllocatorWithDefaultOptions allocator;
35 size_t num_input_nodes = m_session->GetInputCount();
36
37 for( std::size_t i = 0; i < num_input_nodes; i++ ) {
38 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
39 m_input_node_names.push_back(input_name);
40 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
41 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
42 auto input_node_dims = tensor_info.GetShape();
43 for (std::size_t j = 0; j < input_node_dims.size(); j++){
44 if(input_node_dims[j]<0)
45 input_node_dims[j] =1;
46 }
47 m_input_node_dims.push_back(std::move(input_node_dims));
48 }
49
50
51 // Always have only one output
52 size_t num_output_nodes = m_session->GetOutputCount();
53 for( std::size_t i = 0; i < num_output_nodes; i++ ) {
54 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
55 m_output_node_names.push_back(output_name);
56 }
57
58
59 }
60
61
62
63
64
65 float Model::predict( std::vector< std::vector<float> > &input_vecs ) const
66 {
67 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
68 std::vector<Ort::Value> ort_inputs;
69
70 size_t num_inputs = input_vecs.size();
71 for( size_t i=0; i < num_inputs; ++i ){
72 ort_inputs.emplace_back( Ort::Value::CreateTensor<float>(memory_info, input_vecs[i].data(),
73 input_vecs[i].size(), m_input_node_dims[i].data(),
74 m_input_node_dims[i].size()) );
75 }
76
77 // score model & input tensor, get back output tensor
78 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), ort_inputs.data(),
79 ort_inputs.size(), m_output_node_names.data(), 1);
80
81 assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());
82
83 // Get pointer to output tensor float values
84 float* output_arr = output_tensors.front().GetTensorMutableData<float>();
85
86 return output_arr[0];
87 }
88
89 }// onnx
90}// ringer
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
Service used for managing global objects used by Onnx Runtime.
std::vector< std::vector< int64_t > > m_input_node_dims
Definition Model.h:56
std::vector< const char * > m_input_node_names
Definition Model.h:58
std::vector< const char * > m_output_node_names
Definition Model.h:59
unsigned barcode() const
Definition Model.h:50
unsigned m_barcode
Definition Model.h:63
Model(const std::string &modelPath, AthOnnx::IOnnxRuntimeSvc *svc, float etmin, float etmax, float etamin, float etamax, unsigned barcode)
Constructor.
Definition Model.cxx:14
std::shared_ptr< Ort::Session > m_session
Definition Model.h:55
float predict(std::vector< std::vector< float > > &) const
Calculate the disriminant.
Definition Model.cxx:65
Namespace dedicated for Ringer utilities.
STL namespace.