ATLAS Offline Software
Loading...
Searching...
No Matches
AthInfer::TritonTool Class Reference

#include <TritonTool.h>

Inheritance diagram for AthInfer::TritonTool:
Collaboration diagram for AthInfer::TritonTool:

Public Member Functions

 TritonTool (const std::string &type, const std::string &name, const IInterface *parent)
StatusCode initialize () override final
virtual StatusCode inference (InputDataMap &inputData, OutputDataMap &outputData) const override final
void print () const override final

Protected Member Functions

 TritonTool ()=delete
 TritonTool (const TritonTool &)=delete
TritonTooloperator= (const TritonTool &)=delete

Protected Attributes

StringProperty m_modelName {this, "ModelName", "", "Model name"}
IntegerProperty m_port {this, "Port", 8001, "Port ID for Triton server"}
StringProperty m_modelVersion {this, "ModelVersion", "", "Model version, empty for latest"}
FloatProperty m_clientTimeout {this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"}
StringProperty m_url {this, "URL", "", "Triton URL"}
BooleanProperty m_useSSL {this, "UseSSL", false, "Use SSL for Triton server connection"}

Private Member Functions

tc::InferenceServerGrpcClient * getClient () const
template<typename T>
StatusCode prepareInput (const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs) const
template<typename T>
StatusCode extractOutput (const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec) const

Private Attributes

std::unique_ptr< tc::InferOptions > m_options

Detailed Description

Definition at line 30 of file TritonTool.h.

Constructor & Destructor Documentation

◆ TritonTool() [1/3]

AthInfer::TritonTool::TritonTool ( const std::string & type,
const std::string & name,
const IInterface * parent )

Definition at line 7 of file TritonTool.cxx.

10 : base_class(type, name, parent)
11{
12 declareInterface<AthInfer::IAthInferenceTool>(this);
13}

◆ TritonTool() [2/3]

AthInfer::TritonTool::TritonTool ( )
protecteddelete

◆ TritonTool() [3/3]

AthInfer::TritonTool::TritonTool ( const TritonTool & )
protecteddelete

Member Function Documentation

◆ extractOutput()

template<typename T>
StatusCode AthInfer::TritonTool::extractOutput ( const std::string & name,
const std::shared_ptr< tc::InferResult > & result,
std::vector< T > & outputVec ) const
private

Definition at line 39 of file TritonTool.h.

40 {} // nothing to print, but required by the interface.
41
42 protected:
43 TritonTool() = delete;
44 TritonTool(const TritonTool&) =delete;
45 TritonTool &operator=(const TritonTool&) = delete;
46
47 StringProperty m_modelName{this, "ModelName", "", "Model name"};
48 IntegerProperty m_port{this, "Port", 8001, "Port ID for Triton server"};
49 StringProperty m_modelVersion{this, "ModelVersion", "", "Model version, empty for latest"};
50 FloatProperty m_clientTimeout{this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"};
51 StringProperty m_url{this, "URL", "", "Triton URL"};
52 BooleanProperty m_useSSL{this, "UseSSL", false, "Use SSL for Triton server connection"};
53
54 private:
55 tc::InferenceServerGrpcClient* getClient() const;
StringProperty m_modelVersion
Definition TritonTool.h:49
FloatProperty m_clientTimeout
Definition TritonTool.h:50
tc::InferenceServerGrpcClient * getClient() const
StringProperty m_modelName
Definition TritonTool.h:47
StringProperty m_url
Definition TritonTool.h:51
TritonTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition TritonTool.cxx:7
IntegerProperty m_port
Definition TritonTool.h:48
TritonTool & operator=(const TritonTool &)=delete
BooleanProperty m_useSSL
Definition TritonTool.h:52

◆ getClient()

tc::InferenceServerGrpcClient * AthInfer::TritonTool::getClient ( ) const
private

Definition at line 24 of file TritonTool.cxx.

24 {
25 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
26 if (!threadClient) {
27 std::string url = m_url.value() + ":" + std::to_string(m_port); // always use the gRPC port
28
29 bool verbose = false;
30
31 tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, verbose, m_useSSL);
32 if (!err.IsOk()) {
33 ATH_MSG_ERROR("Failed to create Triton gRPC client for model: " + m_modelName.value() + " at url: " + url);
34 ATH_MSG_ERROR("Error message: " + err.Message());
35 return nullptr;
36 }
37
38 ATH_MSG_INFO("Triton client created for model: "+ m_modelName.value() + " at url: "+ url);
39
40 }
41 return threadClient.get();
42}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
bool verbose
Definition hcg.cxx:73

◆ inference()

StatusCode AthInfer::TritonTool::inference ( InputDataMap & inputData,
OutputDataMap & outputData ) const
finaloverridevirtual

Definition at line 44 of file TritonTool.cxx.

44 {
45
46 // Create the tensor for the input data.
47 // Use shared_ptr to manage the memory of the InferInput objects.
48 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
49 inputs_.reserve(inputData.size());
50
51 for (auto& [inputName, inputInfo]: inputData) {
52 const std::vector<int64_t>& inputShape = inputInfo.first;
53 const auto& variant = inputInfo.second;
54
55 const auto status = std::visit([&](const auto& dataVec) {
56 using T = std::decay_t<decltype(dataVec[0])>;
57 return prepareInput<T>(inputName, inputShape, dataVec, inputs_);
58 }, variant);
59
60 if (status != StatusCode::SUCCESS) return status;
61 }
62
63 // construct raw points for inference
64 std::vector<tc::InferInput*> rawInputs;
65 for (auto& input: inputs_) {
66 rawInputs.push_back(input.get());
67 }
68
69 // perform the inference.
70 tc::InferResult* rawResultPtr = nullptr;
71 tc::Headers http_headers;
72 grpc_compression_algorithm compression_algorithm =
73 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
74
76 getClient()->Infer(
77 &rawResultPtr, *m_options, rawInputs, {}, http_headers, compression_algorithm),
78 "unable to run model "+ m_modelName.value() + " error: " + err.Message()
79 );
80
81 std::shared_ptr<tc::InferResult> results(rawResultPtr);
82
83 // Get the result of the inference.
84 for (auto& [outputName, outputInfo]: outputData) {
85 auto& variant = outputInfo.second;
86
87 const auto status = std::visit([&](auto& dataVec) {
88 using T = std::decay_t<decltype(dataVec[0])>;
89 return extractOutput<T>(outputName, results, dataVec);
90 }, variant);
91
92 if (status != StatusCode::SUCCESS) return status;
93 }
94 return StatusCode::SUCCESS;
95}
#define FAIL_IF_ERR(X, MSG)
Definition TritonTool.h:21
StatusCode prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs) const
Definition TritonTool.h:11
StatusCode extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec) const
Definition TritonTool.h:39
std::unique_ptr< tc::InferOptions > m_options
Definition TritonTool.h:56
unsigned long long T
status
Definition merge.py:16

◆ initialize()

StatusCode AthInfer::TritonTool::initialize ( )
finaloverride

Definition at line 15 of file TritonTool.cxx.

15 {
16
17 m_options = std::make_unique<tc::InferOptions>(m_modelName.value());
18 m_options->model_version_ = m_modelVersion;
19 m_options->client_timeout_ = m_clientTimeout;
20
21 return getClient()? StatusCode::SUCCESS : StatusCode::FAILURE;
22}

◆ operator=()

TritonTool & AthInfer::TritonTool::operator= ( const TritonTool & )
protecteddelete

◆ prepareInput()

template<typename T>
StatusCode AthInfer::TritonTool::prepareInput ( const std::string & name,
const std::vector< int64_t > & shape,
const std::vector< T > & data,
std::vector< std::shared_ptr< tc::InferInput > > & inputs ) const
private

Definition at line 11 of file TritonTool.h.

19 {
20
21#define FAIL_IF_ERR(X, MSG) \
22{ \
23 tc::Error err = (X); \
24 if (!err.IsOk()) { \
25 ATH_MSG_ERROR(MSG); \
26 return StatusCode::FAILURE; \
27 } \
28}
29
30class TritonTool: public extends<AthAlgTool, IAthInferenceTool>
31{
32
33 public:
34 TritonTool(const std::string& type, const std::string& name, const IInterface* parent);
35

◆ print()

void AthInfer::TritonTool::print ( ) const
inlinefinaloverride

Definition at line 40 of file TritonTool.h.

40{} // nothing to print, but required by the interface.

Member Data Documentation

◆ m_clientTimeout

FloatProperty AthInfer::TritonTool::m_clientTimeout {this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"}
protected

Definition at line 50 of file TritonTool.h.

50{this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"};

◆ m_modelName

StringProperty AthInfer::TritonTool::m_modelName {this, "ModelName", "", "Model name"}
protected

Definition at line 47 of file TritonTool.h.

47{this, "ModelName", "", "Model name"};

◆ m_modelVersion

StringProperty AthInfer::TritonTool::m_modelVersion {this, "ModelVersion", "", "Model version, empty for latest"}
protected

Definition at line 49 of file TritonTool.h.

49{this, "ModelVersion", "", "Model version, empty for latest"};

◆ m_options

std::unique_ptr<tc::InferOptions> AthInfer::TritonTool::m_options
private

Definition at line 56 of file TritonTool.h.

◆ m_port

IntegerProperty AthInfer::TritonTool::m_port {this, "Port", 8001, "Port ID for Triton server"}
protected

Definition at line 48 of file TritonTool.h.

48{this, "Port", 8001, "Port ID for Triton server"};

◆ m_url

StringProperty AthInfer::TritonTool::m_url {this, "URL", "", "Triton URL"}
protected

Definition at line 51 of file TritonTool.h.

51{this, "URL", "", "Triton URL"};

◆ m_useSSL

BooleanProperty AthInfer::TritonTool::m_useSSL {this, "UseSSL", false, "Use SSL for Triton server connection"}
protected

Definition at line 52 of file TritonTool.h.

52{this, "UseSSL", false, "Use SSL for Triton server connection"};

The documentation for this class was generated from the following files: