ATLAS Offline Software
Graph.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef GRAPH_HH_TAURECTOOLS
6 #define GRAPH_HH_TAURECTOOLS
7 
8 #include "NNLayerConfig.h"
9 #include "Source.h"
10 
11 #include <Eigen/Dense>
12 
13 #include <vector>
14 #include <map>
15 #include <unordered_map>
16 #include <set>
17 
18 namespace lwtDev {
19 
20  class Stack;
21  class RecurrentStack;
22 
23 
24  // node class: will return a VectorXd from ISource
25  class INode
26  {
27  public:
28  virtual ~INode() {}
29  virtual VectorXd compute(const ISource&) const = 0;
30  virtual size_t n_outputs() const = 0;
31  };
32 
33  class InputNode: public INode
34  {
35  public:
36  InputNode(size_t index, size_t n_outputs);
37  virtual VectorXd compute(const ISource&) const override;
38  virtual size_t n_outputs() const override;
39  private:
40  size_t m_index;
41  size_t m_n_outputs;
42  };
43 
44  class FeedForwardNode: public INode
45  {
46  public:
47  FeedForwardNode(const Stack*, const INode* source);
48  virtual VectorXd compute(const ISource&) const override;
49  virtual size_t n_outputs() const override;
50  private:
51  const Stack* m_stack;
52  const INode* m_source;
53  };
54 
55  class ConcatenateNode: public INode
56  {
57  public:
58  ConcatenateNode(const std::vector<const INode*>&);
59  virtual VectorXd compute(const ISource&) const override;
60  virtual size_t n_outputs() const override;
61  private:
62  std::vector<const INode*> m_sources;
63  size_t m_n_outputs;
64  };
65 
66  // sequence nodes
68  {
69  public:
70  virtual ~ISequenceNode() {}
71  virtual MatrixXd scan(const ISource&) const = 0;
72  virtual size_t n_outputs() const = 0;
73  };
74 
76  {
77  public:
78  InputSequenceNode(size_t index, size_t n_outputs);
79  virtual MatrixXd scan(const ISource&) const override;
80  virtual size_t n_outputs() const override;
81  private:
82  size_t m_index;
83  size_t m_n_outputs;
84  };
85 
86  class SequenceNode: public ISequenceNode, public INode
87  {
88  public:
89  SequenceNode(const RecurrentStack*, const ISequenceNode* source);
90  virtual MatrixXd scan(const ISource&) const override;
91  virtual VectorXd compute(const ISource&) const override;
92  virtual size_t n_outputs() const override;
93  private:
96  };
97 
99  {
100  public:
101  TimeDistributedNode(const Stack*, const ISequenceNode* source);
102  virtual MatrixXd scan(const ISource&) const override;
103  virtual size_t n_outputs() const override;
104  private:
105  const Stack* m_stack;
107  };
108  class SumNode: public INode
109  {
110  public:
111  SumNode(const ISequenceNode* source);
112  virtual VectorXd compute(const ISource&) const override;
113  virtual size_t n_outputs() const override;
114  private:
116  };
117 
118  // Graph class, owns the nodes
119  class Graph
120  {
121  public:
122  Graph(); // dummy constructor
123  Graph(const std::vector<NodeConfig>& nodes,
124  const std::vector<LayerConfig>& layers);
125  Graph(Graph&) = delete;
126  Graph& operator=(Graph&) = delete;
127  ~Graph();
128  VectorXd compute(const ISource&, size_t node_number) const;
129  VectorXd compute(const ISource&) const;
130  MatrixXd scan(const ISource&, size_t node_number) const;
131  MatrixXd scan(const ISource&) const;
132  private:
133  void build_node(const size_t,
134  const std::vector<NodeConfig>& nodes,
135  const std::vector<LayerConfig>& layers,
136  std::set<size_t> cycle_check = {});
137 
138  std::unordered_map<size_t, INode*> m_nodes;
139  size_t m_last_node; // <-- convenience for graphs with one output
140  std::unordered_map<size_t, Stack*> m_stacks;
141  std::unordered_map<size_t, ISequenceNode*> m_seq_nodes;
142  std::unordered_map<size_t, RecurrentStack*> m_seq_stacks;
143  // At some point maybe also convolutional nodes, but we'd have to
144  // have a use case for that first.
145  };
146 }
147 
148 #endif // GRAPH_HH_TAURECTOOLS
lwtDev::FeedForwardNode::FeedForwardNode
FeedForwardNode(const Stack *, const INode *source)
Definition: Graph.cxx:92
lwtDev::TimeDistributedNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:189
lwtDev::SequenceNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:170
lwtDev::FeedForwardNode::m_source
const INode * m_source
Definition: Graph.h:52
lwtDev::Graph::~Graph
~Graph()
Definition: Graph.cxx:294
lwtDev::SumNode
Definition: Graph.h:109
lwtDev::Graph::Graph
Graph()
Definition: Graph.cxx:276
lwtDev::SequenceNode::m_source
const ISequenceNode * m_source
Definition: Graph.h:95
lwtDev::INode
Definition: Graph.h:26
index
Definition: index.py:1
lwtDev::Graph::m_nodes
std::unordered_map< size_t, INode * > m_nodes
Definition: Graph.h:138
lwtDev::INode::n_outputs
virtual size_t n_outputs() const =0
lwtDev::InputSequenceNode::InputSequenceNode
InputSequenceNode(size_t index, size_t n_outputs)
Definition: Graph.cxx:130
lwtDev::InputSequenceNode::scan
virtual MatrixXd scan(const ISource &) const override
Definition: Graph.cxx:135
lwtDev::SumNode::compute
virtual VectorXd compute(const ISource &) const override
Definition: Graph.cxx:197
lwtDev::Graph::m_last_node
size_t m_last_node
Definition: Graph.h:139
lwtDev::SequenceNode::compute
virtual VectorXd compute(const ISource &) const override
Definition: Graph.cxx:161
module_driven_slicing.layers
layers
Definition: module_driven_slicing.py:114
lwtDev::InputNode::InputNode
InputNode(size_t index, size_t n_outputs)
Definition: Graph.cxx:72
lwtDev::InputSequenceNode
Definition: Graph.h:76
lwtDev::Graph::build_node
void build_node(const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})
Definition: Graph.cxx:353
Source.h
lwtDev::SequenceNode::SequenceNode
SequenceNode(const RecurrentStack *, const ISequenceNode *source)
Definition: Graph.cxx:152
lwtDev::SumNode::SumNode
SumNode(const ISequenceNode *source)
Definition: Graph.cxx:193
lwtDev::TimeDistributedNode::m_source
const ISequenceNode * m_source
Definition: Graph.h:106
lwtDev::ConcatenateNode::compute
virtual VectorXd compute(const ISource &) const override
Definition: Graph.cxx:112
lwtDev::Graph::m_stacks
std::unordered_map< size_t, Stack * > m_stacks
Definition: Graph.h:140
lwtDev::SumNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:200
lwtDev::Graph
Definition: Graph.h:120
lwtDev::InputSequenceNode::m_n_outputs
size_t m_n_outputs
Definition: Graph.h:83
lwtDev::Graph::Graph
Graph(Graph &)=delete
lwtDev::ISequenceNode
Definition: Graph.h:68
lwtDev::ISequenceNode::scan
virtual MatrixXd scan(const ISource &) const =0
lwtDev::SequenceNode
Definition: Graph.h:87
lwtDev::ConcatenateNode::m_sources
std::vector< const INode * > m_sources
Definition: Graph.h:62
lwtDev::SumNode::m_source
const ISequenceNode * m_source
Definition: Graph.h:115
lwtDev::InputNode::compute
virtual VectorXd compute(const ISource &) const override
Definition: Graph.cxx:77
lwtDev::INode::compute
virtual VectorXd compute(const ISource &) const =0
lwtDev::SequenceNode::scan
virtual MatrixXd scan(const ISource &) const override
Definition: Graph.cxx:158
lwtDev::FeedForwardNode
Definition: Graph.h:45
lwtDev::InputSequenceNode::m_index
size_t m_index
Definition: Graph.h:82
lwtDev::INode::~INode
virtual ~INode()
Definition: Graph.h:28
lwtDev::TimeDistributedNode::TimeDistributedNode
TimeDistributedNode(const Stack *, const ISequenceNode *source)
Definition: Graph.cxx:174
lwtDev::RecurrentStack
Definition: Stack.h:174
lwtDev
Definition: Reconstruction/tauRecTools/Root/lwtnn/Exceptions.cxx:8
lwtDev::InputSequenceNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:148
lwtDev::ConcatenateNode::ConcatenateNode
ConcatenateNode(const std::vector< const INode * > &)
Definition: Graph.cxx:104
lwtDev::TimeDistributedNode::scan
virtual MatrixXd scan(const ISource &) const override
Definition: Graph.cxx:180
lwtDev::InputNode::m_index
size_t m_index
Definition: Graph.h:40
lwtDev::InputNode
Definition: Graph.h:34
lwtDev::Graph::m_seq_nodes
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
Definition: Graph.h:141
lwtDev::ConcatenateNode::m_n_outputs
size_t m_n_outputs
Definition: Graph.h:63
lwtDev::SequenceNode::m_stack
const RecurrentStack * m_stack
Definition: Graph.h:94
lwtDev::ISource
Definition: Source.h:18
lwtDev::Graph::compute
VectorXd compute(const ISource &, size_t node_number) const
Definition: Graph.cxx:315
lwtDev::InputNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:88
lwtDev::Stack
Definition: Stack.h:48
lwtDev::ConcatenateNode
Definition: Graph.h:56
lwtDev::ConcatenateNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:125
lwtDev::Graph::scan
MatrixXd scan(const ISource &, size_t node_number) const
Definition: Graph.cxx:332
lwtDev::TimeDistributedNode::m_stack
const Stack * m_stack
Definition: Graph.h:105
lwtDev::Graph::operator=
Graph & operator=(Graph &)=delete
lwtDev::TimeDistributedNode
Definition: Graph.h:99
lwtDev::ISequenceNode::n_outputs
virtual size_t n_outputs() const =0
lwtDev::FeedForwardNode::compute
virtual VectorXd compute(const ISource &) const override
Definition: Graph.cxx:97
lwtDev::ISequenceNode::~ISequenceNode
virtual ~ISequenceNode()
Definition: Graph.h:70
lwtDev::FeedForwardNode::m_stack
const Stack * m_stack
Definition: Graph.h:51
lwtDev::Graph::m_seq_stacks
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks
Definition: Graph.h:142
lwtDev::InputNode::m_n_outputs
size_t m_n_outputs
Definition: Graph.h:41
lwtDev::FeedForwardNode::n_outputs
virtual size_t n_outputs() const override
Definition: Graph.cxx:100
NNLayerConfig.h