ATLAS Offline Software
Loading...
Searching...
No Matches
TritonTool.icc
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3// DType traits for Triton
4template <typename T> struct TritonDType;
5template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
6template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
7
8// Input handling
9template <typename T>
10StatusCode AthInfer::TritonTool::prepareInput(const std::string& name,
11 const std::vector<int64_t>& shape,
12 const std::vector<T>& data,
13 std::vector<std::shared_ptr<tc::InferInput>>& inputs) const
14{
15 const char* dtype = TritonDType<T>::value;
16 tc::InferInput* rawInputPtr = nullptr;
17
18 // create the InferInput object with the predefined name, shape, and data type.
19 FAIL_IF_ERR(tc::InferInput::Create(&rawInputPtr, name, shape, dtype),
20 "unable to create input: " + name);
21
22 // Append tensor values for this input from a byte array.
23 // Note: The vector is not copied and so it must not be modified or destroyed
24 // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
25 // Multiple calls can be made to this API to keep adding tensor data for this input.
26 // The data will be delivered in the order it was added.
27 std::shared_ptr<tc::InferInput> input(rawInputPtr);
28 FAIL_IF_ERR(input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
29 data.size() * sizeof(T)),
30 "unable to set input data for: " + name);
31
32 inputs.push_back(std::move(input));
33 return StatusCode::SUCCESS;
34}
35
36// Output handling
37template <typename T>
38StatusCode AthInfer::TritonTool::extractOutput(const std::string& name,
39 const std::shared_ptr<tc::InferResult>& result,
40 std::vector<T>& outputVec) const
41{
42 const uint8_t* rawData = nullptr;
43 size_t size = 0;
44
45 // Get access to the buffer holding raw results of specified output returned by the server.
46 // Note: the buffer is owned by InferResult instance.
47 // Users can copy out the data if required to extend the lifetime.
48 FAIL_IF_ERR(result->RawData(name, &rawData, &size),
49 "unable to get raw output for: " + name);
50
51 outputVec.resize(size / sizeof(T));
52 std::memcpy(outputVec.data(), rawData, size);
53 return StatusCode::SUCCESS;
54}