ATLAS Offline Software
Public Member Functions | Private Member Functions | Private Attributes | List of all members
met::METNetHandler Class Reference

#include <METNetHandler.h>

Collaboration diagram for met::METNetHandler:

Public Member Functions

 METNetHandler (const std::string &modelName)
 
virtual ~METNetHandler ()=default
 
int initialize ()
 
unsigned int getReqSize () const
 
void predict (std::vector< float > &outputs, std::vector< float > &inputs) const
 

Private Member Functions

 METNetHandler ()=delete
 

Private Attributes

std::string m_modelName
 
std::string m_modelPath
 
size_t m_numInputs
 
size_t m_numOutputs
 
std::vector< int64_t > m_inputDims
 
std::vector< int64_t > m_outputDims
 
std::vector< const char * > m_graphInputNames
 
std::vector< const char * > m_graphOutputNames
 
Ort::Env m_onnxEnv
 
Ort::SessionOptions m_onnxSessionOptions
 
Ort::AllocatorWithDefaultOptions m_onnxAllocator
 
std::unique_ptr< Ort::Session > m_onnxSession ATLAS_THREAD_SAFE {nullptr}
 
std::mutex m_onnxMutex ATLAS_THREAD_SAFE
 

Detailed Description

Definition at line 22 of file METNetHandler.h.

Constructor & Destructor Documentation

◆ METNetHandler() [1/2]

met::METNetHandler::METNetHandler ( const std::string &  modelName)

Definition at line 15 of file METNetHandler.cxx.

15  :
16  m_modelName(modelName),
17  m_numInputs(1),
18  m_numOutputs(1),
19  m_inputDims({1,77}),
20  m_outputDims({1,2}),
21  m_graphInputNames({"inputs"}),
22  m_graphOutputNames({"outputs"}){}

◆ ~METNetHandler()

virtual met::METNetHandler::~METNetHandler ( )
virtualdefault

◆ METNetHandler() [2/2]

met::METNetHandler::METNetHandler ( )
privatedelete

Member Function Documentation

◆ getReqSize()

unsigned int met::METNetHandler::getReqSize ( ) const

Definition at line 40 of file METNetHandler.cxx.

40  {
41  // Returns the required size of the inputs for the network
42  return static_cast<unsigned int>(m_inputDims[1]);
43  }

◆ initialize()

int met::METNetHandler::initialize ( )

Definition at line 24 of file METNetHandler.cxx.

24  {
25 
26  // Use the path resolver to find the location of the network .onnx file
28  if (m_modelPath == "") return 1;
29 
30  // Use the default ONNX session settings for 1 CPU thread
31  m_onnxSessionOptions.SetIntraOpNumThreads(1);
32  m_onnxSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
33 
34  // Initialise the ONNX environment and session using the above options and the model name
35  m_onnxSession = std::make_unique<Ort::Session>(m_onnxEnv, m_modelPath.c_str(), m_onnxSessionOptions);
36 
37  return 0;
38  }

◆ predict()

void met::METNetHandler::predict ( std::vector< float > &  outputs,
std::vector< float > &  inputs 
) const

Definition at line 45 of file METNetHandler.cxx.

45  {
46  // This method passes a input vector through the neural network and returns its estimate
47  // It requires conversions to onnx type tensors and back
48 
49  // Create a CPU tensor to be used as input
50  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
51  Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
52  inputs.data(),
53  inputs.size(),
54  m_inputDims.data(),
55  m_inputDims.size() );
56 
57  outputs = {0, 0};
58  Ort::Value output_tensor = Ort::Value::CreateTensor<float>( memory_info,
59  outputs.data(),
60  outputs.size(),
61  m_outputDims.data(),
62  m_outputDims.size() );
63 
64  // Pass the input through the network, getting a vector of outputs
65  std::lock_guard<std::mutex> lock(m_onnxMutex);
66  m_onnxSession->Run( Ort::RunOptions{nullptr},
67  m_graphInputNames.data(), &input_tensor, m_numInputs,
68  m_graphOutputNames.data(), &output_tensor, m_numOutputs );
69  }

Member Data Documentation

◆ ATLAS_THREAD_SAFE [1/2]

std::unique_ptr<Ort::Session> m_onnxSession met::METNetHandler::ATLAS_THREAD_SAFE {nullptr}
mutableprivate

Definition at line 58 of file METNetHandler.h.

◆ ATLAS_THREAD_SAFE [2/2]

std::mutex m_onnxMutex met::METNetHandler::ATLAS_THREAD_SAFE
mutableprivate

Definition at line 59 of file METNetHandler.h.

◆ m_graphInputNames

std::vector<const char *> met::METNetHandler::m_graphInputNames
private

Definition at line 51 of file METNetHandler.h.

◆ m_graphOutputNames

std::vector<const char *> met::METNetHandler::m_graphOutputNames
private

Definition at line 52 of file METNetHandler.h.

◆ m_inputDims

std::vector<int64_t> met::METNetHandler::m_inputDims
private

Definition at line 49 of file METNetHandler.h.

◆ m_modelName

std::string met::METNetHandler::m_modelName
private

Definition at line 43 of file METNetHandler.h.

◆ m_modelPath

std::string met::METNetHandler::m_modelPath
private

Definition at line 44 of file METNetHandler.h.

◆ m_numInputs

size_t met::METNetHandler::m_numInputs
private

Definition at line 47 of file METNetHandler.h.

◆ m_numOutputs

size_t met::METNetHandler::m_numOutputs
private

Definition at line 48 of file METNetHandler.h.

◆ m_onnxAllocator

Ort::AllocatorWithDefaultOptions met::METNetHandler::m_onnxAllocator
private

Definition at line 57 of file METNetHandler.h.

◆ m_onnxEnv

Ort::Env met::METNetHandler::m_onnxEnv
private

Definition at line 55 of file METNetHandler.h.

◆ m_onnxSessionOptions

Ort::SessionOptions met::METNetHandler::m_onnxSessionOptions
private

Definition at line 56 of file METNetHandler.h.

◆ m_outputDims

std::vector<int64_t> met::METNetHandler::m_outputDims
private

Definition at line 50 of file METNetHandler.h.


The documentation for this class was generated from the following files:
met::METNetHandler::m_modelName
std::string m_modelName
Definition: METNetHandler.h:43
met::METNetHandler::m_numInputs
size_t m_numInputs
Definition: METNetHandler.h:47
met::METNetHandler::m_graphOutputNames
std::vector< const char * > m_graphOutputNames
Definition: METNetHandler.h:52
met::METNetHandler::m_modelPath
std::string m_modelPath
Definition: METNetHandler.h:44
python.RatesEmulationExample.lock
lock
Definition: RatesEmulationExample.py:148
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
met::METNetHandler::m_numOutputs
size_t m_numOutputs
Definition: METNetHandler.h:48
met::METNetHandler::m_onnxEnv
Ort::Env m_onnxEnv
Definition: METNetHandler.h:55
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
met::METNetHandler::m_outputDims
std::vector< int64_t > m_outputDims
Definition: METNetHandler.h:50
met::METNetHandler::m_inputDims
std::vector< int64_t > m_inputDims
Definition: METNetHandler.h:49
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:321
met::METNetHandler::m_graphInputNames
std::vector< const char * > m_graphInputNames
Definition: METNetHandler.h:51
met::METNetHandler::m_onnxSessionOptions
Ort::SessionOptions m_onnxSessionOptions
Definition: METNetHandler.h:56