ATLAS Offline Software
Loading...
Searching...
No Matches
Graph.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef GRAPH_HH_TAURECTOOLS
6#define GRAPH_HH_TAURECTOOLS
7
8#include "NNLayerConfig.h"
9#include "Source.h"
10
11#include <Eigen/Dense>
12
13#include <vector>
14#include <map>
15#include <unordered_map>
16#include <set>
17
18namespace lwtDev {
19
20 class Stack;
21 class RecurrentStack;
22
23
24 // node class: will return a VectorXd from ISource
25 class INode
26 {
27 public:
28 virtual ~INode() {}
29 virtual VectorXd compute(const ISource&) const = 0;
30 virtual size_t n_outputs() const = 0;
31 };
32
33 class InputNode: public INode
34 {
35 public:
36 InputNode(size_t index, size_t n_outputs);
37 virtual VectorXd compute(const ISource&) const override;
38 virtual size_t n_outputs() const override;
39 private:
40 size_t m_index;
42 };
43
44 class FeedForwardNode: public INode
45 {
46 public:
47 FeedForwardNode(const Stack*, const INode* source);
48 virtual VectorXd compute(const ISource&) const override;
49 virtual size_t n_outputs() const override;
50 private:
51 const Stack* m_stack;
53 };
54
55 class ConcatenateNode: public INode
56 {
57 public:
58 ConcatenateNode(const std::vector<const INode*>&);
59 virtual VectorXd compute(const ISource&) const override;
60 virtual size_t n_outputs() const override;
61 private:
62 std::vector<const INode*> m_sources;
64 };
65
66 // sequence nodes
68 {
69 public:
70 virtual ~ISequenceNode() {}
71 virtual MatrixXd scan(const ISource&) const = 0;
72 virtual size_t n_outputs() const = 0;
73 };
74
76 {
77 public:
78 InputSequenceNode(size_t index, size_t n_outputs);
79 virtual MatrixXd scan(const ISource&) const override;
80 virtual size_t n_outputs() const override;
81 private:
82 size_t m_index;
84 };
85
86 class SequenceNode: public ISequenceNode, public INode
87 {
88 public:
89 SequenceNode(const RecurrentStack*, const ISequenceNode* source);
90 virtual MatrixXd scan(const ISource&) const override;
91 virtual VectorXd compute(const ISource&) const override;
92 virtual size_t n_outputs() const override;
93 private:
96 };
97
99 {
100 public:
101 TimeDistributedNode(const Stack*, const ISequenceNode* source);
102 virtual MatrixXd scan(const ISource&) const override;
103 virtual size_t n_outputs() const override;
104 private:
107 };
108 class SumNode: public INode
109 {
110 public:
111 SumNode(const ISequenceNode* source);
112 virtual VectorXd compute(const ISource&) const override;
113 virtual size_t n_outputs() const override;
114 private:
116 };
117
118 // Graph class, owns the nodes
119 class Graph
120 {
121 public:
122 Graph(); // dummy constructor
123 Graph(const std::vector<NodeConfig>& nodes,
124 const std::vector<LayerConfig>& layers);
125 Graph(Graph&) = delete;
126 Graph& operator=(Graph&) = delete;
127 ~Graph();
128 VectorXd compute(const ISource&, size_t node_number) const;
129 VectorXd compute(const ISource&) const;
130 MatrixXd scan(const ISource&, size_t node_number) const;
131 MatrixXd scan(const ISource&) const;
132 private:
133 void build_node(const size_t,
134 const std::vector<NodeConfig>& nodes,
135 const std::vector<LayerConfig>& layers,
136 std::set<size_t> cycle_check = {});
137
138 std::unordered_map<size_t, INode*> m_nodes;
139 size_t m_last_node; // <-- convenience for graphs with one output
140 std::unordered_map<size_t, Stack*> m_stacks;
141 std::unordered_map<size_t, ISequenceNode*> m_seq_nodes;
142 std::unordered_map<size_t, RecurrentStack*> m_seq_stacks;
143 // At some point maybe also convolutional nodes, but we'd have to
144 // have a use case for that first.
145 };
146}
147
148#endif // GRAPH_HH_TAURECTOOLS
std::vector< const INode * > m_sources
Definition Graph.h:62
ConcatenateNode(const std::vector< const INode * > &)
Definition Graph.cxx:104
virtual size_t n_outputs() const override
Definition Graph.cxx:125
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:112
const Stack * m_stack
Definition Graph.h:51
virtual size_t n_outputs() const override
Definition Graph.cxx:100
FeedForwardNode(const Stack *, const INode *source)
Definition Graph.cxx:92
const INode * m_source
Definition Graph.h:52
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:97
std::unordered_map< size_t, INode * > m_nodes
Definition Graph.h:138
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks
Definition Graph.h:142
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
Definition Graph.h:141
MatrixXd scan(const ISource &, size_t node_number) const
Definition Graph.cxx:332
Graph & operator=(Graph &)=delete
std::unordered_map< size_t, Stack * > m_stacks
Definition Graph.h:140
size_t m_last_node
Definition Graph.h:139
void build_node(const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})
Definition Graph.cxx:353
VectorXd compute(const ISource &, size_t node_number) const
Definition Graph.cxx:315
Graph(Graph &)=delete
virtual VectorXd compute(const ISource &) const =0
virtual ~INode()
Definition Graph.h:28
virtual size_t n_outputs() const =0
virtual size_t n_outputs() const =0
virtual MatrixXd scan(const ISource &) const =0
virtual ~ISequenceNode()
Definition Graph.h:70
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:77
virtual size_t n_outputs() const override
Definition Graph.cxx:88
size_t m_index
Definition Graph.h:40
size_t m_n_outputs
Definition Graph.h:41
InputNode(size_t index, size_t n_outputs)
Definition Graph.cxx:72
virtual size_t n_outputs() const override
Definition Graph.cxx:148
virtual MatrixXd scan(const ISource &) const override
Definition Graph.cxx:135
InputSequenceNode(size_t index, size_t n_outputs)
Definition Graph.cxx:130
virtual MatrixXd scan(const ISource &) const override
Definition Graph.cxx:158
const RecurrentStack * m_stack
Definition Graph.h:94
SequenceNode(const RecurrentStack *, const ISequenceNode *source)
Definition Graph.cxx:152
virtual size_t n_outputs() const override
Definition Graph.cxx:170
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:161
const ISequenceNode * m_source
Definition Graph.h:95
virtual size_t n_outputs() const override
Definition Graph.cxx:200
SumNode(const ISequenceNode *source)
Definition Graph.cxx:193
virtual VectorXd compute(const ISource &) const override
Definition Graph.cxx:197
const ISequenceNode * m_source
Definition Graph.h:115
const ISequenceNode * m_source
Definition Graph.h:106
virtual MatrixXd scan(const ISource &) const override
Definition Graph.cxx:180
const Stack * m_stack
Definition Graph.h:105
TimeDistributedNode(const Stack *, const ISequenceNode *source)
Definition Graph.cxx:174
virtual size_t n_outputs() const override
Definition Graph.cxx:189
Definition index.py:1