8 const std::string&
name,
12 declareInterface<AthInfer::IAthInferenceTool>(
this);
17 m_options = std::make_unique<tc::InferOptions>(m_modelName.value());
18 m_options->model_version_ = m_modelVersion;
19 m_options->client_timeout_ = m_clientTimeout;
21 return getClient()? StatusCode::SUCCESS : StatusCode::FAILURE;
25 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
33 ATH_MSG_ERROR(
"Failed to create Triton gRPC client for model: " + m_modelName.value() +
" at url: " +
url);
38 ATH_MSG_INFO(
"Triton client created for model: "+ m_modelName.value() +
" at url: "+
url);
41 return threadClient.get();
48 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
49 inputs_.reserve(inputData.size());
51 for (
auto& [inputName, inputInfo]: inputData) {
52 const std::vector<int64_t>& inputShape = inputInfo.first;
53 const auto& variant = inputInfo.second;
55 const auto status = std::visit([&](
const auto& dataVec) {
56 using T = std::decay_t<decltype(dataVec[0])>;
57 return prepareInput<T>(inputName, inputShape, dataVec, inputs_);
64 std::vector<tc::InferInput*> rawInputs;
65 for (
auto&
input: inputs_) {
66 rawInputs.push_back(
input.get());
70 tc::InferResult* rawResultPtr =
nullptr;
71 tc::Headers http_headers;
72 grpc_compression_algorithm compression_algorithm =
73 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
77 &rawResultPtr, *m_options, rawInputs, {}, http_headers, compression_algorithm),
78 "unable to run model "+ m_modelName.value() +
" error: " +
err.Message()
81 std::shared_ptr<tc::InferResult>
results(rawResultPtr);
84 for (
auto& [
outputName, outputInfo]: outputData) {
85 auto& variant = outputInfo.second;
87 const auto status = std::visit([&](
auto& dataVec) {
88 using T = std::decay_t<decltype(dataVec[0])>;
94 return StatusCode::SUCCESS;