23 using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
29 std::map<std::string, std::vector<const xAOD::IParticle*>>
constituents;
46 void addVectorLoader(
const std::string& vecName, std::shared_ptr<IConstituentsLoader> loader) {
53 std::vector<float> scalar_feat;
55 std::string varName = varLoader.first;
56 scalar_feat.push_back(varLoader.second(p));
58 std::vector<int64_t> scalar_feat_dim = {1,
static_cast<int64_t
>(scalar_feat.size())};
59 Inputs scalar_inputs(scalar_feat, scalar_feat_dim);
64 std::string input_name = loader.first;
65 auto [input_data, input_objects] = loader.second->getData(*p);
67 salt_model_data.
gnn_inputs.insert({input_name, input_data});
68 salt_model_data.
num_inputs += input_data.first.size();
69 salt_model_data.
constituents[input_name] = input_objects;
71 return salt_model_data;
76 std::cout <<
"-------- Dumping GNN Input Data --------" << std::endl;
78 for (
const auto& [name, inputs] : gnn_inputs) {
79 std::cout <<
"Input Name: " << name << std::endl;
80 std::cout <<
" vec floats: ";
81 for (
const auto& feature : inputs.first) {
82 std::cout << feature <<
" ";
84 std::cout << std::endl;
85 std::cout <<
" vec ints : ";
86 for (
const auto&
id : inputs.second) {
87 std::cout <<
id <<
" ";
89 std::cout << std::endl;
91 std::cout <<
"---------- END GNN Input Data ----------" << std::endl;
virtual SaltModelData loadInputs(const xAOD::IParticle *p) const final
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
void DumpGnnInputs(const SaltModelInputs &gnn_inputs) const
SaltModelGraphConfig::GraphConfig graph_config
SaltModelEDMLoaderBase(ISaltModelPtr salt_model)
std::string scalarInputName
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders
Class providing the definition of the 4-vector interface.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::map< std::string, Inputs > SaltModelInputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
std::shared_ptr< const ISaltModel > ISaltModelPtr
SaltModelInputs gnn_inputs
std::map< std::string, std::vector< const xAOD::IParticle * > > constituents