ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelEDMLoaderBase.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
5#pragma once
9
15
16#include <map>
17#include <vector>
18#include <utility>
19#include <functional>
20
21namespace FlavorTagInference {
22
23 using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
24 using SaltModelInputs = std::map<std::string, Inputs>;
25
28 size_t num_inputs = 0;
29 std::map<std::string, std::vector<const xAOD::IParticle*>> constituents;
30 };
31
33 public:
35 graph_config(salt_model->getGraphConfig()) {};
37 std::string scalarInputName;
38 std::vector<std::pair<std::string /* varName */, std::function<float(const xAOD::IParticle* /* parent */)>>> scalarVarLoaders;
39 std::map<std::string /* vecInputName */, std::shared_ptr<IConstituentsLoader>> vectorVarLoaders;
40
41
42 void addScalarLoader(const std::string& varName, std::function<float(const xAOD::IParticle*)> loader) {
43 scalarVarLoaders.emplace_back(varName, loader);
44 }
45
46 void addVectorLoader(const std::string& vecName, std::shared_ptr<IConstituentsLoader> loader) {
47 vectorVarLoaders.try_emplace(vecName, std::move(loader));
48 }
49
50 virtual SaltModelData loadInputs(const xAOD::IParticle* p) const final {
51 SaltModelData salt_model_data;
52 // loading scalar inputs.
53 std::vector<float> scalar_feat;
54 for (const auto& varLoader : scalarVarLoaders) {
55 std::string varName = varLoader.first;
56 scalar_feat.push_back(varLoader.second(p));
57 }
58 std::vector<int64_t> scalar_feat_dim = {1, static_cast<int64_t>(scalar_feat.size())};
59 Inputs scalar_inputs(scalar_feat, scalar_feat_dim);
60 salt_model_data.gnn_inputs.insert({scalarInputName, scalar_inputs});
61
62 //load vector inputs.
63 for (auto loader : vectorVarLoaders) {
64 std::string input_name = loader.first;
65 auto [input_data, input_objects] = loader.second->getData(*p);
66
67 salt_model_data.gnn_inputs.insert({input_name, input_data});
68 salt_model_data.num_inputs += input_data.first.size();
69 salt_model_data.constituents[input_name] = input_objects;
70 }
71 return salt_model_data;
72 }
73
74 void DumpGnnInputs(const SaltModelInputs& gnn_inputs) const {
75 // Implementation for dumping GNN input data
76 std::cout << "-------- Dumping GNN Input Data --------" << std::endl;
77
78 for (const auto& [name, inputs] : gnn_inputs) {
79 std::cout << "Input Name: " << name << std::endl;
80 std::cout << " vec floats: ";
81 for (const auto& feature : inputs.first) {
82 std::cout << feature << " ";
83 }
84 std::cout << std::endl;
85 std::cout << " vec ints : ";
86 for (const auto& id : inputs.second) {
87 std::cout << id << " ";
88 }
89 std::cout << std::endl;
90 }
91 std::cout << "---------- END GNN Input Data ----------" << std::endl;
92 }
93 }; // class SaltModelEDMLoaderBase
94} // namespace FlavorTagInference
virtual SaltModelData loadInputs(const xAOD::IParticle *p) const final
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
void DumpGnnInputs(const SaltModelInputs &gnn_inputs) const
SaltModelGraphConfig::GraphConfig graph_config
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders
Class providing the definition of the 4-vector interface.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::map< std::string, Inputs > SaltModelInputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
std::shared_ptr< const ISaltModel > ISaltModelPtr
Definition ISaltModel.h:54
std::map< std::string, std::vector< const xAOD::IParticle * > > constituents