ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelTriton.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef FLAVORTAGINFERENCE_SALTMODELTRITON_H
6#define FLAVORTAGINFERENCE_SALTMODELTRITON_H
7
14
16
17#include "grpc_client.h"
18#include "grpc_service.pb.h"
19
20namespace Ort {
21 class Session;
22}
23
24namespace tc = triton::client;
25
26namespace FlavorTagInference {
27
28 class SaltModelTriton final : public ISaltModel
29 {
30 public:
31 SaltModelTriton(const std::string& path_to_onnx
32 , const std::string& model_name
33 , float client_timeout
34 , int port
35 , const std::string& url
36 , bool useSSL
37 , const std::string& bearer = "");
38
39 virtual InferenceOutput runInference(std::map<std::string, Inputs>& gnn_inputs) const override;
40
41 virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override;
42 virtual const OutputConfig& getOutputConfig() const override;
43 virtual SaltModelVersion getSaltModelVersion() const override;
44 virtual const std::string& getModelName() const override;
45
46 private:
47 const nlohmann::json loadMetadata(const std::string& key, const Ort::Session* session) const;
48 const std::string determineModelType(const Ort::Session* session) const;
49 tc::InferenceServerGrpcClient* getClient() const;
50
51 nlohmann::json m_metadata;
52
54 std::string m_model_name;
55 std::string m_model_type;
57
59
60 std::unique_ptr<tc::InferOptions> m_options;
61 float m_clientTimeout{0.f};
62 int m_port{8001};
63 std::string m_url{};
64 bool m_useSSL{false};
65 std::string m_bearer{};
66 }; // Class SaltModelnTriton
67} // end of FlavorTagInference namespace
68
69#endif
70
static Double_t tc
tc::InferenceServerGrpcClient * getClient() const
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
const nlohmann::json loadMetadata(const std::string &key, const Ort::Session *session) const
virtual const OutputConfig & getOutputConfig() const override
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
std::unique_ptr< tc::InferOptions > m_options
const std::string determineModelType(const Ort::Session *session) const
virtual SaltModelVersion getSaltModelVersion() const override
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL, const std::string &bearer="")
virtual const std::string & getModelName() const override
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36