ATLAS Offline Software
Loading...
Searching...
No Matches
met::METNetHandler Class Reference

#include <METNetHandler.h>

Collaboration diagram for met::METNetHandler:

Public Member Functions

 METNetHandler (const std::string &modelName)
virtual ~METNetHandler ()=default
int initialize ()
unsigned int getReqSize () const
void predict (std::vector< float > &outputs, std::vector< float > &inputs) const

Private Member Functions

 METNetHandler ()=delete

Private Attributes

std::string m_modelName
std::string m_modelPath
size_t m_numInputs
size_t m_numOutputs
std::vector< int64_t > m_inputDims
std::vector< int64_t > m_outputDims
std::vector< const char * > m_graphInputNames
std::vector< const char * > m_graphOutputNames
Ort::Env m_onnxEnv
Ort::SessionOptions m_onnxSessionOptions
Ort::AllocatorWithDefaultOptions m_onnxAllocator
std::unique_ptr< Ort::Session > m_onnxSession ATLAS_THREAD_SAFE {nullptr}
std::mutex m_onnxMutex ATLAS_THREAD_SAFE

Detailed Description

Definition at line 22 of file METNetHandler.h.

Constructor & Destructor Documentation

◆ METNetHandler() [1/2]

met::METNetHandler::METNetHandler ( const std::string & modelName)

Definition at line 15 of file METNetHandler.cxx.

15 :
16 m_modelName(modelName),
17 m_numInputs(1),
18 m_numOutputs(1),
19 m_inputDims({1,77}),
20 m_outputDims({1,2}),
21 m_graphInputNames({"inputs"}),
22 m_graphOutputNames({"outputs"}){}
std::vector< const char * > m_graphInputNames
std::vector< const char * > m_graphOutputNames
std::vector< int64_t > m_outputDims
std::vector< int64_t > m_inputDims
std::string m_modelName

◆ ~METNetHandler()

virtual met::METNetHandler::~METNetHandler ( )
virtualdefault

◆ METNetHandler() [2/2]

met::METNetHandler::METNetHandler ( )
privatedelete

Member Function Documentation

◆ getReqSize()

unsigned int met::METNetHandler::getReqSize ( ) const

Definition at line 40 of file METNetHandler.cxx.

40 {
41 // Returns the required size of the inputs for the network
42 return static_cast<unsigned int>(m_inputDims[1]);
43 }

◆ initialize()

int met::METNetHandler::initialize ( )

Definition at line 24 of file METNetHandler.cxx.

24 {
25
26 // Use the path resolver to find the location of the network .onnx file
28 if (m_modelPath == "") return 1;
29
30 // Use the default ONNX session settings for 1 CPU thread
31 m_onnxSessionOptions.SetIntraOpNumThreads(1);
32 m_onnxSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
33
34 // Initialise the ONNX environment and session using the above options and the model name
35 m_onnxSession = std::make_unique<Ort::Session>(m_onnxEnv, m_modelPath.c_str(), m_onnxSessionOptions);
36
37 return 0;
38 }
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Ort::SessionOptions m_onnxSessionOptions
std::string m_modelPath

◆ predict()

void met::METNetHandler::predict ( std::vector< float > & outputs,
std::vector< float > & inputs ) const

Definition at line 45 of file METNetHandler.cxx.

45 {
46 // This method passes a input vector through the neural network and returns its estimate
47 // It requires conversions to onnx type tensors and back
48
49 // Create a CPU tensor to be used as input
50 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
51 Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
52 inputs.data(),
53 inputs.size(),
54 m_inputDims.data(),
55 m_inputDims.size() );
56
57 outputs = {0, 0};
58 Ort::Value output_tensor = Ort::Value::CreateTensor<float>( memory_info,
59 outputs.data(),
60 outputs.size(),
61 m_outputDims.data(),
62 m_outputDims.size() );
63
64 // Pass the input through the network, getting a vector of outputs
65 std::lock_guard<std::mutex> lock(m_onnxMutex);
66 m_onnxSession->Run( Ort::RunOptions{nullptr},
67 m_graphInputNames.data(), &input_tensor, m_numInputs,
68 m_graphOutputNames.data(), &output_tensor, m_numOutputs );
69 }

Member Data Documentation

◆ ATLAS_THREAD_SAFE [1/2]

std::unique_ptr<Ort::Session> m_onnxSession met::METNetHandler::ATLAS_THREAD_SAFE {nullptr}
mutableprivate

Definition at line 58 of file METNetHandler.h.

58{nullptr};

◆ ATLAS_THREAD_SAFE [2/2]

std::mutex m_onnxMutex met::METNetHandler::ATLAS_THREAD_SAFE
mutableprivate

Definition at line 59 of file METNetHandler.h.

◆ m_graphInputNames

std::vector<const char *> met::METNetHandler::m_graphInputNames
private

Definition at line 51 of file METNetHandler.h.

◆ m_graphOutputNames

std::vector<const char *> met::METNetHandler::m_graphOutputNames
private

Definition at line 52 of file METNetHandler.h.

◆ m_inputDims

std::vector<int64_t> met::METNetHandler::m_inputDims
private

Definition at line 49 of file METNetHandler.h.

◆ m_modelName

std::string met::METNetHandler::m_modelName
private

Definition at line 43 of file METNetHandler.h.

◆ m_modelPath

std::string met::METNetHandler::m_modelPath
private

Definition at line 44 of file METNetHandler.h.

◆ m_numInputs

size_t met::METNetHandler::m_numInputs
private

Definition at line 47 of file METNetHandler.h.

◆ m_numOutputs

size_t met::METNetHandler::m_numOutputs
private

Definition at line 48 of file METNetHandler.h.

◆ m_onnxAllocator

Ort::AllocatorWithDefaultOptions met::METNetHandler::m_onnxAllocator
private

Definition at line 57 of file METNetHandler.h.

◆ m_onnxEnv

Ort::Env met::METNetHandler::m_onnxEnv
private

Definition at line 55 of file METNetHandler.h.

◆ m_onnxSessionOptions

Ort::SessionOptions met::METNetHandler::m_onnxSessionOptions
private

Definition at line 56 of file METNetHandler.h.

◆ m_outputDims

std::vector<int64_t> met::METNetHandler::m_outputDims
private

Definition at line 50 of file METNetHandler.h.


The documentation for this class was generated from the following files: