ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelTriton.cxx File Reference
#include "SaltModelTriton.h"
#include "FlavorTagInference/SaltModelGraphConfig.h"
#include "FlavorTagInference/SaltModelOutput.h"
#include "CxxUtils/checker_macros.h"
#include <onnxruntime_cxx_api.h>
#include <stdexcept>
#include <tuple>
#include <set>
#include <memory>
Include dependency graph for SaltModelTriton.cxx:

Go to the source code of this file.

Classes

struct  TritonDType< float >
struct  TritonDType< int64_t >

Namespaces

namespace  FlavorTagInference
 This file contains "getter" functions used for accessing tagger inputs from the EDM.

Functions

template<typename T>
bool prepareInput (const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs)
template<typename T>
bool extractOutput (const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)

Function Documentation

◆ extractOutput()

template<typename T>
bool extractOutput ( const std::string & name,
const std::shared_ptr< tc::InferResult > & result,
std::vector< T > & outputVec )

Definition at line 55 of file SaltModelTriton.cxx.

58{
59 const uint8_t* rawData = nullptr;
60 size_t size = 0;
61
62 // Get access to the buffer holding raw results of specified output returned by the server.
63 // Note: the buffer is owned by InferResult instance.
64 // Users can copy out the data if required to extend the lifetime.
65 tc::Error err = result->RawData(name, &rawData, &size);
66 if(!err.IsOk()) {
67 std::cerr << "Unable to get raw output for: " + name << std::endl;
68 return false;
69 }
70
71 outputVec.resize(size / sizeof(T));
72 std::memcpy(outputVec.data(), rawData, size);
73 return true;
74}

◆ prepareInput()

template<typename T>
bool prepareInput ( const std::string & name,
const std::vector< int64_t > & shape,
const std::vector< T > & data,
std::vector< std::shared_ptr< tc::InferInput > > & inputs )

Definition at line 22 of file SaltModelTriton.cxx.

26{
27 const char* dtype = TritonDType<T>::value;
28 tc::InferInput* rawInputPtr = nullptr;
29
30 // create the InferInput object with the predefined name, shape, and data type.
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
32 if(!err.IsOk()) {
33 std::cerr << "Unable to create input: " + name << std::endl;
34 return false;
35 }
36
37 // Append tensor values for this input from a byte array.
38 // Note: The vector is not copied and so it must not be modified or destroyed
39 // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
40 // Multiple calls can be made to this API to keep adding tensor data for this input.
41 // The data will be delivered in the order it was added.
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
44 data.size() * sizeof(T));
45 if(!err.IsOk()) {
46 std::cerr << "Unable to set input data for: " + name << std::endl;
47 return false;
48 }
49
50 inputs.push_back(std::move(input));
51 return true;
52}
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11