1 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 // DType traits for Triton
4 template <typename T> struct TritonDType;
5 template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
6 template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
10 StatusCode AthInfer::TritonTool::prepareInput(const std::string& name,
11 const std::vector<int64_t>& shape,
12 const std::vector<T>& data,
13 std::vector<std::shared_ptr<tc::InferInput>>& inputs) const
15 const char* dtype = TritonDType<T>::value;
16 tc::InferInput* rawInputPtr = nullptr;
18 // create the InferInput object with the predefined name, shape, and data type.
19 FAIL_IF_ERR(tc::InferInput::Create(&rawInputPtr, name, shape, dtype),
20 "unable to create input: " + name);
22 // Append tensor values for this input from a byte array.
23 // Note: The vector is not copied and so it must not be modified or destroyed
24 // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
25 // Multiple calls can be made to this API to keep adding tensor data for this input.
26 // The data will be delivered in the order it was added.
27 std::shared_ptr<tc::InferInput> input(rawInputPtr);
28 FAIL_IF_ERR(input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
29 data.size() * sizeof(T)),
30 "unable to set input data for: " + name);
32 inputs.push_back(std::move(input));
33 return StatusCode::SUCCESS;
38 StatusCode AthInfer::TritonTool::extractOutput(const std::string& name,
39 const std::shared_ptr<tc::InferResult>& result,
40 std::vector<T>& outputVec) const
42 const uint8_t* rawData = nullptr;
45 // Get access to the buffer holding raw results of specified output returned by the server.
46 // Note: the buffer is owned by InferResult instance.
47 // Users can copy out the data if required to extend the lifetime.
48 FAIL_IF_ERR(result->RawData(name, &rawData, &size),
49 "unable to get raw output for: " + name);
51 outputVec.resize(size / sizeof(T));
52 std::memcpy(outputVec.data(), rawData, size);
53 return StatusCode::SUCCESS;