ATLAS Offline Software
Public Member Functions | Protected Member Functions | Protected Attributes | Private Member Functions | Private Attributes | List of all members
AthInfer::TritonTool Class Reference

#include <TritonTool.h>

Inheritance diagram for AthInfer::TritonTool:
Collaboration diagram for AthInfer::TritonTool:

Public Member Functions

 TritonTool (const std::string &type, const std::string &name, const IInterface *parent)
 
StatusCode initialize () override final
 
virtual StatusCode inference (InputDataMap &inputData, OutputDataMap &outputData) const override final
 
void print () const override final
 

Protected Member Functions

 TritonTool ()=delete
 
 TritonTool (const TritonTool &)=delete
 
TritonTooloperator= (const TritonTool &)=delete
 

Protected Attributes

StringProperty m_modelName {this, "ModelName", "", "Model name"}
 
IntegerProperty m_port {this, "Port", 8001, "Port ID for Triton server"}
 
StringProperty m_modelVersion {this, "ModelVersion", "", "Model version, empty for latest"}
 
FloatProperty m_clientTimeout {this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"}
 
StringProperty m_url {this, "URL", "", "Triton URL"}
 
BooleanProperty m_useSSL {this, "UseSSL", false, "Use SSL for Triton server connection"}
 

Private Member Functions

tc::InferenceServerGrpcClient * getClient () const
 
template<typename T >
StatusCode prepareInput (const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput >> &inputs) const
 
template<typename T >
StatusCode extractOutput (const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec) const
 

Private Attributes

std::unique_ptr< tc::InferOptions > m_options
 

Detailed Description

Definition at line 30 of file TritonTool.h.

Constructor & Destructor Documentation

◆ TritonTool() [1/3]

AthInfer::TritonTool::TritonTool ( const std::string &  type,
const std::string &  name,
const IInterface *  parent 
)

Definition at line 7 of file TritonTool.cxx.

10  : base_class(type, name, parent)
11 {
12  declareInterface<AthInfer::IAthInferenceTool>(this);
13 }

◆ TritonTool() [2/3]

AthInfer::TritonTool::TritonTool ( )
protecteddelete

◆ TritonTool() [3/3]

AthInfer::TritonTool::TritonTool ( const TritonTool )
protecteddelete

Member Function Documentation

◆ extractOutput()

template<typename T >
StatusCode AthInfer::TritonTool::extractOutput ( const std::string &  name,
const std::shared_ptr< tc::InferResult > &  result,
std::vector< T > &  outputVec 
) const
private

◆ getClient()

tc::InferenceServerGrpcClient * AthInfer::TritonTool::getClient ( ) const
private

Definition at line 24 of file TritonTool.cxx.

24  {
25  thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
26  if (!threadClient) {
27  std::string url = m_url.value() + ":" + std::to_string(m_port); // always use the gRPC port
28 
29  bool verbose = false;
30 
31  tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, verbose, m_useSSL);
32  if (!err.IsOk()) {
33  ATH_MSG_ERROR("Failed to create Triton gRPC client for model: " + m_modelName.value() + " at url: " + url);
34  ATH_MSG_ERROR("Error message: " + err.Message());
35  return nullptr;
36  }
37 
38  ATH_MSG_INFO("Triton client created for model: "+ m_modelName.value() + " at url: "+ url);
39 
40  }
41  return threadClient.get();
42 }

◆ inference()

StatusCode AthInfer::TritonTool::inference ( InputDataMap inputData,
OutputDataMap outputData 
) const
finaloverridevirtual

Definition at line 44 of file TritonTool.cxx.

44  {
45 
46  // Create the tensor for the input data.
47  // Use shared_ptr to manage the memory of the InferInput objects.
48  std::vector<std::shared_ptr<tc::InferInput> > inputs_;
49  inputs_.reserve(inputData.size());
50 
51  for (auto& [inputName, inputInfo]: inputData) {
52  const std::vector<int64_t>& inputShape = inputInfo.first;
53  const auto& variant = inputInfo.second;
54 
55  const auto status = std::visit([&](const auto& dataVec) {
56  using T = std::decay_t<decltype(dataVec[0])>;
57  return prepareInput<T>(inputName, inputShape, dataVec, inputs_);
58  }, variant);
59 
60  if (status != StatusCode::SUCCESS) return status;
61  }
62 
63  // construct raw points for inference
64  std::vector<tc::InferInput*> rawInputs;
65  for (auto& input: inputs_) {
66  rawInputs.push_back(input.get());
67  }
68 
69  // perform the inference.
70  tc::InferResult* rawResultPtr = nullptr;
71  tc::Headers http_headers;
72  grpc_compression_algorithm compression_algorithm =
73  grpc_compression_algorithm::GRPC_COMPRESS_NONE;
74 
76  getClient()->Infer(
77  &rawResultPtr, *m_options, rawInputs, {}, http_headers, compression_algorithm),
78  "unable to run model "+ m_modelName.value() + " error: " + err.Message()
79  );
80 
81  std::shared_ptr<tc::InferResult> results(rawResultPtr);
82 
83  // Get the result of the inference.
84  for (auto& [outputName, outputInfo]: outputData) {
85  auto& variant = outputInfo.second;
86 
87  const auto status = std::visit([&](auto& dataVec) {
88  using T = std::decay_t<decltype(dataVec[0])>;
89  return extractOutput<T>(outputName, results, dataVec);
90  }, variant);
91 
92  if (status != StatusCode::SUCCESS) return status;
93  }
94  return StatusCode::SUCCESS;
95 }

