ATLAS Offline Software
Loading...
Searching...
No Matches
met::METNetSigHandler Class Reference

#include <METNetSigHandler.h>

Collaboration diagram for met::METNetSigHandler:

Public Member Functions

 METNetSigHandler (const std::string &modelName)
virtual ~METNetSigHandler ()
int initialize ()
int getReqSize () const
std::vector< floatpredict (std::vector< float > inputs) const

Private Member Functions

 METNetSigHandler ()

Private Attributes

std::string m_modelName
std::string m_modelPath
size_t m_numInputs
size_t m_numOutputs
std::vector< int64_t > m_inputDims
std::vector< int64_t > m_outputDims
std::vector< const char * > m_graphInputNames
std::vector< const char * > m_graphOutputNames
Ort::Env m_onnxEnv
Ort::SessionOptions m_onnxSessionOptions
Ort::AllocatorWithDefaultOptions m_onnxAllocator
std::unique_ptr< Ort::Session > m_onnxSession

Detailed Description

Definition at line 24 of file METNetSigHandler.h.

Constructor & Destructor Documentation

◆ METNetSigHandler() [1/2]

met::METNetSigHandler::METNetSigHandler ( const std::string & modelName)

Definition at line 20 of file METNetSigHandler.cxx.

20 :
21 m_modelName(modelName),
22 m_numInputs(1),
23 m_numOutputs(4),
24 m_inputDims({1,77}),
25 m_outputDims({1,1}), // m_outputDims({1,2}),
26 m_graphInputNames({"input"}),
27 m_graphOutputNames({"MET_x", "MET_y", "Sigma_x", "Sigma_y"}){}
std::vector< const char * > m_graphInputNames
std::vector< int64_t > m_outputDims
std::vector< int64_t > m_inputDims
std::vector< const char * > m_graphOutputNames

◆ ~METNetSigHandler()

met::METNetSigHandler::~METNetSigHandler ( )
virtual

Definition at line 29 of file METNetSigHandler.cxx.

29{ }

◆ METNetSigHandler() [2/2]

met::METNetSigHandler::METNetSigHandler ( )
private

Member Function Documentation

◆ getReqSize()

int met::METNetSigHandler::getReqSize ( ) const

Definition at line 50 of file METNetSigHandler.cxx.

50 {
51 // Returns the required size of the inputs for the network
52 return static_cast<int>(m_inputDims[1]);
53 }

◆ initialize()

int met::METNetSigHandler::initialize ( )

Definition at line 31 of file METNetSigHandler.cxx.

31 {
32
33 // Use the path resolver to find the location of the network .onnx file
34 // m_modelPath = PathResolverFindCalibFile(m_modelName);
35 // m_modelPath = "ggHyydAnalysis/model_mc23.onnx";
37 // m_modelPath = "../source/ggHyydAnalysis/share/model_mc23.onnx";
38 if (m_modelPath == "") return 1;
39
40 // Use the default ONNX session settings for 1 CPU thread
41 m_onnxSessionOptions.SetIntraOpNumThreads(1);
42 m_onnxSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
43
44 // Initialise the ONNX environment and session using the above options and the model name
45 m_onnxSession = std::make_unique<Ort::Session>(m_onnxEnv, m_modelPath.c_str(), m_onnxSessionOptions);
46
47 return 0;
48 }
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Ort::SessionOptions m_onnxSessionOptions
std::unique_ptr< Ort::Session > m_onnxSession

◆ predict()

std::vector< float > met::METNetSigHandler::predict ( std::vector< float > inputs) const

Definition at line 55 of file METNetSigHandler.cxx.

55 {
56 // This method passes a input vector through the neural network and returns its estimate
57 // It requires conversions to onnx type tensors and back
58
59 // Create a CPU tensor to be used as input
60 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
62 inputs.data(),
63 inputs.size(),
64 m_inputDims.data(),
65 m_inputDims.size() );
66
67
68 std::vector<Ort::Value> ort_outputs = m_onnxSession->Run(
69 Ort::RunOptions{nullptr},
70 m_graphInputNames.data(), &input_tensor, m_numInputs,
72 );
73
74 // Extract output values and convert from GeV to MeV
75 std::vector<float> outputs;
76 for (size_t i = 0; i < ort_outputs.size(); ++i) {
77 const float* data = ort_outputs[i].GetTensorData<float>();
78 outputs.push_back(data[0] * 1000.0f);
79 }
80 return outputs;
81
82 }
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11

Member Data Documentation

◆ m_graphInputNames

std::vector<const char *> met::METNetSigHandler::m_graphInputNames
private

Definition at line 53 of file METNetSigHandler.h.

◆ m_graphOutputNames

std::vector<const char *> met::METNetSigHandler::m_graphOutputNames
private

Definition at line 54 of file METNetSigHandler.h.

◆ m_inputDims

std::vector<int64_t> met::METNetSigHandler::m_inputDims
private

Definition at line 51 of file METNetSigHandler.h.

◆ m_modelName

std::string met::METNetSigHandler::m_modelName
private

Definition at line 45 of file METNetSigHandler.h.

◆ m_modelPath

std::string met::METNetSigHandler::m_modelPath
private

Definition at line 46 of file METNetSigHandler.h.

◆ m_numInputs

size_t met::METNetSigHandler::m_numInputs
private

Definition at line 49 of file METNetSigHandler.h.

◆ m_numOutputs

size_t met::METNetSigHandler::m_numOutputs
private

Definition at line 50 of file METNetSigHandler.h.

◆ m_onnxAllocator

Ort::AllocatorWithDefaultOptions met::METNetSigHandler::m_onnxAllocator
private

Definition at line 59 of file METNetSigHandler.h.

◆ m_onnxEnv

Ort::Env met::METNetSigHandler::m_onnxEnv
private

Definition at line 57 of file METNetSigHandler.h.

◆ m_onnxSession

std::unique_ptr<Ort::Session> met::METNetSigHandler::m_onnxSession
private

Definition at line 60 of file METNetSigHandler.h.

◆ m_onnxSessionOptions

Ort::SessionOptions met::METNetSigHandler::m_onnxSessionOptions
private

Definition at line 58 of file METNetSigHandler.h.

◆ m_outputDims

std::vector<int64_t> met::METNetSigHandler::m_outputDims
private

Definition at line 52 of file METNetSigHandler.h.


The documentation for this class was generated from the following files: