ATLAS Offline Software
TritonTool.icc
Go to the documentation of this file.
1 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 // DType traits for Triton
4 template <typename T> struct TritonDType;
5 template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
6 template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
7 
8 // Input handling
9 template <typename T>
10 StatusCode AthInfer::TritonTool::prepareInput(const std::string& name,
11  const std::vector<int64_t>& shape,
12  const std::vector<T>& data,
13  std::vector<std::shared_ptr<tc::InferInput>>& inputs) const
14 {
15  const char* dtype = TritonDType<T>::value;
16  tc::InferInput* rawInputPtr = nullptr;
17 
18  // create the InferInput object with the predefined name, shape, and data type.
19  FAIL_IF_ERR(tc::InferInput::Create(&rawInputPtr, name, shape, dtype),
20  "unable to create input: " + name);
21 
22  // Append tensor values for this input from a byte array.
23  // Note: The vector is not copied and so it must not be modified or destroyed
24  // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
25  // Multiple calls can be made to this API to keep adding tensor data for this input.
26  // The data will be delivered in the order it was added.
27  std::shared_ptr<tc::InferInput> input(rawInputPtr);
28  FAIL_IF_ERR(input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
29  data.size() * sizeof(T)),
30  "unable to set input data for: " + name);
31 
32  inputs.push_back(std::move(input));
33  return StatusCode::SUCCESS;
34 }
35 
36 // Output handling
37 template <typename T>
38 StatusCode AthInfer::TritonTool::extractOutput(const std::string& name,
39  const std::shared_ptr<tc::InferResult>& result,
40  std::vector<T>& outputVec) const
41 {
42  const uint8_t* rawData = nullptr;
43  size_t size = 0;
44 
45  // Get access to the buffer holding raw results of specified output returned by the server.
46  // Note: the buffer is owned by InferResult instance.
47  // Users can copy out the data if required to extend the lifetime.
48  FAIL_IF_ERR(result->RawData(name, &rawData, &size),
49  "unable to get raw output for: " + name);
50 
51  outputVec.resize(size / sizeof(T));
52  std::memcpy(outputVec.data(), rawData, size);
53  return StatusCode::SUCCESS;
54 }