ATLAS Offline Software
Loading...
Searching...
No Matches
lwtDev::Graph Class Reference

#include <Graph.h>

Collaboration diagram for lwtDev::Graph:

Public Member Functions

 Graph ()
 Graph (const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers)
 Graph (Graph &)=delete
Graphoperator= (Graph &)=delete
 ~Graph ()
VectorXd compute (const ISource &, size_t node_number) const
VectorXd compute (const ISource &) const
MatrixXd scan (const ISource &, size_t node_number) const
MatrixXd scan (const ISource &) const

Private Member Functions

void build_node (const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})

Private Attributes

std::unordered_map< size_t, INode * > m_nodes
size_t m_last_node
std::unordered_map< size_t, Stack * > m_stacks
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks

Detailed Description

Definition at line 119 of file Graph.h.

Constructor & Destructor Documentation

◆ Graph() [1/3]

lwtDev::Graph::Graph ( )

Definition at line 276 of file Graph.cxx.

276 {
277 m_stacks[0] = new Stack;
278
279 m_nodes[0] = new InputNode(0, 2);
280 m_nodes[1] = new InputNode(1, 2);
281 m_nodes[2] = new ConcatenateNode({m_nodes.at(0), m_nodes.at(1)});
282 m_nodes[3] = new FeedForwardNode(m_stacks.at(0), m_nodes.at(2));
283 m_last_node = 3;
284 }
std::unordered_map< size_t, INode * > m_nodes
Definition Graph.h:138
std::unordered_map< size_t, Stack * > m_stacks
Definition Graph.h:140
size_t m_last_node
Definition Graph.h:139

◆ Graph() [2/3]

lwtDev::Graph::Graph ( const std::vector< NodeConfig > & nodes,
const std::vector< LayerConfig > & layers )

Definition at line 285 of file Graph.cxx.

286 :
287 m_last_node(0)
288 {
289 for (size_t iii = 0; iii < nodes.size(); iii++) {
290 build_node(iii, nodes, layers);
291 }
292 // assert(maps.node.size() + maps.seq_node.size() == nodes.size());
293 }
void build_node(const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})
Definition Graph.cxx:353

◆ Graph() [3/3]

lwtDev::Graph::Graph ( Graph & )
delete

◆ ~Graph()

lwtDev::Graph::~Graph ( )

Definition at line 294 of file Graph.cxx.

294 {
295 for (auto node: m_nodes) {
296 delete node.second;
297 node.second = nullptr;
298 }
299 for (auto node: m_seq_nodes) {
300 // The m_nodes collection is the owner of anything that inherits
301 // from both INode and ISequenceNode. So we try not to delete
302 // anything that the m_nodes would already take care of.
303 if (!m_nodes.count(node.first)) delete node.second;
304 node.second = nullptr;
305 }
306 for (auto stack: m_stacks) {
307 delete stack.second;
308 stack.second = nullptr;
309 }
310 for (auto stack: m_seq_stacks) {
311 delete stack.second;
312 stack.second = nullptr;
313 }
314 }
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks
Definition Graph.h:142
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
Definition Graph.h:141

Member Function Documentation

◆ build_node()

void lwtDev::Graph::build_node ( const size_t iii,
const std::vector< NodeConfig > & nodes,
const std::vector< LayerConfig > & layers,
std::set< size_t > cycle_check = {} )
private

Definition at line 353 of file Graph.cxx.

