ATLAS Offline Software
Functions
AthOnnxUtils Namespace Reference

Functions

template<typename T >
std::vector< T > flattenNestedVectors (const std::vector< std::vector< T >> &features)
 
void getInputNodeInfo (const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
 
void getOutputNodeInfo (const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
 
void getNodeInfo (const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
 
int64_t getTensorSize (const std::vector< int64_t > &dataShape)
 
void inferenceWithIOBinding (Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
 
template<typename T >
Ort::Value createTensor (std::vector< T > &data, const std::vector< int64_t > &dataShape)
 

Function Documentation

◆ createTensor()

template<typename T >
Ort::Value AthOnnxUtils::createTensor ( std::vector< T > &  data,
const std::vector< int64_t > &  dataShape 
)

Definition at line 78 of file OnnxUtils.h.

78  {
79  // Create a tensor from the data.
80  Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
81  return Ort::Value::CreateTensor<T>(memoryInfo, data.data(), data.size(), dataShape.data(), dataShape.size());
82 }

◆ flattenNestedVectors()

template<typename T >
std::vector<T> AthOnnxUtils::flattenNestedVectors ( const std::vector< std::vector< T >> &  features)
inline

Definition at line 20 of file OnnxUtils.h.

20  {
21  // 1. Compute the total size required.
22  int total_size = 0;
23  for (const auto& feature : features) total_size += feature.size();
24 
25  std::vector<T> flatten1D;
26  flatten1D.reserve(total_size);
27 
28  for (const auto& feature : features)
29  for (const auto& elem : feature)
30  flatten1D.push_back(elem);
31 
32  return flatten1D;
33 }

◆ getInputNodeInfo()

void AthOnnxUtils::getInputNodeInfo ( const Ort::Session &  session,
std::vector< std::vector< int64_t > > &  dataShape,
std::vector< std::string > &  nodeNames 
)

Definition at line 33 of file OnnxUtils.cxx.

37  {
38  getNodeInfo(session, dataShape, nodeNames, true);
39 }

◆ getNodeInfo()

void AthOnnxUtils::getNodeInfo ( const Ort::Session &  session,
std::vector< std::vector< int64_t > > &  dataShape,
std::vector< std::string > &  nodeNames,
bool  isInput 
)

Definition at line 9 of file OnnxUtils.cxx.

14  {
15  dataShape.clear();
16  nodeNames.clear();
17 
18  size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
19  dataShape.reserve(numNodes);
20  nodeNames.reserve(numNodes);
21 
22  Ort::AllocatorWithDefaultOptions allocator;
23  for( std::size_t i = 0; i < numNodes; i++ ) {
24  Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
25  auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
26  dataShape.emplace_back(tensorInfo.GetShape());
27 
28  auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
29  nodeNames.emplace_back(nodeName.get());
30  }
31 }

◆ getOutputNodeInfo()

void AthOnnxUtils::getOutputNodeInfo ( const Ort::Session &  session,
std::vector< std::vector< int64_t > > &  dataShape,
std::vector< std::string > &  nodeNames 
)

Definition at line 41 of file OnnxUtils.cxx.

45  {
46  getNodeInfo(session, dataShape, nodeNames, false);
47 }

◆ getTensorSize()

int64_t AthOnnxUtils::getTensorSize ( const std::vector< int64_t > &  dataShape)

Definition at line 73 of file OnnxUtils.cxx.

73  {
74  int64_t size = 1;
75  for (const auto& dim : dataShape) {
76  size *= dim;
77  }
78  return size;
79 }

◆ inferenceWithIOBinding()

void AthOnnxUtils::inferenceWithIOBinding ( Ort::Session &  session,
const std::vector< std::string > &  inputNames,
const std::vector< Ort::Value > &  inputData,
const std::vector< std::string > &  outputNames,
const std::vector< Ort::Value > &  outputData 
)

Definition at line 49 of file OnnxUtils.cxx.

53  {
54 
55  if (inputNames.empty()) {
56  throw std::runtime_error("Onnxruntime input data maping cannot be empty");
57  }
58  assert(inputNames.size() == inputData.size());
59 
60  Ort::IoBinding iobinding(session);
61  for(size_t idx = 0; idx < inputNames.size(); ++idx){
62  iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
63  }
64 
65 
66  for(size_t idx = 0; idx < outputNames.size(); ++idx){
67  iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
68  }
69 
70  session.Run(Ort::RunOptions{nullptr}, iobinding);
71 }
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
lumiFormat.i
int i
Definition: lumiFormat.py:85
AthOnnxUtils::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:18
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69