ATLAS Offline Software
GNNDataLoader.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 FlavorTagInference::GNNDataLoader::GNNDataLoader(std::shared_ptr<const SaltModel> saltModel, const GNNOptions& gnn_options) :
8  salt_model(saltModel),
9  graph_config(saltModel->getGraphConfig()),
10  m_gnn_options(gnn_options)
11  {
12  // Create configuration objects for data preprocessing.
13  auto [inputs, constituents_configs, fo] = dataprep::createGetterConfig<SaltModelGraphConfig::GraphConfig, SaltModelGraphConfig::OutputNodeConfig>(
15 
16  for (auto config : constituents_configs){
17  switch (config.type){
18  using enum ConstituentsType;
19  case TRACK:
20  constituents_loaders.push_back(std::make_shared<TracksLoader>(config, fo));
21  break;
22  case FLOW_ELEMENT:
23  constituents_loaders.push_back(std::make_shared<FlowElementsLoader>(config, fo));
24  break;
25  case HIT:
26  constituents_loaders.push_back(std::make_shared<HitsLoader>(config, fo));
27  break;
28  case ELECTRON:
29  constituents_loaders.push_back(std::make_shared<ElectronsLoader>(config, fo));
30  break;
31  default:
32  throw std::runtime_error("Unknown constituent type");
33  }
34  }
35  // Initialize jet and b-tagging input getters.
36  auto [vb, vj, ds] = dataprep::createBvarGetters(inputs);
37  vars_from_jet = vj;
39  ftag_options = std::move(fo);
40  }
41 
43  auto jet = dynamic_cast<const xAOD::Jet*>(p);
44 
45  SaltModelData salt_model_data;
46  // jet level inputs
47  std::vector<float> jet_feat;
48  for (const auto& getter: vars_from_jet) {
49  jet_feat.push_back(getter(*jet).second);
50  }
51  std::vector<int64_t> jet_feat_dim = {1, static_cast<int64_t>(jet_feat.size())};
52  Inputs jet_info(jet_feat, jet_feat_dim);
53  if (salt_model->getSaltModelVersion() == SaltModelVersion::V2) {
54  salt_model_data.gnn_inputs.insert({"jets", jet_info});
55  } else {
56  salt_model_data.gnn_inputs.insert({"jet_features", jet_info});
57  }
58 
59  // constituent level inputs
60  for (const auto& loader : constituents_loaders){
61  auto [input_name, input_data, input_objects] = loader->getData(*jet);
62  if (salt_model->getSaltModelVersion() != SaltModelVersion::V2) {
63  input_name.pop_back();
64  input_name.append("_features");
65  }
66  salt_model_data.gnn_inputs.insert({input_name, input_data});
67  salt_model_data.num_inputs += input_data.first.size();
68  salt_model_data.constituents[input_name] = input_objects;
69  }
70  return salt_model_data;
71 }
FlavorTagInference::GNNOptions::flip_config
FlipTagConfig flip_config
Definition: GNNOptions.h:17
checkxAOD.ds
ds
Definition: Tools/PyUtils/bin/checkxAOD.py:260
FlavorTagInference::GNNDataLoader::ftag_options
FTagOptions ftag_options
Definition: GNNDataLoader.h:31
FlavorTagInference::GNNDataLoader::GNNDataLoader
GNNDataLoader(std::shared_ptr< const SaltModel > salt_model, const GNNOptions &opts)
Definition: GNNDataLoader.cxx:7
FlavorTagInference::GNNDataLoader::data_dependency_names
FTagDataDependencyNames data_dependency_names
Definition: GNNDataLoader.h:34
FlavorTagInference::GNNDataLoader::graph_config
SaltModelGraphConfig::GraphConfig graph_config
Definition: GNNDataLoader.h:30
xAOD::IParticle
Class providing the definition of the 4-vector interface.
Definition: Event/xAOD/xAODBase/xAODBase/IParticle.h:41
GNNDataLoader.h
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
FlavorTagInference::GNNDataLoader::salt_model
std::shared_ptr< const SaltModel > salt_model
Definition: GNNDataLoader.h:29
FlavorTagInference::SaltModelData::gnn_inputs
SaltModelInputs gnn_inputs
Definition: GNNDataLoader.h:20
FlavorTagInference::dataprep::createBvarGetters
std::tuple< std::vector< internal::VarFromBTag >, std::vector< internal::VarFromJet >, FTagDataDependencyNames > createBvarGetters(const std::vector< FTagInputConfig > &inputs)
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/Root/DataPrepUtilities.cxx:358
FlavorTagInference::ConstituentsType::HIT
@ HIT
FlavorTagInference::ConstituentsType::FLOW_ELEMENT
@ FLOW_ELEMENT
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:209
jet
Definition: JetCalibTools_PlotJESFactors.cxx:23
FlavorTagInference::SaltModelData::constituents
std::map< std::string, std::vector< const xAOD::IParticle * > > constituents
Definition: GNNDataLoader.h:22
FlavorTagInference::SaltModelVersion::V2
@ V2
FlavorTagInference::SaltModelData
Definition: GNNDataLoader.h:19
FlavorTagInference::GNNDataLoader::constituents_loaders
std::vector< std::shared_ptr< IConstituentsLoader > > constituents_loaders
Definition: GNNDataLoader.h:33
FlavorTagInference::ConstituentsType
ConstituentsType
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:48
FlavorTagInference::GNNOptions
Definition: GNNOptions.h:16
FlavorTagInference::GNNDataLoader::vars_from_jet
std::vector< internal::VarFromJet > vars_from_jet
Definition: GNNDataLoader.h:32
xAOD::Jet_v1
Class describing a jet.
Definition: Jet_v1.h:57
FlavorTagInference::GNNDataLoader::loadInputs
SaltModelData loadInputs(const xAOD::IParticle *p) const
Definition: GNNDataLoader.cxx:42
FlavorTagInference::GNNDataLoader::m_gnn_options
GNNOptions m_gnn_options
Definition: GNNDataLoader.h:36
FlavorTagInference::GNNOptions::track_link_type
TrackLinkType track_link_type
Definition: GNNOptions.h:19
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
FlavorTagInference::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: GNNDataLoader.h:16
FlavorTagInference::GNNOptions::variable_remapping
std::map< std::string, std::string > variable_remapping
Definition: GNNOptions.h:18
FlavorTagInference::SaltModelData::num_inputs
size_t num_inputs
Definition: GNNDataLoader.h:21