ATLAS Offline Software
TFCSONNXHandler.h
Go to the documentation of this file.
1 
17 #ifndef TFCSONNXHANDLER_H
18 #define TFCSONNXHANDLER_H
19 
20 // inherits from
22 
23 #include <iostream>
24 
25 // ONNX Runtime include(s).
26 #include <onnxruntime_cxx_api.h>
27 
28 // For reading and writing to root
29 #include "TFile.h"
30 #include "TTree.h"
31 
32 // For storing the lambda function
33 #include <functional>
34 
35 
43 class TFCSONNXHandler : public VNetworkBase {
44 public:
45  // Don't lose the default constructor
47 
57  explicit TFCSONNXHandler(const std::string &inputFile);
58 
73  explicit TFCSONNXHandler(const std::vector<char> &bytes);
74 
83  TFCSONNXHandler(const TFCSONNXHandler &copy_from);
84 
96  NetworkOutputs compute(NetworkInputs const &inputs) const override;
97 
98  // Output to a ttree file
100 
111  void writeNetToTTree(TTree &tree) override;
112 
121  std::vector<std::string> getOutputLayers() const override;
122 
131  void deleteAllButNet() override;
132 
133 protected:
142  virtual void print(std::ostream &strm) const override;
143 
153  void setupPersistedVariables() override;
154 
164  void setupNet() override;
165 
166 private:
170  std::vector<char> m_bytes;
181  std::vector<char>
182  getSerializedSession(const std::string& tree_name = m_defaultTreeName);
190  std::vector<char> readBytesFromTTree(TTree &tree);
197  void writeBytesToTTree(TTree &tree, const std::vector<char> &bytes);
198 
199  // unique ptr deletes the object when it goes out of scope
210  std::unique_ptr<Ort::Session> m_session;
211 
220  void readSerializedSession();
221 
234  std::vector<const char *> m_inputNodeNames;
235 
236 #if ORT_API_VERSION > 11
237 
243  std::vector<Ort::AllocatedStringPtr> m_storeInputNodeNames;
244 #endif
245 
258  std::vector<const char *> m_outputNodeNames;
259 
260 #if ORT_API_VERSION > 11
261 
267  std::vector<Ort::AllocatedStringPtr> m_storeOutputNodeNames;
268 #endif
269 
276  std::vector<std::vector<int64_t>> m_inputNodeDims;
277 
284  std::vector<std::vector<int64_t>> m_outputNodeDims;
285 
292  std::vector<int64_t> m_outputNodeSize;
293 
300  template <typename Tin, typename Tout>
302 
306  std::function<NetworkOutputs(NetworkInputs)>
308 
312  Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(
313  OrtArenaAllocator, OrtMemTypeDefault);
314 
318  std::vector<std::string> m_outputLayers;
319 
320  // For the streamer
322 };
323 
324 #endif // TFCSONNXHANDLER_H
VNetworkBase::NetworkOutputs
std::map< std::string, double > NetworkOutputs
Format for network outputs.
Definition: VNetworkBase.h:100
VNetworkBase.h
VNetworkBase::VNetworkBase
VNetworkBase()
VNetworkBase default constructor.
Definition: VNetworkBase.cxx:16
TFCSONNXHandler::writeBytesToTTree
void writeBytesToTTree(TTree &tree, const std::vector< char > &bytes)
Write the content of the proto file to a TTree as a branch.
Definition: TFCSONNXHandler.cxx:310
TFCSONNXHandler::m_inputNodeNames
std::vector< const char * > m_inputNodeNames
names that index the input nodes
Definition: TFCSONNXHandler.h:234
TFCSONNXHandler::computeTemplate
NetworkOutputs computeTemplate(NetworkInputs const &input)
Do not persistify.
TFCSONNXHandler::getOutputLayers
std::vector< std::string > getOutputLayers() const override
List the names of the outputs.
Definition: TFCSONNXHandler.cxx:67
tree
TChain * tree
Definition: tile_monitor.h:30
TFCSONNXHandler::m_outputNodeNames
std::vector< const char * > m_outputNodeNames
Do not persistify.
Definition: TFCSONNXHandler.h:258
TFCSONNXHandler::m_computeLambda
std::function< NetworkOutputs(NetworkInputs)> m_computeLambda
computeTemplate with apropreate types selected.
Definition: TFCSONNXHandler.h:307
TFCSONNXHandler::print
virtual void print(std::ostream &strm) const override
Write a short description of this net to the string stream.
Definition: TFCSONNXHandler.cxx:78
VNetworkBase::NetworkInputs
std::map< std::string, std::map< std::string, double > > NetworkInputs
Format for network inputs.
Definition: VNetworkBase.h:90
VNetworkBase::m_defaultTreeName
static const std::string m_defaultTreeName
Default name for the TTree to save in.
Definition: VNetworkBase.h:173
TFCSONNXHandler::compute
NetworkOutputs compute(NetworkInputs const &inputs) const override
Function to pass values to the network.
Definition: TFCSONNXHandler.cxx:57
TFCSONNXHandler::writeNetToTTree
void writeNetToTTree(TTree &tree) override
Save the network to a TTree.
Definition: TFCSONNXHandler.cxx:62
TFCSONNXHandler::m_outputLayers
std::vector< std::string > m_outputLayers
Do not persistify.
Definition: TFCSONNXHandler.h:318
TFCSONNXHandler::setupPersistedVariables
void setupPersistedVariables() override
Perform actions that prep data to create the net.
Definition: TFCSONNXHandler.cxx:102
TFCSONNXHandler::deleteAllButNet
void deleteAllButNet() override
Get rid of any memory objects that arn't needed to run the net.
Definition: TFCSONNXHandler.cxx:72
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
TFCSONNXHandler::getSerializedSession
std::vector< char > getSerializedSession(const std::string &tree_name=m_defaultTreeName)
Return content of the proto (.onnx) file in memory.
Definition: TFCSONNXHandler.cxx:273
CaloCondBlobAlgs_fillNoiseFromASCII.inputFile
string inputFile
Definition: CaloCondBlobAlgs_fillNoiseFromASCII.py:17
TFCSONNXHandler::TFCSONNXHandler
TFCSONNXHandler(const std::string &inputFile)
TFCSONNXHandler constructor.
Definition: TFCSONNXHandler.cxx:24
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
TFCSONNXHandler::m_inputNodeDims
std::vector< std::vector< int64_t > > m_inputNodeDims
Do not persistify.
Definition: TFCSONNXHandler.h:276
TFCSONNXHandler::setupNet
void setupNet() override
Perform actions that prepare network for use.
Definition: TFCSONNXHandler.cxx:112
TFCSONNXHandler::readSerializedSession
void readSerializedSession()
Do not persistify.
Definition: TFCSONNXHandler.cxx:324
TFCSONNXHandler::m_outputNodeSize
std::vector< int64_t > m_outputNodeSize
Do not persistify.
Definition: TFCSONNXHandler.h:292
TFCSONNXHandler::m_outputNodeDims
std::vector< std::vector< int64_t > > m_outputNodeDims
Do not persistify.
Definition: TFCSONNXHandler.h:284
VNetworkBase
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: VNetworkBase.h:38
TFCSONNXHandler::ClassDefOverride
ClassDefOverride(TFCSONNXHandler, 1)
Do not persistify.
TFCSONNXHandler::m_memoryInfo
Ort::MemoryInfo m_memoryInfo
Do not persistify.
Definition: TFCSONNXHandler.h:312
TFCSONNXHandler::m_bytes
std::vector< char > m_bytes
Content of the proto file.
Definition: TFCSONNXHandler.h:170
VNetworkBase::writeNetToTTree
virtual void writeNetToTTree(TTree &tree)=0
Save the network to a TTree.
TFCSONNXHandler::readBytesFromTTree
std::vector< char > readBytesFromTTree(TTree &tree)
Retrieve the content of the proto file from a TTree.
Definition: TFCSONNXHandler.cxx:297
TFCSONNXHandler::m_session
std::unique_ptr< Ort::Session > m_session
The network session itself.
Definition: TFCSONNXHandler.h:210
TFCSONNXHandler
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration.
Definition: TFCSONNXHandler.h:43