ATLAS Offline Software
Loading...
Searching...
No Matches
TFCSONNXHandler.h
Go to the documentation of this file.
1
16
17#ifndef TFCSONNXHANDLER_H
18#define TFCSONNXHANDLER_H
19
20// inherits from
22
23#include <iostream>
24
25// ONNX Runtime include(s).
26#include <onnxruntime_cxx_api.h>
27
28// For reading and writing to root
29#include "TFile.h"
30#include "TTree.h"
31
32// For storing the lambda function
33#include <functional>
34
35
44public:
45 // Don't lose the default constructor
47
57 explicit TFCSONNXHandler(const std::string &inputFile);
58
73 explicit TFCSONNXHandler(const std::vector<char> &bytes);
74
83 TFCSONNXHandler(const TFCSONNXHandler &copy_from);
84
96 NetworkOutputs compute(NetworkInputs const &inputs) const override;
97
98 // Output to a ttree file
100
111 void writeNetToTTree(TTree &tree) override;
112
121 std::vector<std::string> getOutputLayers() const override;
122
131 void deleteAllButNet() override;
132
133protected:
142 virtual void print(std::ostream &strm) const override;
143
153 void setupPersistedVariables() override;
154
164 void setupNet() override;
165
166private:
170 std::vector<char> m_bytes;
181 std::vector<char>
182 getSerializedSession(const std::string& tree_name = m_defaultTreeName);
190 std::vector<char> readBytesFromTTree(TTree &tree);
197 void writeBytesToTTree(TTree &tree, const std::vector<char> &bytes);
198
199 // unique ptr deletes the object when it goes out of scope
210 std::unique_ptr<Ort::Session> m_session;
221
234 std::vector<const char *> m_inputNodeNames;
235
236#if ORT_API_VERSION > 11
243 std::vector<Ort::AllocatedStringPtr> m_storeInputNodeNames;
244#endif
245
258 std::vector<const char *> m_outputNodeNames;
259
260#if ORT_API_VERSION > 11
267 std::vector<Ort::AllocatedStringPtr> m_storeOutputNodeNames;
268#endif
269
276 std::vector<std::vector<int64_t>> m_inputNodeDims;
284 std::vector<std::vector<int64_t>> m_outputNodeDims;
292 std::vector<int64_t> m_outputNodeSize;
293
300 template <typename Tin, typename Tout>
302
306 std::function<NetworkOutputs(NetworkInputs)>
308
312 Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(
313 OrtArenaAllocator, OrtMemTypeDefault);
314
318 std::vector<std::string> m_outputLayers;
319
320 // For the streamer
322};
323
324#endif // TFCSONNXHANDLER_H
ClassDefOverride(TFCSONNXHandler, 1)
Do not persistify.
void readSerializedSession()
Do not persistify.
virtual void print(std::ostream &strm) const override
Write a short description of this net to the string stream.
VNetworkBase()
VNetworkBase default constructor.
std::vector< const char * > m_inputNodeNames
names that index the input nodes
void setupPersistedVariables() override
Perform actions that prep data to create the net.
NetworkOutputs computeTemplate(NetworkInputs const &input)
Do not persistify.
std::vector< char > getSerializedSession(const std::string &tree_name=m_defaultTreeName)
Return content of the proto (.onnx) file in memory.
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
std::vector< const char * > m_outputNodeNames
Do not persistify.
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
std::vector< std::vector< int64_t > > m_outputNodeDims
Do not persistify.
std::unique_ptr< Ort::Session > m_session
The network session itself.
Ort::MemoryInfo m_memoryInfo
Do not persistify.
std::vector< std::string > m_outputLayers
Do not persistify.
TFCSONNXHandler(const std::string &inputFile)
TFCSONNXHandler constructor.
void setupNet() override
Perform actions that prepare network for use.
std::vector< char > readBytesFromTTree(TTree &tree)
Retrieve the content of the proto file from a TTree.
std::vector< int64_t > m_outputNodeSize
Do not persistify.
std::function< NetworkOutputs(NetworkInputs)> m_computeLambda
computeTemplate with apropreate types selected.
void writeNetToTTree(TTree &tree) override
Save the network to a TTree.
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
void writeBytesToTTree(TTree &tree, const std::vector< char > &bytes)
Write the content of the proto file to a TTree as a branch.
std::vector< char > m_bytes
Content of the proto file.
std::vector< std::vector< int64_t > > m_inputNodeDims
Do not persistify.
static const std::string m_defaultTreeName
Default name for the TTree to save in.
VNetworkBase()
VNetworkBase default constructor.
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
std::map< std::string, double > NetworkOutputs
Format for network outputs.
TChain * tree