ATLAS Offline Software
Loading...
Searching...
No Matches
AthInfer::TritonTool Class Reference

#include <TritonTool.h>

Inheritance diagram for AthInfer::TritonTool:
Collaboration diagram for AthInfer::TritonTool:

Classes

struct  Impl

Public Member Functions

 TritonTool (const std::string &type, const std::string &name, const IInterface *parent)
 Constructor.
virtual ~TritonTool ()
 Destructor.
Function(s) inherited from @c AthAlgTool
virtual StatusCode initialize () override
 Initialize the tool.
Function(s) inherited from @c IAthInferenceTool
virtual StatusCode inference (InputDataMap &inputData, OutputDataMap &outputData) const override final
 Run inference with multiple inputs and multiple outputs.
virtual void print () const override
 Print the tool's properties and configuration.

Private Attributes

std::unique_ptr< Implm_impl
 Pointer to the implementation details.
Tool properties
StringProperty m_modelName {this, "ModelName", "", "Model name"}
IntegerProperty m_port {this, "Port", 8001, "Port ID for Triton server"}
StringProperty m_modelVersion
FloatProperty m_clientTimeout
StringProperty m_url {this, "URL", "", "Triton URL"}
BooleanProperty m_useSSL

Detailed Description

Definition at line 14 of file TritonTool.h.

Constructor & Destructor Documentation

◆ TritonTool()

AthInfer::TritonTool::TritonTool ( const std::string & type,
const std::string & name,
const IInterface * parent )

Constructor.

Definition at line 123 of file TritonTool.cxx.

125 : base_class(type, name, parent) {}

◆ ~TritonTool()

AthInfer::TritonTool::~TritonTool ( )
virtualdefault

Destructor.

Member Function Documentation

◆ inference()

StatusCode AthInfer::TritonTool::inference ( InputDataMap & inputData,
OutputDataMap & outputData ) const
finaloverridevirtual

Run inference with multiple inputs and multiple outputs.

Definition at line 168 of file TritonTool.cxx.

169 {
170
171 assert(m_impl);
172
173 // Create the tensor for the input data.
174 // Use shared_ptr to manage the memory of the InferInput objects.
175 std::vector<std::unique_ptr<tc::InferInput>> inputs;
176 inputs.reserve(inputData.size());
177
178 for (auto& [inputName, inputInfo] : inputData) {
179
180 const std::vector<int64_t>& inputShape = inputInfo.first;
181 const DataVariant& variant = inputInfo.second;
182
183 ATH_CHECK(std::visit(
184 [&](const auto& dataVec) {
185 using T = std::decay_t<decltype(dataVec[0])>;
186 return m_impl->prepareInput<T>(inputName, inputShape, dataVec,
187 inputs);
188 },
189 variant));
190 }
191
192 // construct raw points for inference
193 std::vector<tc::InferInput*> rawInputs;
194 for (auto& input : inputs) {
195 rawInputs.push_back(input.get());
196 }
197
198 // Get the triton client object.
199 tc::InferenceServerGrpcClient* client = nullptr;
200 ATH_CHECK(m_impl->getClient(client, m_url, m_port, m_useSSL));
201 assert(client != nullptr);
202
203 // perform the inference.
204 std::shared_ptr<tc::InferResult> results;
205 tc::Headers http_headers;
206 grpc_compression_algorithm compression_algorithm =
207 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
208
209 if (m_impl->m_parentAsyncAlg == nullptr) {
210 tc::InferResult* rawResultPtr = nullptr;
211 TRITON_CHECK(client->Infer(&rawResultPtr, *(m_impl->m_options), rawInputs,
212 {}, http_headers, compression_algorithm));
213 assert(rawResultPtr != nullptr);
214 results.reset(rawResultPtr);
215 } else {
216 // If m_impl->m_parentAsyncAlg is set, use asynchronous inference
217 using Promise_t = boost::fibers::promise<tc::InferResult*>;
218 using Future_t = boost::fibers::future<tc::InferResult*>;
219 Promise_t promise{};
220 Future_t future = promise.get_future();
221 auto callback = [&promise](tc::InferResult* resultPtr) {
222 assert(resultPtr != nullptr);
223 promise.set_value(resultPtr);
224 };
225 TRITON_CHECK(client->AsyncInfer(callback, *(m_impl->m_options), rawInputs,
226 {}, http_headers, compression_algorithm));
227 results.reset(future.get());
228 ATH_CHECK(m_impl->m_parentAsyncAlg->restoreAfterSuspend());
229 }
230
231 // Get the result of the inference.
232 for (auto& [outputName, outputInfo] : outputData) {
233
234 DataVariant& variant = outputInfo.second;
235
236 ATH_CHECK(std::visit(
237 [&](auto& dataVec) {
238 using T = std::decay_t<decltype(dataVec[0])>;
239 return m_impl->extractOutput<T>(outputName, *results, dataVec);
240 },
241 variant));
242 }
243
244 // Return gracefully.
245 return StatusCode::SUCCESS;
246}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define TRITON_CHECK(EXP)
Shorthand for the Triton client namespace.
StringProperty m_url
Definition TritonTool.h:54
IntegerProperty m_port
Definition TritonTool.h:48
BooleanProperty m_useSSL
Definition TritonTool.h:55
std::unique_ptr< Impl > m_impl
Pointer to the implementation details.
Definition TritonTool.h:63
std::variant< std::vector< float >, std::vector< int64_t > > DataVariant
unsigned long long T
str outputName
Definition lumiFormat.py:65

