ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelEDMLoaderBase.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
5#pragma once
9
17
18#include <map>
19#include <vector>
20#include <utility>
21#include <functional>
22
23namespace FlavorTagInference {
24
25 using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>;
26 using SaltModelInputs = std::map<std::string, Inputs>;
27
32
34 public:
36 graph_config(salt_model->getGraphConfig()) {};
38 std::string scalarInputName;
39 std::vector<std::pair<std::string /* varName */, std::function<float(const xAOD::IParticle* /* parent */)>>> scalarVarLoaders;
40 std::map<std::string /* vecInputName */, std::shared_ptr<IConstituentsLoader>> vectorVarLoaders;
41
42
43 void addScalarLoader(const std::string& varName, std::function<float(const xAOD::IParticle*)> loader) {
44 scalarVarLoaders.emplace_back(varName, loader);
45 }
46
47 void addVectorLoader(const std::string& vecName, std::shared_ptr<IConstituentsLoader> loader) {
48 vectorVarLoaders.try_emplace(vecName, std::move(loader));
49 }
50
51 virtual SaltModelData loadInputs(const xAOD::IParticle* p) const final {
52 SaltModelData salt_model_data;
53 // loading scalar inputs.
54 std::vector<float> scalar_feat;
55 for (const auto& varLoader : scalarVarLoaders) {
56 std::string varName = varLoader.first;
57 scalar_feat.push_back(varLoader.second(p));
58 }
59 std::vector<int64_t> scalar_feat_dim = {1, static_cast<int64_t>(scalar_feat.size())};
60 Inputs scalar_inputs(scalar_feat, scalar_feat_dim);
61 salt_model_data.gnn_inputs.insert({scalarInputName, scalar_inputs});
62
63 //load vector inputs.
64 for (auto loader : vectorVarLoaders) {
65 std::string input_name = loader.first;
66 Inputs input_data = loader.second->getData(*p);
67
68 salt_model_data.gnn_inputs.insert({input_name, input_data});
69 salt_model_data.num_inputs += input_data.first.size();
70 }
71 return salt_model_data;
72 }
73
74 void DumpGnnInputs(const SaltModelInputs& gnn_inputs) const {
75 // Implementation for dumping GNN input data
76 std::cout << "-------- Dumping GNN Input Data --------" << std::endl;
77
78 for (const auto& [name, inputs] : gnn_inputs) {
79 std::cout << "Input Name: " << name << std::endl;
80 std::cout << " vec floats: ";
81 for (const auto& feature : inputs.first) {
82 std::cout << feature << " ";
83 }
84 std::cout << std::endl;
85 std::cout << " vec ints : ";
86 for (const auto& id : inputs.second) {
87 std::cout << id << " ";
88 }
89 std::cout << std::endl;
90 }
91 std::cout << "---------- END GNN Input Data ----------" << std::endl;
92 }
93 }; // class SaltModelEDMLoaderBase
94} // namespace FlavorTagInference
virtual SaltModelData loadInputs(const xAOD::IParticle *p) const final
void addVectorLoader(const std::string &vecName, std::shared_ptr< IConstituentsLoader > loader)
void addScalarLoader(const std::string &varName, std::function< float(const xAOD::IParticle *)> loader)
void DumpGnnInputs(const SaltModelInputs &gnn_inputs) const
SaltModelGraphConfig::GraphConfig graph_config
std::vector< std::pair< std::string, std::function< float(const xAOD::IParticle *)> > > scalarVarLoaders
std::map< std::string, std::shared_ptr< IConstituentsLoader > > vectorVarLoaders
Class providing the definition of the 4-vector interface.
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::map< std::string, Inputs > SaltModelInputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
std::shared_ptr< const ISaltModel > ISaltModelPtr
Definition ISaltModel.h:54