ATLAS Offline Software
Loading...
Searching...
No Matches
GNNDataLoader.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
8 SaltModelEDMLoaderBase(saltModel),
9 m_gnn_options(gnn_options)
10 {
11 // Create configuration objects for data preprocessing.
12 auto [inputs_config, constituents_configs, fo] =
16 > (
18 m_gnn_options.flip_config,
19 m_gnn_options.variable_remapping
20 );
21 auto salt_model_version = saltModel->getSaltModelVersion();
22
23 for (auto config : constituents_configs){
24 switch (config.type){
25 using enum ConstituentsType;
26 case TRACK:
27 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<TracksLoader>(config, fo));
28 break;
29 case FLOW_ELEMENT:
30 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<FlowElementsLoader>(config, fo));
31 break;
32 case HIT:
33 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<HitsLoader>(config, fo));
34 break;
35 case ELECTRON:
36 addVectorLoader(getVecInputName(salt_model_version, config), std::make_shared<ElectronsLoader>(config, fo));
37 break;
38 default:
39 throw std::runtime_error("Unknown constituent type");
40 }
41 }
42 // Initialize jet and b-tagging input getters.
43 scalarInputName = (salt_model_version == SaltModelVersion::V2 ? "jets" : "jet_features");
44 auto [vars_from_jet, ds] = dataprep::createBvarGetters(inputs_config);
45 data_dependency_names = std::move(ds);
46 ftag_options = std::move(fo);
47 for (const auto& [name, getter]: vars_from_jet) {
49 name,
50 [getter](const xAOD::IParticle* p) {
51 auto jet = dynamic_cast<const xAOD::Jet*>(p);
52 return getter(*jet).second;
53 });
54 }
55 }
56
57std::string FlavorTagInference::GNNDataLoader::getVecInputName(const SaltModelVersion salt_model_version, const ConstituentsInputConfig& constituents_config) const {
58 if (salt_model_version == SaltModelVersion::V2){
59 return constituents_config.output_name;
60 } else {
61 auto out = constituents_config.output_name;
62 out.pop_back();
63 return out + "_features";
64 }
65}
std::string getVecInputName(const SaltModelVersion salt_model_version, const ConstituentsInputConfig &constituents_config) const
GNNDataLoader(ISaltModelPtr salt_model, const GNNOptions &opts)
FTagDataDependencyNames data_dependency_names
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
SaltModelGraphConfig::GraphConfig graph_config
Class providing the definition of the 4-vector interface.
std::tuple< std::vector< std::pair< std::string, internal::VarFromJet > >, FTagDataDependencyNames > createBvarGetters(const std::vector< FTagInputConfig > &inputs)
std::tuple< std::vector< FTagInputConfig >, std::vector< ConstituentsInputConfig >, FTagOptions > createGetterConfig(GraphConfig &graph_config, FlipTagConfig flip_config, std::map< std::string, std::string > remap_scalar)
std::shared_ptr< const ISaltModel > ISaltModelPtr
Definition ISaltModel.h:54
Jet_v1 Jet
Definition of the current "jet version".