ATLAS Offline Software
Loading...
Searching...
No Matches
LightweightGraph.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
8#include <iostream>
9#include <Eigen/Dense>
10
11namespace {
12 using namespace Eigen;
13 using namespace lwtDev;
14
16 typedef InputPreprocessor IP;
17 typedef std::vector<std::pair<std::string, IP*> > Preprocs;
19 typedef InputVectorPreprocessor IVP;
20 typedef std::vector<std::pair<std::string, IVP*> > VecPreprocs;
21
22 // this is used internally to ensure that we only look up map inputs
23 // when the network asks for them.
24 class LazySource: public ISource
25 {
26 public:
27 LazySource(const NodeMap&, const SeqNodeMap&,
28 const Preprocs&, const VecPreprocs&);
29 virtual VectorXd at(size_t index) const override;
30 virtual MatrixXd matrix_at(size_t index) const override;
31 private:
32 const NodeMap& m_nodes;
33 const SeqNodeMap& m_seqs;
34 const Preprocs& m_preprocs;
35 const VecPreprocs& m_vec_preprocs;
36 };
37
38 LazySource::LazySource(const NodeMap& n, const SeqNodeMap& s,
39 const Preprocs& p, const VecPreprocs& v):
40 m_nodes(n), m_seqs(s), m_preprocs(p), m_vec_preprocs(v)
41 {
42 }
43 VectorXd LazySource::at(size_t index) const
44 {
45 const auto& proc = m_preprocs.at(index);
46 if (!m_nodes.count(proc.first)) {
47 throw NNEvaluationException("Can't find node " + proc.first);
48 }
49 const auto& preproc = *proc.second;
50 return preproc(m_nodes.at(proc.first));
51 }
52 MatrixXd LazySource::matrix_at(size_t index) const
53 {
54 const auto& proc = m_vec_preprocs.at(index);
55 if (!m_seqs.count(proc.first)) {
56 throw NNEvaluationException("Can't find sequence node " + proc.first);
57 }
58 const auto& preproc = *proc.second;
59 return preproc(m_seqs.at(proc.first));
60 }
61
62}
63namespace lwtDev {
64 // ______________________________________________________________________
65 // Lightweight Graph
66
69 const std::string& default_output):
70 m_graph(new Graph(config.nodes, config.layers))
71 {
72 for (const auto& node: config.inputs) {
73 m_preprocs.emplace_back(
74 node.name, new InputPreprocessor(node.variables));
75 }
76 for (const auto& node: config.input_sequences) {
77 m_vec_preprocs.emplace_back(
78 node.name, new InputVectorPreprocessor(node.variables));
79 }
80 size_t output_n = 0;
81 for (const auto& node: config.outputs) {
82 m_outputs.emplace_back(node.second.node_index, node.second.labels);
83 m_output_indices.emplace(node.first, output_n);
84 output_n++;
85 }
86 if (default_output.size() > 0) {
87 if (!m_output_indices.count(default_output)) {
88 throw NNConfigurationException("no output node" + default_output);
89 }
90 m_default_output = m_output_indices.at(default_output);
91 } else if (output_n == 1) {
92 m_default_output = 0;
93 } else {
94 throw NNConfigurationException("you must specify a default output");
95 }
96 }
97
99 delete m_graph;
100 for (auto& preproc: m_preprocs) {
101 delete preproc.second;
102 preproc.second = 0;
103 }
104 for (auto& preproc: m_vec_preprocs) {
105 delete preproc.second;
106 preproc.second = 0;
107 }
108 }
109
111 const SeqNodeMap& seq) const {
112 return compute(nodes, seq, m_default_output);
113 }
115 const SeqNodeMap& seq,
116 const std::string& output) const {
117 if (!m_output_indices.count(output)) {
118 throw NNEvaluationException("no output node " + output);
119 }
120 return compute(nodes, seq, m_output_indices.at(output));
121 }
123 const SeqNodeMap& seq,
124 size_t idx) const {
125 LazySource source(nodes, seq, m_preprocs, m_vec_preprocs);
126 VectorXd result = m_graph->compute(source, m_outputs.at(idx).first);
127 const std::vector<std::string>& labels = m_outputs.at(idx).second;
128 std::map<std::string, double> output;
129 for (size_t iii = 0; iii < labels.size(); iii++) {
130 output[labels.at(iii)] = result(iii);
131 }
132 return output;
133 }
134
136 const SeqNodeMap& seq) const {
137 return scan(nodes, seq, m_default_output);
138 }
140 const SeqNodeMap& seq,
141 const std::string& output) const {
142 if (!m_output_indices.count(output)) {
143 throw NNEvaluationException("no output node " + output);
144 }
145 return scan(nodes, seq, m_output_indices.at(output));
146 }
148 const SeqNodeMap& seq,
149 size_t idx) const {
150 LazySource source(nodes, seq, m_preprocs, m_vec_preprocs);
151 MatrixXd result = m_graph->scan(source, m_outputs.at(idx).first);
152 const std::vector<std::string>& labels = m_outputs.at(idx).second;
153 std::map<std::string, std::vector<double> > output;
154 for (size_t iii = 0; iii < labels.size(); iii++) {
155 VectorXd row = result.row(iii);
156 std::vector<double> out_vector(row.data(), row.data() + row.size());
157 output[labels.at(iii)] = std::move(out_vector);
158 }
159 return output;
160 }
161
162}
if(febId1==febId2)
LightweightGraph(const GraphConfig &config, const std::string &default_output="")
VectorMap scan(const NodeMap &, const SeqNodeMap &={}) const
std::map< std::string, ValueMap > NodeMap
std::map< std::string, size_t > m_output_indices
std::map< std::string, VectorMap > SeqNodeMap
ValueMap compute(const NodeMap &, const SeqNodeMap &={}) const
std::vector< std::pair< size_t, std::vector< std::string > > > m_outputs
Definition node.h:24
Definition index.py:1
std::map< std::string, std::vector< double > > VectorMap
std::map< std::string, double > ValueMap
LightweightGraph::NodeMap NodeMap
std::map< std::string, VectorMap > SeqNodeMap