ATLAS Offline Software
Loading...
Searching...
No Matches
Graph.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
8
9#include <set>
10#include <memory>
11
12namespace lwtDev {
13
14 // Sources
15 VectorSource::VectorSource(std::vector<VectorXd>&& vv,
16 std::vector<MatrixXd>&& mm):
17 m_inputs(std::move(vv)),
18 m_matrix_inputs(std::move(mm))
19 {
20 }
21 VectorXd VectorSource::at(size_t index) const {
22 if (index >= m_inputs.size()) {
24 "VectorSource: no source vector defined at " + std::to_string(index));
25 }
26 return m_inputs.at(index);
27 }
28 MatrixXd VectorSource::matrix_at(size_t index) const {
29 if (index >= m_matrix_inputs.size()) {
31 "VectorSource: no source matrix defined at " + std::to_string(index));
32 }
33 return m_matrix_inputs.at(index);
34 }
35
36 DummySource::DummySource(const std::vector<size_t>& input_sizes,
37 const std::vector<std::pair<size_t,size_t> >& ma):
38 m_sizes(input_sizes),
40 {
41 }
42 VectorXd DummySource::at(size_t index) const {
43 if (index >= m_sizes.size()) {
45 "Dummy Source: no size defined at " + std::to_string(index));
46 }
47 size_t n_entries = m_sizes.at(index);
48 VectorXd vec(n_entries);
49 for (size_t iii = 0; iii < n_entries; iii++) {
50 vec(iii) = iii;
51 }
52 return vec;
53 }
54 MatrixXd DummySource::matrix_at(size_t index) const {
55 if (index >= m_matrix_sizes.size()) {
57 "Dummy Source: no size defined at " + std::to_string(index));
58 }
59 size_t n_rows = m_matrix_sizes.at(index).first;
60 size_t n_cols = m_matrix_sizes.at(index).second;
61 MatrixXd mat(n_rows, n_cols);
62 for (size_t iii = 0; iii < n_rows; iii++) {
63 for (size_t jjj = 0; jjj < n_cols; jjj++) {
64 mat(iii, jjj) = jjj + n_cols * iii;
65 }
66 }
67 return mat;
68 }
69
70
71 // Nodes
77 VectorXd InputNode::compute(const ISource& source) const {
78 VectorXd output = source.at(m_index);
79 assert(output.rows() > 0);
80 if (static_cast<size_t>(output.rows()) != m_n_outputs) {
81 std::string len = std::to_string(output.rows());
82 std::string found = std::to_string(m_n_outputs);
84 "Found vector of length " + len + ", expected " + found);
85 }
86 return output;
87 }
88 size_t InputNode::n_outputs() const {
89 return m_n_outputs;
90 }
91
92 FeedForwardNode::FeedForwardNode(const Stack* stack, const INode* source):
93 m_stack(stack),
94 m_source(source)
95 {
96 }
97 VectorXd FeedForwardNode::compute(const ISource& source) const {
98 return m_stack->compute(m_source->compute(source));
99 }
101 return m_stack->n_outputs();
102 }
103
104 ConcatenateNode::ConcatenateNode(const std::vector<const INode*>& sources):
105 m_sources(sources),
106 m_n_outputs(0)
107 {
108 for (const auto source: sources) {
109 m_n_outputs += source->n_outputs();
110 }
111 }
112 VectorXd ConcatenateNode::compute(const ISource& source) const {
113 VectorXd output(m_n_outputs);
114 size_t offset = 0;
115 for (const auto node: m_sources) {
116 VectorXd input = node->compute(source);
117 size_t n_elements = input.rows();
118 assert(n_elements == node->n_outputs());
119 output.segment(offset, n_elements) = input;
120 offset += n_elements;
121 }
122 assert(offset == m_n_outputs);
123 return output;
124 }
126 return m_n_outputs;
127 }
128
129 // Sequence nodes
135 MatrixXd InputSequenceNode::scan(const ISource& source) const {
136 MatrixXd output = source.matrix_at(m_index);
137 if (output.rows() == 0) {
138 throw NNEvaluationException("empty input sequence");
139 }
140 if (static_cast<size_t>(output.rows()) != m_n_outputs) {
141 std::string len = std::to_string(output.rows());
142 std::string found = std::to_string(m_n_outputs);
144 "Found vector of length " + len + ", expected " + found);
145 }
146 return output;
147 }
149 return m_n_outputs;
150 }
151
153 const ISequenceNode* source) :
154 m_stack(stack),
155 m_source(source)
156 {
157 }
158 MatrixXd SequenceNode::scan(const ISource& source) const {
159 return m_stack->scan(m_source->scan(source));
160 }
161 VectorXd SequenceNode::compute(const ISource& src) const {
162 MatrixXd mat = scan(src);
163 size_t n_cols = mat.cols();
164 // special handling for empty sequence
165 if (n_cols == 0) {
166 return MatrixXd::Zero(mat.rows(), 1);
167 }
168 return mat.col(n_cols - 1);
169 }
170 size_t SequenceNode::n_outputs() const {
171 return m_stack->n_outputs();
172 }
173
175 const ISequenceNode* source):
176 m_stack(stack),
177 m_source(source)
178 {
179 }
180 MatrixXd TimeDistributedNode::scan(const ISource& source) const {
181 MatrixXd input = m_source->scan(source);
182 MatrixXd output(m_stack->n_outputs(), input.cols());
183 size_t n_columns = input.cols();
184 for (size_t col_n = 0; col_n < n_columns; col_n++) {
185 output.col(col_n) = m_stack->compute(input.col(col_n));
186 }
187 return output;
188 }
190 return m_stack->n_outputs();
191 }
192
194 m_source(source)
195 {
196 }
197 VectorXd SumNode::compute(const ISource& source) const {
198 return m_source->scan(source).rowwise().sum();
199 }
200 size_t SumNode::n_outputs() const {
201 return m_source->n_outputs();
202 }
203
204}
205
206namespace {
207 using namespace lwtDev;
208 void throw_cfg(const std::string & msg, size_t index) {
209 throw NNConfigurationException(msg + " " + std::to_string(index));
210 }
211 void check_compute_node(const NodeConfig& node) {
212 size_t n_source = node.sources.size();
213 if (n_source != 1) throw_cfg("need one source, found", n_source);
214 int layer_n = node.index;
215 if (layer_n < 0) throw_cfg("negative layer number", layer_n);
216 }
217 void check_compute_node(const NodeConfig& node, size_t n_layers) {
218 check_compute_node(node);
219 int layer_n = node.index;
220 if (static_cast<size_t>(layer_n) >= n_layers) {
221 throw_cfg("no layer number", layer_n);
222 }
223 }
224 // NOTE: you own this pointer!
225 INode* get_feedforward_node(
226 const NodeConfig& node,
227 const std::vector<LayerConfig>& layers,
228 const std::unordered_map<size_t, INode*>& node_map,
229 std::unordered_map<size_t, Stack*>& stack_map) {
230
231 // FIXME: merge this block with the time distributed one later on
232 check_compute_node(node, layers.size());
233 INode* source = node_map.at(node.sources.at(0));
234 int layer_n = node.index;
235 if (!stack_map.count(layer_n)) {
236 stack_map[layer_n] = new Stack(source->n_outputs(),
237 {layers.at(layer_n)});
238 }
239 return new FeedForwardNode(stack_map.at(layer_n), source);
240 }
241 SequenceNode* get_sequence_node(
242 const NodeConfig& node,
243 const std::vector<LayerConfig>& layers,
244 const std::unordered_map<size_t, ISequenceNode*>& node_map,
245 std::unordered_map<size_t, RecurrentStack*>& stack_map) {
246
247 check_compute_node(node, layers.size());
248 ISequenceNode* source = node_map.at(node.sources.at(0));
249 int layer_n = node.index;
250 if (!stack_map.count(layer_n)) {
251 stack_map[layer_n] = new RecurrentStack(source->n_outputs(),
252 {layers.at(layer_n)});
253 }
254 return new SequenceNode(stack_map.at(layer_n), source);
255 }
256 TimeDistributedNode* get_time_distributed_node(
257 const NodeConfig& node,
258 const std::vector<LayerConfig>& layers,
259 const std::unordered_map<size_t, ISequenceNode*>& node_map,
260 std::unordered_map<size_t, Stack*>& stack_map) {
261
262 // FIXME: merge this block with the FF block above
263 check_compute_node(node, layers.size());
264 ISequenceNode* source = node_map.at(node.sources.at(0));
265 int layer_n = node.index;
266 if (!stack_map.count(layer_n)) {
267 stack_map[layer_n] = new Stack(source->n_outputs(),
268 {layers.at(layer_n)});
269 }
270 return new TimeDistributedNode(stack_map.at(layer_n), source);
271 }
272}
273
274namespace lwtDev {
275 // graph
277 m_stacks[0] = new Stack;
278
279 m_nodes[0] = new InputNode(0, 2);
280 m_nodes[1] = new InputNode(1, 2);
281 m_nodes[2] = new ConcatenateNode({m_nodes.at(0), m_nodes.at(1)});
282 m_nodes[3] = new FeedForwardNode(m_stacks.at(0), m_nodes.at(2));
283 m_last_node = 3;
284 }
285 Graph::Graph(const std::vector<NodeConfig>& nodes,
286 const std::vector<LayerConfig>& layers):
287 m_last_node(0)
288 {
289 for (size_t iii = 0; iii < nodes.size(); iii++) {
290 build_node(iii, nodes, layers);
291 }
292 // assert(maps.node.size() + maps.seq_node.size() == nodes.size());
293 }
295 for (auto node: m_nodes) {
296 delete node.second;
297 node.second = nullptr;
298 }
299 for (auto node: m_seq_nodes) {
300 // The m_nodes collection is the owner of anything that inherits
301 // from both INode and ISequenceNode. So we try not to delete
302 // anything that the m_nodes would already take care of.
303 if (!m_nodes.count(node.first)) delete node.second;
304 node.second = nullptr;
305 }
306 for (auto stack: m_stacks) {
307 delete stack.second;
308 stack.second = nullptr;
309 }
310 for (auto stack: m_seq_stacks) {
311 delete stack.second;
312 stack.second = nullptr;
313 }
314 }
315 VectorXd Graph::compute(const ISource& source, size_t node_number) const {
316 if (!m_nodes.count(node_number)) {
317 auto num = std::to_string(node_number);
318 if (m_seq_nodes.count(node_number)) {
320 "Graph: output at " + num + " not feed forward");
321 }
322 throw NNEvaluationException("Graph: no output at " + num);
323 }
324 return m_nodes.at(node_number)->compute(source);
325 }
326 VectorXd Graph::compute(const ISource& source) const {
327 if (!m_nodes.count(m_last_node)) {
328 throw OutputRankException("Graph: output is not a feed forward node");
329 }
330 return m_nodes.at(m_last_node)->compute(source);
331 }
332 MatrixXd Graph::scan(const ISource& source, size_t node_number) const {
333 if (!m_seq_nodes.count(node_number)) {
334 auto num = std::to_string(node_number);
335 if (m_nodes.count(node_number)) {
337 "Graph: output at " + num + " not a sequence");
338 }
339 throw NNEvaluationException("Graph: no output at " + num);
340 }
341 return m_seq_nodes.at(node_number)->scan(source);
342 }
343 MatrixXd Graph::scan(const ISource& source) const {
344 if (!m_seq_nodes.count(m_last_node)) {
345 throw OutputRankException("Graph: output is not a sequence node");
346 }
347 return m_seq_nodes.at(m_last_node)->scan(source);
348 }
349
350 // ______________________________________________________________________
351 // private methods
352
353 void Graph::build_node(const size_t iii,
354 const std::vector<NodeConfig>& nodes,
355 const std::vector<LayerConfig>& layers,
356 std::set<size_t> cycle_check) {
357 if (m_nodes.count(iii) || m_seq_nodes.count(iii)) return;
358
359 // we insist that the upstream nodes are built before the
360 // downstream ones, so the last node built should be some kind of
361 // sink for graphs with only one output this will be it.
362 m_last_node = iii;
363
364 if (iii >= nodes.size()) throw_cfg("no node index", iii);
365
366 const NodeConfig& node = nodes.at(iii);
367
368 // if it's an input, build and return
369 if (node.type == NodeConfig::Type::INPUT) {
370 check_compute_node(node);
371 size_t input_number = node.sources.at(0);
372 m_nodes[iii] = new InputNode(input_number, node.index);
373 return;
374 } else if (node.type == NodeConfig::Type::INPUT_SEQUENCE) {
375 check_compute_node(node);
376 size_t input_number = node.sources.at(0);
377 m_seq_nodes[iii] = new InputSequenceNode(input_number, node.index);
378 return;
379 }
380
381 // otherwise build all the inputs first
382 if (cycle_check.count(iii)) {
383 throw NNConfigurationException("found cycle in graph");
384 }
385 cycle_check.insert(iii);
386 for (size_t source_node: node.sources) {
387 build_node(source_node, nodes, layers, cycle_check);
388 }
389
390 // check node types
391 if (node.type == NodeConfig::Type::FEED_FORWARD) {
392 m_nodes[iii] = get_feedforward_node(node, layers,
394 } else if (node.type == NodeConfig::Type::TIME_DISTRIBUTED) {
395 m_seq_nodes[iii] = get_time_distributed_node(node, layers,
397 } else if (node.type == NodeConfig::Type::SEQUENCE) {
398 SequenceNode* seq_node =
399 get_sequence_node(node, layers, m_seq_nodes, m_seq_stacks);
400 // entering in m_nodes means that m_nodes will delete this one
401 m_nodes[iii] = seq_node;
402 m_seq_nodes[iii] = seq_node;
403 } else if (node.type == NodeConfig::Type::CONCATENATE) {
404 // build concatenate layer
405 std::vector<const INode*> in_nodes;
406 for (size_t source_node: node.sources) {
407 in_nodes.push_back(m_nodes.at(source_node));
408 }
409 m_nodes[iii] = new ConcatenateNode(in_nodes);
410 } else if (node.type == NodeConfig::Type::SUM) {
411 if (node.sources.size() != 1) {
412 throw NNConfigurationException("Sum node needs exactly 1 source");
413 }
414 m_nodes[iii] = new SumNode(m_seq_nodes.at(node.sources.at(0)));
415 } else {
416 throw NNConfigurationException("unknown node type");
417 }
418 }
419
420}
std::vector< size_t > vec
std::vector< const INode * > m_sources
Definition Graph.h:62
ConcatenateNode(const std::vector< const INode * > &)
Definition Graph.cxx:104
virtual size_t n_outputs() const override
Definition Graph.cxx:125
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:112
std::vector< std::pair< size_t, size_t > > m_matrix_sizes
Definition Source.h:45
std::vector< size_t > m_sizes
Definition Source.h:44
virtual VectorXd at(size_t index) const override
Definition Graph.cxx:42
virtual MatrixXd matrix_at(size_t index) const override
Definition Graph.cxx:54
DummySource(const std::vector< size_t > &input_sizes, const std::vector< std::pair< size_t, size_t > > &={})
Definition Graph.cxx:36
const Stack * m_stack
Definition Graph.h:51
virtual size_t n_outputs() const override
Definition Graph.cxx:100
FeedForwardNode(const Stack *, const INode *source)
Definition Graph.cxx:92
const INode * m_source
Definition Graph.h:52
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:97
std::unordered_map< size_t, INode * > m_nodes
Definition Graph.h:138
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks
Definition Graph.h:142
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
Definition Graph.h:141
MatrixXd scan(const ISource &, size_t node_number) const
Definition Graph.cxx:332
std::unordered_map< size_t, Stack * > m_stacks
Definition Graph.h:140
size_t m_last_node
Definition Graph.h:139
void build_node(const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})
Definition Graph.cxx:353
VectorXd compute(const ISource &, size_t node_number) const
Definition Graph.cxx:315
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:77
virtual size_t n_outputs() const override
Definition Graph.cxx:88
size_t m_index
Definition Graph.h:40
size_t m_n_outputs
Definition Graph.h:41
InputNode(size_t index, size_t n_outputs)
Definition Graph.cxx:72
virtual size_t n_outputs() const override
Definition Graph.cxx:148
virtual MatrixXd scan(const ISource &) const override
Definition Graph.cxx:135
InputSequenceNode(size_t index, size_t n_outputs)
Definition Graph.cxx:130
virtual MatrixXd scan(const ISource &) const override
Definition Graph.cxx:158
const RecurrentStack * m_stack
Definition Graph.h:94
SequenceNode(const RecurrentStack *, const ISequenceNode *source)
Definition Graph.cxx:152
virtual size_t n_outputs() const override
Definition Graph.cxx:170
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:161
const ISequenceNode * m_source
Definition Graph.h:95
virtual size_t n_outputs() const override
Definition Graph.cxx:200
SumNode(const ISequenceNode *source)
Definition Graph.cxx:193
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:197
const ISequenceNode * m_source
Definition Graph.h:115
const ISequenceNode * m_source
Definition Graph.h:106
virtual MatrixXd scan(const ISource &) const override
Definition Graph.cxx:180
const Stack * m_stack
Definition Graph.h:105
TimeDistributedNode(const Stack *, const ISequenceNode *source)
Definition Graph.cxx:174
virtual size_t n_outputs() const override
Definition Graph.cxx:189
VectorSource(std::vector< VectorXd > &&, std::vector< MatrixXd > &&={})
Definition Graph.cxx:15
virtual VectorXd at(size_t index) const override
Definition Graph.cxx:21
virtual MatrixXd matrix_at(size_t index) const override
Definition Graph.cxx:28
std::vector< VectorXd > m_inputs
Definition Source.h:32
std::vector< MatrixXd > m_matrix_inputs
Definition Source.h:33
Definition node.h:24
void type(TYPE t)
Definition node.h:51
layers(flags, cells_name, *args, **kw)
Here we define wrapper functions to set up all of the standard corrections.
Definition index.py:1
STL namespace.
MsgStream & msg
Definition testRead.cxx:32