ATLAS Offline Software
Loading...
Searching...
No Matches
TritonTool.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
4
5namespace tc = triton::client;
6
8 const std::string& name,
9 const IInterface* parent)
10 : base_class(type, name, parent)
11{
12 declareInterface<AthInfer::IAthInferenceTool>(this);
13}
14
16
17 m_options = std::make_unique<tc::InferOptions>(m_modelName.value());
18 m_options->model_version_ = m_modelVersion;
19 m_options->client_timeout_ = m_clientTimeout;
20
21 return getClient()? StatusCode::SUCCESS : StatusCode::FAILURE;
22}
23
24tc::InferenceServerGrpcClient* AthInfer::TritonTool::getClient() const {
25 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
26 if (!threadClient) {
27 std::string url = m_url.value() + ":" + std::to_string(m_port); // always use the gRPC port
28
29 bool verbose = false;
30
31 tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, verbose, m_useSSL);
32 if (!err.IsOk()) {
33 ATH_MSG_ERROR("Failed to create Triton gRPC client for model: " + m_modelName.value() + " at url: " + url);
34 ATH_MSG_ERROR("useSSL is set to: " + std::to_string(m_useSSL));
35 ATH_MSG_ERROR("Error message: " + err.Message());
36 return nullptr;
37 }
38
39 ATH_MSG_INFO("Triton client created for model: "+ m_modelName.value() + " at url: "+ url);
40
41 }
42 return threadClient.get();
43}
44
45StatusCode AthInfer::TritonTool::inference(InputDataMap& inputData, OutputDataMap& outputData) const {
46
47 // Create the tensor for the input data.
48 // Use shared_ptr to manage the memory of the InferInput objects.
49 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
50 inputs_.reserve(inputData.size());
51
52 for (auto& [inputName, inputInfo]: inputData) {
53 const std::vector<int64_t>& inputShape = inputInfo.first;
54 const auto& variant = inputInfo.second;
55
56 const auto status = std::visit([&](const auto& dataVec) {
57 using T = std::decay_t<decltype(dataVec[0])>;
58 return prepareInput<T>(inputName, inputShape, dataVec, inputs_);
59 }, variant);
60
61 if (status != StatusCode::SUCCESS) return status;
62 }
63
64 // construct raw points for inference
65 std::vector<tc::InferInput*> rawInputs;
66 for (auto& input: inputs_) {
67 rawInputs.push_back(input.get());
68 }
69
70 // perform the inference.
71 tc::InferResult* rawResultPtr = nullptr;
72 tc::Headers http_headers;
73 grpc_compression_algorithm compression_algorithm =
74 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
75
77 getClient()->Infer(
78 &rawResultPtr, *m_options, rawInputs, {}, http_headers, compression_algorithm),
79 "unable to run model "+ m_modelName.value() + " error: " + err.Message()
80 );
81
82 std::shared_ptr<tc::InferResult> results(rawResultPtr);
83
84 // Get the result of the inference.
85 for (auto& [outputName, outputInfo]: outputData) {
86 auto& variant = outputInfo.second;
87
88 const auto status = std::visit([&](auto& dataVec) {
89 using T = std::decay_t<decltype(dataVec[0])>;
90 return extractOutput<T>(outputName, results, dataVec);
91 }, variant);
92
93 if (status != StatusCode::SUCCESS) return status;
94 }
95 return StatusCode::SUCCESS;
96}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
static Double_t tc
bool extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)
bool prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs)
#define FAIL_IF_ERR(X, MSG)
Definition TritonTool.h:21
StatusCode initialize() override final
virtual StatusCode inference(InputDataMap &inputData, OutputDataMap &outputData) const override final
StringProperty m_modelVersion
Definition TritonTool.h:49
FloatProperty m_clientTimeout
Definition TritonTool.h:50
tc::InferenceServerGrpcClient * getClient() const
StringProperty m_modelName
Definition TritonTool.h:47
StringProperty m_url
Definition TritonTool.h:51
IntegerProperty m_port
Definition TritonTool.h:48
BooleanProperty m_useSSL
Definition TritonTool.h:52
std::unique_ptr< tc::InferOptions > m_options
Definition TritonTool.h:56
bool verbose
Definition hcg.cxx:73
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap