ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelEDMLoaderBase.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#pragma once
9
14
15#include <map>
16#include <vector>
17#include <utility>
18#include <functional>
19
20namespace FlavorTagInference {
21
22 using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
23 using SaltModelInputs = std::map<std::string, Inputs>;
24
27 size_t num_inputs = 0;
28 std::map<std::string, std::vector<const xAOD::IParticle*>> constituents;
29 };
30
32 public:
34 graph_config(salt_model->getGraphConfig()) {};
36 std::string scalarInputName;
37 std::vector<std::pair<std::string /* varName */, std::function<float(const xAOD::IParticle* /* parent */)>>> scalarVarLoaders;
38 std::map<std::string /* vecInputName */, std::shared_ptr<IConstituentsLoader>> vectorVarLoaders;
39
40
41 void addScalarLoader(const std::string& varName, std::function<float(const xAOD::IParticle*)> loader) {
42 scalarVarLoaders.emplace_back(varName, loader);
43 }
44
45 void addVectorLoader(const std::string& vecName, std::shared_ptr<IConstituentsLoader> loader) {
46 vectorVarLoaders.try_emplace(vecName, std::move(loader));
47 }
48
49 virtual SaltModelData loadInputs(const xAOD::IParticle* p) const final {
50 SaltModelData salt_model_data;
51 // loading scalar inputs.
52 std::vector<float> scalar_feat;
53 for (const auto& varLoader : scalarVarLoaders) {
54 std::string varName = varLoader.first;
55 scalar_feat.push_back(varLoader.second(p));
56 }
57 std::vector<int64_t> scalar_feat_dim = {1, static_cast<int64_t>(scalar_feat.size())};
58 Inputs scalar_inputs(scalar_feat, scalar_feat_dim);
59 salt_model_data.gnn_inputs.insert({scalarInputName, scalar_inputs});
60
61 //load vector inputs.
62 for (auto loader : vectorVarLoaders) {
63 std::string input_name = loader.first;
64 auto [input_data, input_objects] = loader.second->getData(*p);
65
66 salt_model_data.gnn_inputs.insert({input_name, input_data});
67 salt_model_data.num_inputs += input_data.first.size();
68 salt_model_data.constituents[input_name] = input_objects;
69 }
70 return salt_model_data;
71 }
72
73 void DumpGnnInputs(const SaltModelInputs& gnn_inputs) const {
74 // Implementation for dumping GNN input data
75 std::cout << "-------- Dumping GNN Input Data --------" << std::endl;
76
77 for (const auto& [name, inputs] : gnn_inputs) {
78 std::cout << "Input Name: " << name << std::endl;
79 std::cout << " vec floats: ";
80 for (const auto& feature : inputs.first) {
81 std::cout << feature << " ";
82 }
83 std::cout << std::endl;
84 std::cout << " vec ints : ";
85 for (const auto& id : inputs.second) {
86 std::cout << id << " ";
87 }
88 std::cout << std::endl;
89 }
90 std::cout << "---------- END GNN Input Data ----------" << std::endl;
91 }
92 }; // class SaltModelEDMLoaderBase
93} // namespace FlavorTagInference
virtual SaltModelData loadInputs(const xAOD::IParticle *p) const final
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
void DumpGnnInputs(const SaltModelInputs &gnn_inputs) const
SaltModelGraphConfig::GraphConfig graph_config
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders
Class providing the definition of the 4-vector interface.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::map< std::string, Inputs > SaltModelInputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
std::shared_ptr< const ISaltModel > ISaltModelPtr
Definition ISaltModel.h:54
std::map< std::string, std::vector< const xAOD::IParticle * > > constituents