◆ initialize()

StatusCode AthInfer::TritonTool::initialize ( )
overridevirtual

Initialize the tool.

Definition at line 129 of file TritonTool.cxx.

129 {
130
131 // Set up the implementation object.
132 m_impl = std::make_unique<Impl>(name() + "::Impl");
133 m_impl->m_options = std::make_unique<tc::InferOptions>(m_modelName.value());
134 m_impl->m_options->model_version_ = m_modelVersion;
135 m_impl->m_options->client_timeout_ = m_clientTimeout;
136
137 // Figure out if parent is an AthAsynchronousAlgorithm, and set pointer if it
138 // is
139 const IAlgTool* p = dynamic_cast<const IAlgTool*>(this);
140 // Follow chain of parents up until we hit one that can't be converted to an
141 // IAlgTool
142 const IInterface* myParent = nullptr;
143 while (p != nullptr) {
144 myParent = p->parent();
145 p = dynamic_cast<const IAlgTool*>(myParent);
146 }
147 // If this ultimate ancestor can be converted to an AthAsynchronousAlgorithm,
148 // set the member variable
149 m_impl->m_parentAsyncAlg =
150 dynamic_cast<const AthAsynchronousAlgorithm*>(myParent);
151 if (m_impl->m_parentAsyncAlg != nullptr) {
153 "Owned by an AthAsynchronousAlgorithm, using asynchronous inference");
154 } else {
156 "Not owned by an AthAsynchronousAlgorithm, not using asynchronous "
157 "inference");
158 }
159
160 // Make sure already during initialization that a client can be created.
161 tc::InferenceServerGrpcClient* dummyClient = nullptr;
162 ATH_CHECK(m_impl->getClient(dummyClient, m_url, m_port, m_useSSL));
163
164 // Return gracefully.
165 return StatusCode::SUCCESS;
166}
#define ATH_MSG_INFO(x)
StringProperty m_modelVersion
Definition TritonTool.h:49
FloatProperty m_clientTimeout
Definition TritonTool.h:51
StringProperty m_modelName
Definition TritonTool.h:47

◆ print()

void AthInfer::TritonTool::print ( ) const
overridevirtual

Print the tool's properties and configuration.

Definition at line 248 of file TritonTool.cxx.

248{}

Member Data Documentation

◆ m_clientTimeout

FloatProperty AthInfer::TritonTool::m_clientTimeout
private
Initial value:
{
this, "ClientTimeout", 0,
"Client timeout in milliseconds, 0 for no timeout"}

Definition at line 51 of file TritonTool.h.

51 {
52 this, "ClientTimeout", 0,
53 "Client timeout in milliseconds, 0 for no timeout"};

◆ m_impl

std::unique_ptr<Impl> AthInfer::TritonTool::m_impl
private

Pointer to the implementation details.

Definition at line 63 of file TritonTool.h.

◆ m_modelName

StringProperty AthInfer::TritonTool::m_modelName {this, "ModelName", "", "Model name"}
private

Definition at line 47 of file TritonTool.h.

47{this, "ModelName", "", "Model name"};

◆ m_modelVersion

StringProperty AthInfer::TritonTool::m_modelVersion
private
Initial value:
{this, "ModelVersion", "",
"Model version, empty for latest"}

Definition at line 49 of file TritonTool.h.

49 {this, "ModelVersion", "",
50 "Model version, empty for latest"};

◆ m_port

IntegerProperty AthInfer::TritonTool::m_port {this, "Port", 8001, "Port ID for Triton server"}
private

Definition at line 48 of file TritonTool.h.

48{this, "Port", 8001, "Port ID for Triton server"};

◆ m_url

StringProperty AthInfer::TritonTool::m_url {this, "URL", "", "Triton URL"}
private

Definition at line 54 of file TritonTool.h.

54{this, "URL", "", "Triton URL"};

◆ m_useSSL

BooleanProperty AthInfer::TritonTool::m_useSSL
private
Initial value:
{this, "UseSSL", false,
"Use SSL for Triton server connection"}

Definition at line 55 of file TritonTool.h.

55 {this, "UseSSL", false,
56 "Use SSL for Triton server connection"};

The documentation for this class was generated from the following files: