ATLAS Offline Software
METNetHandler.cxx
Go to the documentation of this file.
1 /*
3  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
4 */
5 // Author: Bill Balunas <balunas@cern.ch>, based on earlier implementation by M. Leigh
7 
8 // METUtilities includes
11 
12 namespace met {
13 
14  // Dimensions cannot be changed without altering METNet, so it's OK to hard code these for the foreseeable future.
15  METNetHandler::METNetHandler(const std::string& modelName) :
16  m_modelName(modelName),
17  m_numInputs(1),
18  m_numOutputs(1),
19  m_inputDims({1,77}),
20  m_outputDims({1,2}),
21  m_graphInputNames({"inputs"}),
22  m_graphOutputNames({"outputs"}){}
23 
25 
26  // Use the path resolver to find the location of the network .onnx file
28  if (m_modelPath == "") return 1;
29 
30  // Use the default ONNX session settings for 1 CPU thread
31  m_onnxSessionOptions.SetIntraOpNumThreads(1);
32  m_onnxSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
33 
34  // Initialise the ONNX environment and session using the above options and the model name
35  m_onnxSession = std::make_unique<Ort::Session>(m_onnxEnv, m_modelPath.c_str(), m_onnxSessionOptions);
36 
37  return 0;
38  }
39 
40  unsigned int METNetHandler::getReqSize() const{
41  // Returns the required size of the inputs for the network
42  return static_cast<unsigned int>(m_inputDims[1]);
43  }
44 
45  void METNetHandler::predict(std::vector<float>& outputs, std::vector<float>& inputs) const {
46  // This method passes a input vector through the neural network and returns its estimate
47  // It requires conversions to onnx type tensors and back
48 
49  // Create a CPU tensor to be used as input
50  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
51  Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
52  inputs.data(),
53  inputs.size(),
54  m_inputDims.data(),
55  m_inputDims.size() );
56 
57  outputs = {0, 0};
58  Ort::Value output_tensor = Ort::Value::CreateTensor<float>( memory_info,
59  outputs.data(),
60  outputs.size(),
61  m_outputDims.data(),
62  m_outputDims.size() );
63 
64  // Pass the input through the network, getting a vector of outputs
65  std::lock_guard<std::mutex> lock(m_onnxMutex);
66  m_onnxSession->Run( Ort::RunOptions{nullptr},
67  m_graphInputNames.data(), &input_tensor, m_numInputs,
68  m_graphOutputNames.data(), &output_tensor, m_numOutputs );
69  }
70 
71 }
met::METNetHandler::m_modelName
std::string m_modelName
Definition: METNetHandler.h:43
met::METNetHandler::m_numInputs
size_t m_numInputs
Definition: METNetHandler.h:47
met::METNetHandler::m_graphOutputNames
std::vector< const char * > m_graphOutputNames
Definition: METNetHandler.h:52
met::METNetHandler::m_modelPath
std::string m_modelPath
Definition: METNetHandler.h:44
python.RatesEmulationExample.lock
lock
Definition: RatesEmulationExample.py:148
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
met::METNetHandler::m_numOutputs
size_t m_numOutputs
Definition: METNetHandler.h:48
met::METNetHandler::m_onnxEnv
Ort::Env m_onnxEnv
Definition: METNetHandler.h:55
met
Definition: IMETSignificance.h:24
met::METNetHandler::initialize
int initialize()
Definition: METNetHandler.cxx:24
met::METNetHandler::getReqSize
unsigned int getReqSize() const
Definition: METNetHandler.cxx:40
METNetHandler.h
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
met::METNetHandler::m_outputDims
std::vector< int64_t > m_outputDims
Definition: METNetHandler.h:50
met::METNetHandler::m_inputDims
std::vector< int64_t > m_inputDims
Definition: METNetHandler.h:49
PathResolver.h
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:321
met::METNetHandler::m_graphInputNames
std::vector< const char * > m_graphInputNames
Definition: METNetHandler.h:51
met::METNetHandler::predict
void predict(std::vector< float > &outputs, std::vector< float > &inputs) const
Definition: METNetHandler.cxx:45
met::METNetHandler::m_onnxSessionOptions
Ort::SessionOptions m_onnxSessionOptions
Definition: METNetHandler.h:56
met::METNetHandler::METNetHandler
METNetHandler()=delete