ATLAS Offline Software
Classes | Namespaces | Functions
SaltModelTriton.cxx File Reference
#include "FlavorTagInference/SaltModelTriton.h"
#include "FlavorTagInference/SaltModelGraphConfig.h"
#include "FlavorTagInference/SaltModelOutput.h"
#include "CxxUtils/checker_macros.h"
#include <onnxruntime_cxx_api.h>
#include <stdexcept>
#include <tuple>
#include <set>
#include <memory>
Include dependency graph for SaltModelTriton.cxx:

Go to the source code of this file.

Classes

struct  TritonDType< T >
 
struct  TritonDType< float >
 
struct  TritonDType< int64_t >
 

Namespaces

 FlavorTagInference
 This file contains "getter" functions used for accessing tagger inputs from the EDM.
 

Functions

template<typename T >
bool prepareInput (const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput >> &inputs)
 
template<typename T >
bool extractOutput (const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)
 

Function Documentation

◆ extractOutput()

template<typename T >
bool extractOutput ( const std::string &  name,
const std::shared_ptr< tc::InferResult > &  result,
std::vector< T > &  outputVec 
)

Definition at line 55 of file SaltModelTriton.cxx.

58 {
59  const uint8_t* rawData = nullptr;
60  size_t size = 0;
61 
62  // Get access to the buffer holding raw results of specified output returned by the server.
63  // Note: the buffer is owned by InferResult instance.
64  // Users can copy out the data if required to extend the lifetime.
65  tc::Error err = result->RawData(name, &rawData, &size);
66  if(!err.IsOk()) {
67  std::cerr << "Unable to get raw output for: " + name << std::endl;
68  return false;
69  }
70 
71  outputVec.resize(size / sizeof(T));
72  std::memcpy(outputVec.data(), rawData, size);
73  return true;
74 }

◆ prepareInput()

template<typename T >
bool prepareInput ( const std::string &  name,
const std::vector< int64_t > &  shape,
const std::vector< T > &  data,
std::vector< std::shared_ptr< tc::InferInput >> &  inputs 
)

Definition at line 22 of file SaltModelTriton.cxx.

26 {
27  const char* dtype = TritonDType<T>::value;
28  tc::InferInput* rawInputPtr = nullptr;
29 
30  // create the InferInput object with the predefined name, shape, and data type.
31  tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
32  if(!err.IsOk()) {
33  std::cerr << "Unable to create input: " + name << std::endl;
34  return false;
35  }
36 
37  // Append tensor values for this input from a byte array.
38  // Note: The vector is not copied and so it must not be modified or destroyed
39  // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
40  // Multiple calls can be made to this API to keep adding tensor data for this input.
41  // The data will be delivered in the order it was added.
42  std::shared_ptr<tc::InferInput> input(rawInputPtr);
43  err = input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
44  data.size() * sizeof(T));
45  if(!err.IsOk()) {
46  std::cerr << "Unable to set input data for: " + name << std::endl;
47  return false;
48  }
49 
50  inputs.push_back(std::move(input));
51  return true;
52 }
TritonDType
Definition: SaltModelTriton.cxx:17
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
get_generator_info.result
result
Definition: get_generator_info.py:21
xAOD::uint8_t
uint8_t
Definition: Muon_v1.cxx:553
PlotCalibFromCool.dtype
dtype
Definition: PlotCalibFromCool.py:495
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
dqt_zlumi_pandas.err
err
Definition: dqt_zlumi_pandas.py:183
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
trigbs_mixBSevents.input
input
Definition: trigbs_mixBSevents.py:56
L1Topo::Error
Error
The different types of error that can be flagged in the L1TopoRDO.
Definition: Error.h:16
TSU::T
unsigned long long T
Definition: L1TopoDataTypes.h:35