ATLAS Offline Software
Loading...
Searching...
No Matches
GNNDataLoader.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6
8 SaltModelEDMLoaderBase(saltModel),
9 m_gnn_options(gnn_options)
10 {
11 // Create configuration objects for data preprocessing.
12 auto [inputs_config, constituents_configs, fo] =
16 > (
18 m_gnn_options.flip_config,
19 m_gnn_options.variable_remapping,
20 saltModel->getModelName()
21 );
22 auto salt_model_version = saltModel->getSaltModelVersion();
23
24 for (auto config : constituents_configs){
25 switch (config.type){
26 using enum ConstituentsType;
27 case TRACK:
28 addVectorLoader(getVecInputName(salt_model_version, config),
29 std::make_shared<TracksLoader>(config, fo));
30 break;
31 case FLOW_ELEMENT:
32 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<FlowElementsLoader>(config, fo));
33 break;
34 case HIT:
35 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<HitsLoader>(config, fo));
36 break;
37 case ELECTRON:
38 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<ElectronsLoader>(config, fo));
39 break;
40 case MUON:
41 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<MuonsLoader>(config, fo));
42 break;
43 case CALO_CLUSTER:
44 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<CaloClusterLoader>(config, fo));
45 break;
46 case TOWER:
47 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<TowerLoader>(config, fo));
48 break;
49 default:
50 throw std::runtime_error("Unknown constituent type");
51 }
52 }
53 // Initialize jet and b-tagging input getters.
54 scalarInputName = (salt_model_version == SaltModelVersion::V2 ? "jets" : "jet_features");
55 auto [vars_from_jet, ds] = dataprep::createBvarGetters(inputs_config);
56 data_dependency_names = std::move(ds);
57 ftag_options = std::move(fo);
58 for (const auto& [name, getter]: vars_from_jet) {
60 name,
61 [getter](const xAOD::IParticle* p) {
62 auto jet = dynamic_cast<const xAOD::Jet*>(p);
63 return getter(*jet).second;
64 });
65 }
66 }
67
68std::string FlavorTagInference::GNNDataLoader::getVecInputName(const SaltModelVersion salt_model_version, const ConstituentsInputConfig& constituents_config) const {
69 if (salt_model_version == SaltModelVersion::V2){
70 return constituents_config.output_name;
71 } else {
72 auto out = constituents_config.output_name;
73 out.pop_back();
74 return out + "_features";
75 }
76}
std::string getVecInputName(const SaltModelVersion salt_model_version, const ConstituentsInputConfig &constituents_config) const
GNNDataLoader(ISaltModelPtr salt_model, const GNNOptions &opts)
FTagDataDependencyNames data_dependency_names
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
SaltModelGraphConfig::GraphConfig graph_config
Class providing the definition of the 4-vector interface.
std::tuple< std::vector< FTagInputConfig >, std::vector< ConstituentsInputConfig >, FTagOptions > createGetterConfig(GraphConfig &graph_config, FlipTagConfig flip_config, std::map< std::string, std::string > remap_scalar, const std::string &object_link_prefix)
std::tuple< std::vector< std::pair< std::string, internal::VarFromJet > >, FTagDataDependencyNames > createBvarGetters(const std::vector< FTagInputConfig > &inputs)
std::shared_ptr< const ISaltModel > ISaltModelPtr
Definition ISaltModel.h:54
Jet_v1 Jet
Definition of the current "jet version".