ATLAS Offline Software
Loading...
Searching...
No Matches
METNetHandler.h
Go to the documentation of this file.
1
2/*
3 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
4*/
5// Author: Bill Balunas <balunas@cern.ch>, based on earlier implementation by M. Leigh
7
8#ifndef METUTILITIES_MET_METNETHANDLER_H
9#define METUTILITIES_MET_METNETHANDLER_H
10
11// STL includes
12#include <string>
13
14// ONNX Library
15#include <onnxruntime_cxx_api.h>
16
17// For ATLAS_THREAD_SAFE
19
20namespace met {
21
23
24 public:
25
26 // Constructor with parameters
27 METNetHandler(const std::string& modelName);
28
29 // Destructor
30 virtual ~METNetHandler() = default;
31
32 // Public methods
33 int initialize();
34 unsigned int getReqSize() const;
35 void predict(std::vector<float>& outputs, std::vector<float>& inputs) const;
36
37 private:
38
39 // Default constructor
40 METNetHandler() = delete;
41
42 // Class properties
43 std::string m_modelName; // Path to the onnx file
44 std::string m_modelPath; // Output of the path resolver
45
46 // Features of the network structure
49 std::vector<int64_t> m_inputDims;
50 std::vector<int64_t> m_outputDims;
51 std::vector<const char *> m_graphInputNames;
52 std::vector<const char *> m_graphOutputNames;
53
54 // ONNX session objects
55 Ort::Env m_onnxEnv;
56 Ort::SessionOptions m_onnxSessionOptions;
57 Ort::AllocatorWithDefaultOptions m_onnxAllocator;
58 mutable std::unique_ptr<Ort::Session> m_onnxSession ATLAS_THREAD_SAFE {nullptr};
59 mutable std::mutex m_onnxMutex ATLAS_THREAD_SAFE;
60 };
61
62}
63
64#endif
Define macros for attributes used to control the static checker.
Ort::SessionOptions m_onnxSessionOptions
Ort::AllocatorWithDefaultOptions m_onnxAllocator
std::string m_modelPath
std::vector< const char * > m_graphInputNames
METNetHandler()=delete
unsigned int getReqSize() const
std::vector< const char * > m_graphOutputNames
std::vector< int64_t > m_outputDims
std::vector< int64_t > m_inputDims
METNetHandler(const std::string &modelName)
std::string m_modelName
void predict(std::vector< float > &outputs, std::vector< float > &inputs) const
std::unique_ptr< Ort::Session > m_onnxSession ATLAS_THREAD_SAFE
virtual ~METNetHandler()=default