ATLAS Offline Software
Loading...
Searching...
No Matches
METNetHandler.cxx
Go to the documentation of this file.
1
2/*
3 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
4*/
5// Author: Bill Balunas <balunas@cern.ch>, based on earlier implementation by M. Leigh
7
8// METUtilities includes
11
12namespace met {
13
14 // Dimensions cannot be changed without altering METNet, so it's OK to hard code these for the foreseeable future.
15 METNetHandler::METNetHandler(const std::string& modelName) :
16 m_modelName(modelName),
17 m_numInputs(1),
18 m_numOutputs(1),
19 m_inputDims({1,77}),
20 m_outputDims({1,2}),
21 m_graphInputNames({"inputs"}),
22 m_graphOutputNames({"outputs"}){}
23
25
26 // Use the path resolver to find the location of the network .onnx file
28 if (m_modelPath == "") return 1;
29
30 // Use the default ONNX session settings for 1 CPU thread
31 m_onnxSessionOptions.SetIntraOpNumThreads(1);
32 m_onnxSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
33
34 // Initialise the ONNX environment and session using the above options and the model name
35 m_onnxSession = std::make_unique<Ort::Session>(m_onnxEnv, m_modelPath.c_str(), m_onnxSessionOptions);
36
37 return 0;
38 }
39
40 unsigned int METNetHandler::getReqSize() const{
41 // Returns the required size of the inputs for the network
42 return static_cast<unsigned int>(m_inputDims[1]);
43 }
44
45 void METNetHandler::predict(std::vector<float>& outputs, std::vector<float>& inputs) const {
46 // This method passes a input vector through the neural network and returns its estimate
47 // It requires conversions to onnx type tensors and back
48
49 // Create a CPU tensor to be used as input
50 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
51 Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
52 inputs.data(),
53 inputs.size(),
54 m_inputDims.data(),
55 m_inputDims.size() );
56
57 outputs = {0, 0};
58 Ort::Value output_tensor = Ort::Value::CreateTensor<float>( memory_info,
59 outputs.data(),
60 outputs.size(),
61 m_outputDims.data(),
62 m_outputDims.size() );
63
64 // Pass the input through the network, getting a vector of outputs
65 std::lock_guard<std::mutex> lock(m_onnxMutex);
66 m_onnxSession->Run( Ort::RunOptions{nullptr},
67 m_graphInputNames.data(), &input_tensor, m_numInputs,
68 m_graphOutputNames.data(), &output_tensor, m_numOutputs );
69 }
70
71}
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Ort::SessionOptions m_onnxSessionOptions
std::string m_modelPath
std::vector< const char * > m_graphInputNames
METNetHandler()=delete
unsigned int getReqSize() const
std::vector< const char * > m_graphOutputNames
std::vector< int64_t > m_outputDims
std::vector< int64_t > m_inputDims
std::string m_modelName
void predict(std::vector< float > &outputs, std::vector< float > &inputs) const