ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelTriton.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef FLAVORTAGINFERENCE_SALTMODELTRITON_H
6#define FLAVORTAGINFERENCE_SALTMODELTRITON_H
7
14
16
17#include "grpc_client.h"
18#include "grpc_service.pb.h"
19
20namespace Ort {
21 class Session;
22}
23
24namespace tc = triton::client;
25
26namespace FlavorTagInference {
27
28 class SaltModelTriton final : public ISaltModel
29 {
30 public:
31 SaltModelTriton(const std::string& path_to_onnx
32 , const std::string& model_name
33 , float client_timeout
34 , int port
35 , const std::string& url
36 , bool useSSL);
37
38 virtual InferenceOutput runInference(std::map<std::string, Inputs>& gnn_inputs) const override;
39
40 virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override;
41 virtual const OutputConfig& getOutputConfig() const override;
42 virtual SaltModelVersion getSaltModelVersion() const override;
43 virtual const std::string& getModelName() const override;
44
45 private:
46 const nlohmann::json loadMetadata(const std::string& key, const Ort::Session* session) const;
47 const std::string determineModelType(const Ort::Session* session) const;
48 tc::InferenceServerGrpcClient* getClient() const;
49
50 nlohmann::json m_metadata;
51
53 std::string m_model_name;
54 std::string m_model_type;
56
58
59 std::unique_ptr<tc::InferOptions> m_options;
60 float m_clientTimeout{0.f};
61 int m_port{8001};
62 std::string m_url{};
63 bool m_useSSL{false};
64 }; // Class SaltModelnTriton
65} // end of FlavorTagInference namespace
66
67#endif
68
static Double_t tc
tc::InferenceServerGrpcClient * getClient() const
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL)
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
const nlohmann::json loadMetadata(const std::string &key, const Ort::Session *session) const
virtual const OutputConfig & getOutputConfig() const override
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
std::unique_ptr< tc::InferOptions > m_options
const std::string determineModelType(const Ort::Session *session) const
virtual SaltModelVersion getSaltModelVersion() const override
virtual const std::string & getModelName() const override
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36