◆ initialize()

StatusCode AthInfer::TritonTool::initialize ( )
finaloverride

Definition at line 15 of file TritonTool.cxx.

15  {
16 
17  m_options = std::make_unique<tc::InferOptions>(m_modelName.value());
18  m_options->model_version_ = m_modelVersion;
19  m_options->client_timeout_ = m_clientTimeout;
20 
21  return getClient()? StatusCode::SUCCESS : StatusCode::FAILURE;
22 }

◆ operator=()

TritonTool& AthInfer::TritonTool::operator= ( const TritonTool )
protecteddelete

◆ prepareInput()

template<typename T >
StatusCode AthInfer::TritonTool::prepareInput ( const std::string &  name,
const std::vector< int64_t > &  shape,
const std::vector< T > &  data,
std::vector< std::shared_ptr< tc::InferInput >> &  inputs 
) const
private

◆ print()

void AthInfer::TritonTool::print ( ) const
inlinefinaloverride

Definition at line 40 of file TritonTool.h.

40 {} // nothing to print, but required by the interface.

Member Data Documentation

◆ m_clientTimeout

FloatProperty AthInfer::TritonTool::m_clientTimeout {this, "ClientTimeout", 0, "Client timeout in milliseconds, 0 for no timeout"}
protected

Definition at line 50 of file TritonTool.h.

◆ m_modelName

StringProperty AthInfer::TritonTool::m_modelName {this, "ModelName", "", "Model name"}
protected

Definition at line 47 of file TritonTool.h.

◆ m_modelVersion

StringProperty AthInfer::TritonTool::m_modelVersion {this, "ModelVersion", "", "Model version, empty for latest"}
protected

Definition at line 49 of file TritonTool.h.

◆ m_options

std::unique_ptr<tc::InferOptions> AthInfer::TritonTool::m_options
private

Definition at line 56 of file TritonTool.h.

◆ m_port

IntegerProperty AthInfer::TritonTool::m_port {this, "Port", 8001, "Port ID for Triton server"}
protected

Definition at line 48 of file TritonTool.h.

◆ m_url

StringProperty AthInfer::TritonTool::m_url {this, "URL", "", "Triton URL"}
protected

Definition at line 51 of file TritonTool.h.

◆ m_useSSL

BooleanProperty AthInfer::TritonTool::m_useSSL {this, "UseSSL", false, "Use SSL for Triton server connection"}
protected

Definition at line 52 of file TritonTool.h.


The documentation for this class was generated from the following files:
AthInfer::TritonTool::m_modelName
StringProperty m_modelName
Definition: TritonTool.h:47
AthInfer::TritonTool::m_options
std::unique_ptr< tc::InferOptions > m_options
Definition: TritonTool.h:56
AthInfer::TritonTool::m_useSSL
BooleanProperty m_useSSL
Definition: TritonTool.h:52
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthInfer::TritonTool::m_clientTimeout
FloatProperty m_clientTimeout
Definition: TritonTool.h:50
physics_parameters.url
string url
Definition: physics_parameters.py:27
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
AthInfer::TritonTool::m_modelVersion
StringProperty m_modelVersion
Definition: TritonTool.h:49
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
dqt_zlumi_pandas.err
err
Definition: dqt_zlumi_pandas.py:183
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
test_pyathena.parent
parent
Definition: test_pyathena.py:15
add-xsec-uncert-quadrature-N.results
dictionary results
Definition: add-xsec-uncert-quadrature-N.py:39
AthInfer::TritonTool::m_url
StringProperty m_url
Definition: TritonTool.h:51
AthInfer::TritonTool::getClient
tc::InferenceServerGrpcClient * getClient() const
Definition: TritonTool.cxx:24
FAIL_IF_ERR
#define FAIL_IF_ERR(X, MSG)
Definition: TritonTool.h:21
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lumiFormat.outputName
string outputName
Definition: lumiFormat.py:65
python.TriggerHandler.verbose
verbose
Definition: TriggerHandler.py:296
L1Topo::Error
Error
The different types of error that can be flagged in the L1TopoRDO.
Definition: Error.h:16
merge.status
status
Definition: merge.py:16
TSU::T
unsigned long long T
Definition: L1TopoDataTypes.h:35
AthInfer::TritonTool::m_port
IntegerProperty m_port
Definition: TritonTool.h:48