ATLAS Offline Software
Loading...
Searching...
No Matches
NNSharingSvc.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6#ifndef XAOD_ANALYSIS
7#include "SaltModelTriton.h"
8#endif
10#include "src/hash.h"
11
12namespace FlavorTagInference {
13
14 namespace detail {
15 std::size_t NNKey::hash() const {
16 return combine(getHash(path), getHash(opts));
17 }
18
19 // allow NNKey to be used as unordered map key
20 bool NNKey::operator==(const NNKey& key) const {
21 return path == key.path && opts == key.opts;
22 }
23 }
24
25 std::shared_ptr<const GNN> NNSharingSvc::get(
26 const std::string& nn_name,
27 const GNNOptions& opts) {
28 detail::NNKey key{nn_name, opts};
29 if (m_gnns.count(key)) {
30 ATH_MSG_INFO("getting " << nn_name << " from cached NNs");
31 return m_gnns.at(key);
32 } else if (m_base_gnns.count(nn_name) ) {
33 ATH_MSG_INFO("adapting " << nn_name << " from cached NNs, new opts");
34 auto nn = std::make_shared<const GNN>(*m_base_gnns.at(nn_name), opts);
35 m_gnns[key] = nn;
36 return nn;
37 }
38 std::shared_ptr<const GNN> nn;
39#ifndef XAOD_ANALYSIS
40 auto it = m_tritonPathToName.find(nn_name);
41 if(m_useTriton && it!=m_tritonPathToName.end()) {
42 ATH_MSG_INFO("building " << nn_name << " from onnx file to run with Triton");
43 //using namespace FlavorTagInference;
44 std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_name);
45 auto saltSharedTriton = std::make_shared<const SaltModelTriton>(fullPathToOnnxFile
46 , it->second
51 ISaltModelPtr saltShared = saltSharedTriton;
52 nn = std::make_shared<const GNN>(saltShared, opts);
53 }
54 else {
55 ATH_MSG_INFO("building " << nn_name << " from onnx file");
56 nn = std::make_shared<const GNN>(nn_name, opts);
57 }
58#else
59 ATH_MSG_INFO("building " << nn_name << " from onnx file");
60 nn = std::make_shared<const GNN>(nn_name, opts);
61#endif
62 m_base_gnns[nn_name] = nn;
63 m_gnns[key] = nn;
64 return nn;
65 }
66
67#ifndef XAOD_ANALYSIS
70 return StatusCode::SUCCESS;
71 }
72
73 // This is a quick solution for the initial stage of testing.
74 // In the long run we need to find some other mechanism for mapping
75 // model paths to model names
78 {"BTagging/20250527/GN3V01/antikt4empflow/network.onnx"
79 , "BTagging_network_93a858f5c730"},
80 {"BTagging/20231205/GN2v01/antikt4empflow/network_fold0.onnx"
81 , "BTagging_network_fold0_4812578c733e"},
82 {"BTagging/20231205/GN2v01/antikt4empflow/network_fold1.onnx"
83 , "BTagging_network_fold1_9280d77c131c"},
84 {"BTagging/20231205/GN2v01/antikt4empflow/network_fold2.onnx"
85 , "BTagging_network_fold2_25c6ad03db10"},
86 {"BTagging/20231205/GN2v01/antikt4empflow/network_fold3.onnx"
87 , "BTagging_network_fold3_0558b4924c49"},
88 {"BTagging/20250213/GN3V00/antikt4empflow/network.onnx"
89 , "BTagging_network_cce6be90efd1"},
90 {"BTagging/20250213/GN3PflowMuonsV00/antikt4empflow/network.onnx"
91 , "BTagging_network_d2138c4252e6"},
92// {"BTagging/20230705/gn2xv01/antikt10ufo/network.onnx" << This model is commented out because at the time of submitting
93// , "BTagging_network_9f8aadb82b76"}, << it did not work on Triton. The code falls back to direct ONNX reading
94 {"BTagging/20240925/GN2Xv02/antikt10ufo/network.onnx"
95 , "BTagging_network_09c2dddf15bf"},
96 {"BTagging/20250310/GN2XTauV00/antikt10ufo/network.onnx"
97 , "BTagging_network_e8d5e9a3059b"},
98 {"BTagging/20250912/GN3XPV01/antikt10ufo/network.onnx"
99 , "BTagging_network_08105bb8c1d6"},
100 {"BTagging/20250912/GN3EPCLV01/antikt4empflow/network.onnx"
101 , "BTagging_network_8085e6c5717c"},
102 {"JetCalibTools/CalibArea-00-04-83/CalibrationFactors/bbJESJMS_calibFactors_R22_MC20_CSSKUFO_bJR10v00Ext_20250212.onnx"
103 , "JetCalibTools_bbJESJMS_calibFactor_80138d800ac5"},
104 {"JetCalibTools/CalibArea-00-04-83/CalibrationFactors/bbJESJMS_calibFactors_R22_MC20MC23_CSSKUFO_bJR10v01_20250212.onnx"
105 , "JetCalibTools_bbJESJMS_calibFactor_fefb85f452f9"}
106 };
107 }
108#endif
109}
#define ATH_MSG_INFO(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
virtual StatusCode initialize() override
Gaudi::Property< bool > m_tritonUseSsl
std::unordered_map< std::string, val_t > m_base_gnns
std::map< std::string, std::string > m_tritonPathToName
virtual std::shared_ptr< const GNN > get(const std::string &nn_name, const GNNOptions &opts) override
std::unordered_map< detail::NNKey, val_t, detail::NNHasher > m_gnns
Gaudi::Property< float > m_tritonTimeout
Gaudi::Property< int > m_tritonPort
Gaudi::Property< bool > m_useTriton
Gaudi::Property< std::string > m_tritonUrl
This file contains "getter" functions used for accessing tagger inputs from the EDM.
size_t combine(size_t lhs, size_t rhs)
Definition hash.h:21
std::size_t getHash(const T &obj)
Definition hash.h:13
std::shared_ptr< const ISaltModel > ISaltModelPtr
Definition ISaltModel.h:54
bool operator==(const NNKey &) const