ATLAS Offline Software
Loading...
Searching...
No Matches
METNetSigHandler.cxx
Go to the documentation of this file.
1
2
3/*
4 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
5*/
6
7// METNetSigHandler.cxx
8// Implementation file for class METNetSigHandler
9// Author: Alberto Plebani <alberto.plebani@cern.ch>, based on earlier implementation by M. Leigh
11
12// METUtilities includes
15#include <iostream>
16
17namespace met {
18
19 // Dimensions cannot be changed without altering METNet, so it's OK to hard code these for the foreseeable future.
20 METNetSigHandler::METNetSigHandler(const std::string& modelName) :
21 m_modelName(modelName),
22 m_numInputs(1),
23 m_numOutputs(4),
24 m_inputDims({1,77}),
25 m_outputDims({1,1}), // m_outputDims({1,2}),
26 m_graphInputNames({"input"}),
27 m_graphOutputNames({"MET_x", "MET_y", "Sigma_x", "Sigma_y"}){}
28
30
32
33 // Use the path resolver to find the location of the network .onnx file
34 // m_modelPath = PathResolverFindCalibFile(m_modelName);
35 // m_modelPath = "ggHyydAnalysis/model_mc23.onnx";
37 // m_modelPath = "../source/ggHyydAnalysis/share/model_mc23.onnx";
38 if (m_modelPath == "") return 1;
39
40 // Use the default ONNX session settings for 1 CPU thread
41 m_onnxSessionOptions.SetIntraOpNumThreads(1);
42 m_onnxSessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
43
44 // Initialise the ONNX environment and session using the above options and the model name
45 m_onnxSession = std::make_unique<Ort::Session>(m_onnxEnv, m_modelPath.c_str(), m_onnxSessionOptions);
46
47 return 0;
48 }
49
51 // Returns the required size of the inputs for the network
52 return static_cast<int>(m_inputDims[1]);
53 }
54
55 std::vector<float> METNetSigHandler::predict(std::vector<float> inputs) const {
56 // This method passes a input vector through the neural network and returns its estimate
57 // It requires conversions to onnx type tensors and back
58
59 // Create a CPU tensor to be used as input
60 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
61 Ort::Value input_tensor = Ort::Value::CreateTensor<float>( memory_info,
62 inputs.data(),
63 inputs.size(),
64 m_inputDims.data(),
65 m_inputDims.size() );
66
67
68 std::vector<Ort::Value> ort_outputs = m_onnxSession->Run(
69 Ort::RunOptions{nullptr},
70 m_graphInputNames.data(), &input_tensor, m_numInputs,
72 );
73
74 // Extract output values and convert from GeV to MeV
75 std::vector<float> outputs;
76 for (size_t i = 0; i < ort_outputs.size(); ++i) {
77 const float* data = ort_outputs[i].GetTensorData<float>();
78 outputs.push_back(data[0] * 1000.0f);
79 }
80 return outputs;
81
82 }
83
84}
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
std::vector< const char * > m_graphInputNames
Ort::SessionOptions m_onnxSessionOptions
std::vector< float > predict(std::vector< float > inputs) const
std::unique_ptr< Ort::Session > m_onnxSession
std::vector< int64_t > m_inputDims
std::vector< const char * > m_graphOutputNames