356 {
357 if (m_nodes.count(iii) || m_seq_nodes.count(iii)) return;
358
359 // we insist that the upstream nodes are built before the
360 // downstream ones, so the last node built should be some kind of
361 // sink for graphs with only one output this will be it.
362 m_last_node = iii;
363
364 if (iii >= nodes.size()) throw_cfg("no node index", iii);
365
366 const NodeConfig& node = nodes.at(iii);
367
368 // if it's an input, build and return
369 if (node.type == NodeConfig::Type::INPUT) {
370 check_compute_node(node);
371 size_t input_number = node.sources.at(0);
372 m_nodes[iii] = new InputNode(input_number, node.index);
373 return;
374 } else if (node.type == NodeConfig::Type::INPUT_SEQUENCE) {
375 check_compute_node(node);
376 size_t input_number = node.sources.at(0);
377 m_seq_nodes[iii] = new InputSequenceNode(input_number, node.index);
378 return;
379 }
380
381 // otherwise build all the inputs first
382 if (cycle_check.count(iii)) {
383 throw NNConfigurationException("found cycle in graph");
384 }
385 cycle_check.insert(iii);
386 for (size_t source_node: node.sources) {
387 build_node(source_node, nodes, layers, cycle_check);
388 }
389
390 // check node types
391 if (node.type == NodeConfig::Type::FEED_FORWARD) {
392 m_nodes[iii] = get_feedforward_node(node, layers,
394 } else if (node.type == NodeConfig::Type::TIME_DISTRIBUTED) {
395 m_seq_nodes[iii] = get_time_distributed_node(node, layers,
397 } else if (node.type == NodeConfig::Type::SEQUENCE) {
398 SequenceNode* seq_node =
399 get_sequence_node(node, layers, m_seq_nodes, m_seq_stacks);
400 // entering in m_nodes means that m_nodes will delete this one
401 m_nodes[iii] = seq_node;
402 m_seq_nodes[iii] = seq_node;
403 } else if (node.type == NodeConfig::Type::CONCATENATE) {
404 // build concatenate layer
405 std::vector<const INode*> in_nodes;
406 for (size_t source_node: node.sources) {
407 in_nodes.push_back(m_nodes.at(source_node));
408 }
409 m_nodes[iii] = new ConcatenateNode(in_nodes);
410 } else if (node.type == NodeConfig::Type::SUM) {
411 if (node.sources.size() != 1) {
412 throw NNConfigurationException("Sum node needs exactly 1 source");
413 }
414 m_nodes[iii] = new SumNode(m_seq_nodes.at(node.sources.at(0)));
415 } else {
416 throw NNConfigurationException("unknown node type");
417 }
418 }

◆ compute() [1/2]

VectorXd lwtDev::Graph::compute ( const ISource & source) const

Definition at line 326 of file Graph.cxx.

326 {
327 if (!m_nodes.count(m_last_node)) {
328 throw OutputRankException("Graph: output is not a feed forward node");
329 }
330 return m_nodes.at(m_last_node)->compute(source);
331 }

◆ compute() [2/2]

VectorXd lwtDev::Graph::compute ( const ISource & source,
size_t node_number ) const

Definition at line 315 of file Graph.cxx.

315 {
316 if (!m_nodes.count(node_number)) {
317 auto num = std::to_string(node_number);
318 if (m_seq_nodes.count(node_number)) {
319 throw OutputRankException(
320 "Graph: output at " + num + " not feed forward");
321 }
322 throw NNEvaluationException("Graph: no output at " + num);
323 }
324 return m_nodes.at(node_number)->compute(source);
325 }

◆ operator=()

Graph & lwtDev::Graph::operator= ( Graph & )
delete

◆ scan() [1/2]

MatrixXd lwtDev::Graph::scan ( const ISource & source) const

Definition at line 343 of file Graph.cxx.

343 {
344 if (!m_seq_nodes.count(m_last_node)) {
345 throw OutputRankException("Graph: output is not a sequence node");
346 }
347 return m_seq_nodes.at(m_last_node)->scan(source);
348 }

◆ scan() [2/2]

MatrixXd lwtDev::Graph::scan ( const ISource & source,
size_t node_number ) const

Definition at line 332 of file Graph.cxx.

332 {
333 if (!m_seq_nodes.count(node_number)) {
334 auto num = std::to_string(node_number);
335 if (m_nodes.count(node_number)) {
336 throw OutputRankException(
337 "Graph: output at " + num + " not a sequence");
338 }
339 throw NNEvaluationException("Graph: no output at " + num);
340 }
341 return m_seq_nodes.at(node_number)->scan(source);
342 }

Member Data Documentation

◆ m_last_node

size_t lwtDev::Graph::m_last_node
private

Definition at line 139 of file Graph.h.

◆ m_nodes

std::unordered_map<size_t, INode*> lwtDev::Graph::m_nodes
private

Definition at line 138 of file Graph.h.

◆ m_seq_nodes

std::unordered_map<size_t, ISequenceNode*> lwtDev::Graph::m_seq_nodes
private

Definition at line 141 of file Graph.h.

◆ m_seq_stacks

std::unordered_map<size_t, RecurrentStack*> lwtDev::Graph::m_seq_stacks
private

Definition at line 142 of file Graph.h.

◆ m_stacks

std::unordered_map<size_t, Stack*> lwtDev::Graph::m_stacks
private

Definition at line 140 of file Graph.h.


The documentation for this class was generated from the following files: