ATLAS Offline Software
Loading...
Searching...
No Matches
TritonTool.h
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3#pragma once
4
6#include "grpc_client.h"
7#include "grpc_service.pb.h"
8
9#include <string>
10#include <vector>
11#include <memory>
12
13
15
16namespace tc = triton::client;
17
18
19namespace AthInfer {
20
21#define FAIL_IF_ERR(X, MSG) \
22{ \
23 tc::Error err = (X); \
24 if (!err.IsOk()) { \
25 ATH_MSG_ERROR(MSG); \
26 return StatusCode::FAILURE; \
27 } \
28}
29
30class TritonTool: public extends<AthAlgTool, IAthInferenceTool>
31{
32
33 public:
34 TritonTool(const std::string& type, const std::string& name, const IInterface* parent);
35
36 StatusCode initialize() override final;
37
38 virtual StatusCode inference(InputDataMap& inputData, OutputDataMap& outputData) const override final;
40 void print() const override final {} // nothing to print, but required by the interface.
41
42 protected:
43 TritonTool() = delete;
44 TritonTool(const TritonTool&) =delete;
45 TritonTool &operator=(const TritonTool&) = delete;
46
47 StringProperty m_modelName{this, "ModelName", "", "Model name"};
48 IntegerProperty m_port{this, "Port", 8001, "Port ID for Triton server"};
49 StringProperty m_modelVersion{this, "ModelVersion", "", "Model version, empty for latest"};
50 FloatProperty m_clientTimeout{this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"};
51 StringProperty m_url{this, "URL", "", "Triton URL"};
52 BooleanProperty m_useSSL{this, "UseSSL", false, "Use SSL for Triton server connection"};
53
54 private:
55 tc::InferenceServerGrpcClient* getClient() const;
56 std::unique_ptr<tc::InferOptions> m_options;
57
58 template <typename T>
59 StatusCode prepareInput(const std::string& name,
60 const std::vector<int64_t>& shape,
61 const std::vector<T>& data,
62 std::vector<std::shared_ptr<tc::InferInput>>& inputs) const;
63
64 template <typename T>
65 StatusCode extractOutput(const std::string& name,
66 const std::shared_ptr<tc::InferResult>& result,
67 std::vector<T>& outputVec) const;
68
69};
70
71 #include "TritonTool.icc"
72
73}
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
static Double_t tc
StatusCode initialize() override final
virtual StatusCode inference(InputDataMap &inputData, OutputDataMap &outputData) const override final
StringProperty m_modelVersion
Definition TritonTool.h:49
StatusCode prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs) const
Definition TritonTool.h:11
FloatProperty m_clientTimeout
Definition TritonTool.h:50
void print() const override final
Definition TritonTool.h:40
tc::InferenceServerGrpcClient * getClient() const
StringProperty m_modelName
Definition TritonTool.h:47
StatusCode extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec) const
Definition TritonTool.h:39
StringProperty m_url
Definition TritonTool.h:51
TritonTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition TritonTool.cxx:7
IntegerProperty m_port
Definition TritonTool.h:48
TritonTool & operator=(const TritonTool &)=delete
BooleanProperty m_useSSL
Definition TritonTool.h:52
std::unique_ptr< tc::InferOptions > m_options
Definition TritonTool.h:56
TritonTool(const TritonTool &)=delete
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap