ATLAS Offline Software
Loading...
Searching...
No Matches
METNetHandler.h
Go to the documentation of this file.
1
2/*
3 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
4*/
5// Author: Bill Balunas <balunas@cern.ch>, based on earlier implementation by M. Leigh
7
8#ifndef METUTILITIES_MET_METNETHANDLER_H
9#define METUTILITIES_MET_METNETHANDLER_H
10
11// STL includes
12#include <string>
13#include <mutex>
14
15// ONNX Library
16#include <onnxruntime_cxx_api.h>
17
18// For ATLAS_THREAD_SAFE
20
21namespace met {
22
24
25 public:
26
27 // Constructor with parameters
28 METNetHandler(const std::string& modelName);
29
30 // Destructor
31 virtual ~METNetHandler() = default;
32
33 // Public methods
34 int initialize();
35 unsigned int getReqSize() const;
36 void predict(std::vector<float>& outputs, std::vector<float>& inputs) const;
37
38 private:
39
40 // Default constructor
41 METNetHandler() = delete;
42
43 // Class properties
44 std::string m_modelName; // Path to the onnx file
45 std::string m_modelPath; // Output of the path resolver
46
47 // Features of the network structure
50 std::vector<int64_t> m_inputDims;
51 std::vector<int64_t> m_outputDims;
52 std::vector<const char *> m_graphInputNames;
53 std::vector<const char *> m_graphOutputNames;
54
55 // ONNX session objects
56 Ort::Env m_onnxEnv;
57 Ort::SessionOptions m_onnxSessionOptions;
58 Ort::AllocatorWithDefaultOptions m_onnxAllocator;
59 mutable std::unique_ptr<Ort::Session> m_onnxSession ATLAS_THREAD_SAFE {nullptr};
60 mutable std::mutex m_onnxMutex ATLAS_THREAD_SAFE;
61 };
62
63}
64
65#endif
Define macros for attributes used to control the static checker.
Ort::SessionOptions m_onnxSessionOptions
Ort::AllocatorWithDefaultOptions m_onnxAllocator
std::string m_modelPath
std::vector< const char * > m_graphInputNames
METNetHandler()=delete
unsigned int getReqSize() const
std::vector< const char * > m_graphOutputNames
std::vector< int64_t > m_outputDims
std::vector< int64_t > m_inputDims
METNetHandler(const std::string &modelName)
std::string m_modelName
void predict(std::vector< float > &outputs, std::vector< float > &inputs) const
std::unique_ptr< Ort::Session > m_onnxSession ATLAS_THREAD_SAFE
virtual ~METNetHandler()=default