22 using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
28 std::map<std::string, std::vector<const xAOD::IParticle*>>
constituents;
45 void addVectorLoader(
const std::string& vecName, std::shared_ptr<IConstituentsLoader> loader) {
52 std::vector<float> scalar_feat;
54 std::string varName = varLoader.first;
55 scalar_feat.push_back(varLoader.second(p));
57 std::vector<int64_t> scalar_feat_dim = {1,
static_cast<int64_t
>(scalar_feat.size())};
58 Inputs scalar_inputs(scalar_feat, scalar_feat_dim);
63 std::string input_name = loader.first;
64 auto [input_data, input_objects] = loader.second->getData(*p);
66 salt_model_data.
gnn_inputs.insert({input_name, input_data});
67 salt_model_data.
num_inputs += input_data.first.size();
68 salt_model_data.
constituents[input_name] = input_objects;
70 return salt_model_data;
75 std::cout <<
"-------- Dumping GNN Input Data --------" << std::endl;
77 for (
const auto& [name, inputs] : gnn_inputs) {
78 std::cout <<
"Input Name: " << name << std::endl;
79 std::cout <<
" vec floats: ";
80 for (
const auto& feature : inputs.first) {
81 std::cout << feature <<
" ";
83 std::cout << std::endl;
84 std::cout <<
" vec ints : ";
85 for (
const auto&
id : inputs.second) {
86 std::cout <<
id <<
" ";
88 std::cout << std::endl;
90 std::cout <<
"---------- END GNN Input Data ----------" << std::endl;
virtual SaltModelData loadInputs(const xAOD::IParticle *p) const final
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
void DumpGnnInputs(const SaltModelInputs &gnn_inputs) const
SaltModelGraphConfig::GraphConfig graph_config
SaltModelEDMLoaderBase(ISaltModelPtr salt_model)
std::string scalarInputName
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders
Class providing the definition of the 4-vector interface.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::map< std::string, Inputs > SaltModelInputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
std::shared_ptr< const ISaltModel > ISaltModelPtr
SaltModelInputs gnn_inputs
std::map< std::string, std::vector< const xAOD::IParticle * > > constituents