#include <TritonTool.h>
|
StringProperty | m_modelName {this, "ModelName", "", "Model name"} |
|
IntegerProperty | m_port {this, "Port", 8001, "Port ID for Triton server"} |
|
StringProperty | m_modelVersion {this, "ModelVersion", "", "Model version, empty for latest"} |
|
FloatProperty | m_clientTimeout {this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"} |
|
StringProperty | m_url {this, "URL", "", "Triton URL"} |
|
BooleanProperty | m_useSSL {this, "UseSSL", false, "Use SSL for Triton server connection"} |
|
|
std::unique_ptr< tc::InferOptions > | m_options |
|
Definition at line 30 of file TritonTool.h.
◆ TritonTool() [1/3]
AthInfer::TritonTool::TritonTool |
( |
const std::string & |
type, |
|
|
const std::string & |
name, |
|
|
const IInterface * |
parent |
|
) |
| |
Definition at line 7 of file TritonTool.cxx.
12 declareInterface<AthInfer::IAthInferenceTool>(
this);
◆ TritonTool() [2/3]
AthInfer::TritonTool::TritonTool |
( |
| ) |
|
|
protecteddelete |
◆ TritonTool() [3/3]
◆ extractOutput()
template<typename T >
StatusCode AthInfer::TritonTool::extractOutput |
( |
const std::string & |
name, |
|
|
const std::shared_ptr< tc::InferResult > & |
result, |
|
|
std::vector< T > & |
outputVec |
|
) |
| const |
|
private |
◆ getClient()
tc::InferenceServerGrpcClient * AthInfer::TritonTool::getClient |
( |
| ) |
const |
|
private |
Definition at line 24 of file TritonTool.cxx.
25 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
41 return threadClient.get();
◆ inference()
Definition at line 44 of file TritonTool.cxx.
48 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
49 inputs_.reserve(inputData.size());
51 for (
auto& [inputName, inputInfo]: inputData) {
52 const std::vector<int64_t>& inputShape = inputInfo.first;
53 const auto& variant = inputInfo.second;
55 const auto status = std::visit([&](
const auto& dataVec) {
56 using T = std::decay_t<decltype(dataVec[0])>;
57 return prepareInput<T>(inputName, inputShape, dataVec, inputs_);
64 std::vector<tc::InferInput*> rawInputs;
65 for (
auto&
input: inputs_) {
66 rawInputs.push_back(
input.get());
70 tc::InferResult* rawResultPtr =
nullptr;
71 tc::Headers http_headers;
72 grpc_compression_algorithm compression_algorithm =
73 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
77 &rawResultPtr, *
m_options, rawInputs, {}, http_headers, compression_algorithm),
78 "unable to run model "+
m_modelName.value() +
" error: " +
err.Message()
81 std::shared_ptr<tc::InferResult>
results(rawResultPtr);
84 for (
auto& [
outputName, outputInfo]: outputData) {
85 auto& variant = outputInfo.second;
87 const auto status = std::visit([&](
auto& dataVec) {
88 using T = std::decay_t<decltype(dataVec[0])>;
94 return StatusCode::SUCCESS;
◆ initialize()
StatusCode AthInfer::TritonTool::initialize |
( |
| ) |
|
|
finaloverride |
◆ operator=()
◆ prepareInput()
template<typename T >
StatusCode AthInfer::TritonTool::prepareInput |
( |
const std::string & |
name, |
|
|
const std::vector< int64_t > & |
shape, |
|
|
const std::vector< T > & |
data, |
|
|
std::vector< std::shared_ptr< tc::InferInput >> & |
inputs |
|
) |
| const |
|
private |
◆ print()
void AthInfer::TritonTool::print |
( |
| ) |
const |
|
inlinefinaloverride |
◆ m_clientTimeout
FloatProperty AthInfer::TritonTool::m_clientTimeout {this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"} |
|
protected |
◆ m_modelName
StringProperty AthInfer::TritonTool::m_modelName {this, "ModelName", "", "Model name"} |
|
protected |
◆ m_modelVersion
StringProperty AthInfer::TritonTool::m_modelVersion {this, "ModelVersion", "", "Model version, empty for latest"} |
|
protected |
◆ m_options
std::unique_ptr<tc::InferOptions> AthInfer::TritonTool::m_options |
|
private |
◆ m_port
IntegerProperty AthInfer::TritonTool::m_port {this, "Port", 8001, "Port ID for Triton server"} |
|
protected |
◆ m_url
StringProperty AthInfer::TritonTool::m_url {this, "URL", "", "Triton URL"} |
|
protected |
◆ m_useSSL
BooleanProperty AthInfer::TritonTool::m_useSSL {this, "UseSSL", false, "Use SSL for Triton server connection"} |
|
protected |
The documentation for this class was generated from the following files: