ATLAS Offline Software
TritonTool.h
Go to the documentation of this file.
1 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 #pragma once
4 
6 #include "grpc_client.h"
7 #include "grpc_service.pb.h"
8 
9 #include <string>
10 #include <vector>
11 #include <memory>
12 
13 
15 
16 namespace tc = triton::client;
17 
18 
19 namespace AthInfer {
20 
21 #define FAIL_IF_ERR(X, MSG) \
22 { \
23  tc::Error err = (X); \
24  if (!err.IsOk()) { \
25  ATH_MSG_ERROR(MSG); \
26  return StatusCode::FAILURE; \
27  } \
28 }
29 
30 class TritonTool: public extends<AthAlgTool, IAthInferenceTool>
31 {
32 
33  public:
34  TritonTool(const std::string& type, const std::string& name, const IInterface* parent);
35 
36  StatusCode initialize() override final;
37 
38  virtual StatusCode inference(InputDataMap& inputData, OutputDataMap& outputData) const override final;
39 
40  void print() const override final {} // nothing to print, but required by the interface.
41 
42  protected:
43  TritonTool() = delete;
44  TritonTool(const TritonTool&) =delete;
45  TritonTool &operator=(const TritonTool&) = delete;
46 
47  StringProperty m_modelName{this, "ModelName", "", "Model name"};
48  IntegerProperty m_port{this, "Port", 8001, "Port ID for Triton server"};
49  StringProperty m_modelVersion{this, "ModelVersion", "", "Model version, empty for latest"};
50  FloatProperty m_clientTimeout{this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"};
51  StringProperty m_url{this, "URL", "", "Triton URL"};
52  BooleanProperty m_useSSL{this, "UseSSL", false, "Use SSL for Triton server connection"};
53 
54  private:
55  tc::InferenceServerGrpcClient* getClient() const;
56  std::unique_ptr<tc::InferOptions> m_options;
57 
58  template <typename T>
59  StatusCode prepareInput(const std::string& name,
60  const std::vector<int64_t>& shape,
61  const std::vector<T>& data,
62  std::vector<std::shared_ptr<tc::InferInput>>& inputs) const;
63 
64  template <typename T>
65  StatusCode extractOutput(const std::string& name,
66  const std::shared_ptr<tc::InferResult>& result,
67  std::vector<T>& outputVec) const;
68 
69 };
70 
71  #include "TritonTool.icc"
72 
73 }
AthInfer::TritonTool::m_modelName
StringProperty m_modelName
Definition: TritonTool.h:47
AthInfer
Definition: ExampleMLInferenceWithTriton.cxx:12
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthInfer::TritonTool::extractOutput
StatusCode extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec) const
get_generator_info.result
result
Definition: get_generator_info.py:21
AthInfer::TritonTool::m_options
std::unique_ptr< tc::InferOptions > m_options
Definition: TritonTool.h:56
AthInfer::TritonTool::m_useSSL
BooleanProperty m_useSSL
Definition: TritonTool.h:52
AthInfer::TritonTool
Definition: TritonTool.h:31
TrigInDetValidation_menu_test.tc
tc
Definition: TrigInDetValidation_menu_test.py:8
AthInfer::TritonTool::m_clientTimeout
FloatProperty m_clientTimeout
Definition: TritonTool.h:50
AthInfer::TritonTool::TritonTool
TritonTool()=delete
AthInfer::TritonTool::inference
virtual StatusCode inference(InputDataMap &inputData, OutputDataMap &outputData) const override final
Definition: TritonTool.cxx:44
const
bool const RAWDATA *ch2 const
Definition: LArRodBlockPhysicsV0.cxx:560
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
AthInfer::TritonTool::prepareInput
StatusCode prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput >> &inputs) const
rerun_display.client
client
Definition: rerun_display.py:31
AthInfer::TritonTool::TritonTool
TritonTool(const TritonTool &)=delete
AthInfer::TritonTool::m_modelVersion
StringProperty m_modelVersion
Definition: TritonTool.h:49
AthInfer::TritonTool::initialize
StatusCode initialize() override final
Definition: TritonTool.cxx:15
AthInfer::TritonTool::operator=
TritonTool & operator=(const TritonTool &)=delete
TritonTool.icc
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
AthInfer::OutputDataMap
std::map< std::string, InferenceData > OutputDataMap
Definition: IAthInferenceTool.h:17
test_pyathena.parent
parent
Definition: test_pyathena.py:15
IAthInferenceTool.h
AthInfer::TritonTool::m_url
StringProperty m_url
Definition: TritonTool.h:51
AthInfer::TritonTool::getClient
tc::InferenceServerGrpcClient * getClient() const
Definition: TritonTool.cxx:24
AthInfer::TritonTool::print
void print() const override final
Definition: TritonTool.h:40
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
AthInfer::InputDataMap
std::map< std::string, InferenceData > InputDataMap
Definition: IAthInferenceTool.h:16
AthInfer::TritonTool::m_port
IntegerProperty m_port
Definition: TritonTool.h:48