ATLAS Offline Software
Loading...
Searching...
No Matches
NNSharingSvc.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef NN_SHARING_SVC_H
6#define NN_SHARING_SVC_H
7
10
13
14namespace FlavorTagInference
15{
16
17 namespace detail {
18 struct NNKey {
19 std::string path;
21 bool operator==(const NNKey&) const;
22 std::size_t hash() const;
23 };
24 struct NNHasher {
25 std::size_t operator()(const NNKey& o) const {
26 return o.hash();
27 }
28 };
29 }
30
31 class NNSharingSvc: public extends<asg::AsgService, INNSharingSvc>
32 {
33 public:
34 using extends::extends; // base class constructor
35#ifndef XAOD_ANALYSIS
36 virtual StatusCode initialize() override;
37#endif
38 virtual std::shared_ptr<const GNN> get(
39 const std::string& nn_name,
40 const GNNOptions& opts) override;
41 private:
42 using val_t = std::shared_ptr<const GNN>;
43 std::unordered_map<detail::NNKey, val_t, detail::NNHasher> m_gnns;
44 std::unordered_map<std::string, val_t> m_base_gnns;
45#ifndef XAOD_ANALYSIS
46 Gaudi::Property<bool> m_useTriton {this, "UseTriton", false
47 , "Toggle running the inference through Triton"};
48 Gaudi::Property<float> m_tritonTimeout {this, "TritonTimeout", 0.f
49 , "Timeout value for Triton client"};
50 Gaudi::Property<int> m_tritonPort {this, "TritonPort", 443
51 , "Triton server port"};
52 Gaudi::Property<std::string> m_tritonUrl {this, "TritonUrl", ""
53 , "Triton server URL"};
54 Gaudi::Property<bool> m_tritonUseSsl {this, "TritonUseSSL", true
55 , "Connect to the Triton server over SSL"};
56
57 // !!! ------ For testing purpose only! -------- !!!
58 // In the long run we need to find some other mechanism
59 // for mapping physical paths to model names
60 std::map<std::string, std::string> m_tritonPathToName;
62 // !!! ----------------------------------------- !!!
63#endif
64 };
65
66}
67
68#endif
virtual StatusCode initialize() override
Gaudi::Property< bool > m_tritonUseSsl
std::unordered_map< std::string, val_t > m_base_gnns
std::map< std::string, std::string > m_tritonPathToName
std::shared_ptr< const GNN > val_t
virtual std::shared_ptr< const GNN > get(const std::string &nn_name, const GNNOptions &opts) override
std::unordered_map< detail::NNKey, val_t, detail::NNHasher > m_gnns
Gaudi::Property< float > m_tritonTimeout
Gaudi::Property< int > m_tritonPort
Gaudi::Property< bool > m_useTriton
Gaudi::Property< std::string > m_tritonUrl
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::size_t operator()(const NNKey &o) const
bool operator==(const NNKey &) const