ATLAS Offline Software
Public Member Functions | Private Member Functions | Private Attributes | List of all members
lwtDev::Graph Class Reference

#include <Graph.h>

Collaboration diagram for lwtDev::Graph:

Public Member Functions

 Graph ()
 
 Graph (const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers)
 
 Graph (Graph &)=delete
 
Graphoperator= (Graph &)=delete
 
 ~Graph ()
 
VectorXd compute (const ISource &, size_t node_number) const
 
VectorXd compute (const ISource &) const
 
MatrixXd scan (const ISource &, size_t node_number) const
 
MatrixXd scan (const ISource &) const
 

Private Member Functions

void build_node (const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})
 

Private Attributes

std::unordered_map< size_t, INode * > m_nodes
 
size_t m_last_node
 
std::unordered_map< size_t, Stack * > m_stacks
 
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
 
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks
 

Detailed Description

Definition at line 119 of file Graph.h.

Constructor & Destructor Documentation

◆ Graph() [1/3]

lwtDev::Graph::Graph ( )

Definition at line 276 of file Graph.cxx.

276  {
277  m_stacks[0] = new Stack;
278 
279  m_nodes[0] = new InputNode(0, 2);
280  m_nodes[1] = new InputNode(1, 2);
281  m_nodes[2] = new ConcatenateNode({m_nodes.at(0), m_nodes.at(1)});
282  m_nodes[3] = new FeedForwardNode(m_stacks.at(0), m_nodes.at(2));
283  m_last_node = 3;
284  }

◆ Graph() [2/3]

lwtDev::Graph::Graph ( const std::vector< NodeConfig > &  nodes,
const std::vector< LayerConfig > &  layers 
)

Definition at line 285 of file Graph.cxx.

286  :
287  m_last_node(0)
288  {
289  for (size_t iii = 0; iii < nodes.size(); iii++) {
290  build_node(iii, nodes, layers);
291  }
292  // assert(maps.node.size() + maps.seq_node.size() == nodes.size());
293  }

◆ Graph() [3/3]

lwtDev::Graph::Graph ( Graph )
delete

◆ ~Graph()

lwtDev::Graph::~Graph ( )

Definition at line 294 of file Graph.cxx.

294  {
295  for (auto node: m_nodes) {
296  delete node.second;
297  node.second = nullptr;
298  }
299  for (auto node: m_seq_nodes) {
300  // The m_nodes collection is the owner of anything that inherits
301  // from both INode and ISequenceNode. So we try not to delete
302  // anything that the m_nodes would already take care of.
303  if (!m_nodes.count(node.first)) delete node.second;
304  node.second = nullptr;
305  }
306  for (auto stack: m_stacks) {
307  delete stack.second;
308  stack.second = nullptr;
309  }
310  for (auto stack: m_seq_stacks) {
311  delete stack.second;
312  stack.second = nullptr;
313  }
314  }

Member Function Documentation

◆ build_node()

void lwtDev::Graph::build_node ( const size_t  iii,
const std::vector< NodeConfig > &  nodes,
const std::vector< LayerConfig > &  layers,
std::set< size_t >  cycle_check = {} 
)
private

Definition at line 353 of file Graph.cxx.

356  {
357  if (m_nodes.count(iii) || m_seq_nodes.count(iii)) return;
358 
359  // we insist that the upstream nodes are built before the
360  // downstream ones, so the last node built should be some kind of
361  // sink for graphs with only one output this will be it.
362  m_last_node = iii;
363 
364  if (iii >= nodes.size()) throw_cfg("no node index", iii);
365 
366  const NodeConfig& node = nodes.at(iii);
367 
368  // if it's an input, build and return
370  check_compute_node(node);
371  size_t input_number = node.sources.at(0);
372  m_nodes[iii] = new InputNode(input_number, node.index);
373  return;
375  check_compute_node(node);
376  size_t input_number = node.sources.at(0);
377  m_seq_nodes[iii] = new InputSequenceNode(input_number, node.index);
378  return;
379  }
380 
381  // otherwise build all the inputs first
382  if (cycle_check.count(iii)) {
383  throw NNConfigurationException("found cycle in graph");
384  }
385  cycle_check.insert(iii);
386  for (size_t source_node: node.sources) {
387  build_node(source_node, nodes, layers, cycle_check);
388  }
389 
390  // check node types
392  m_nodes[iii] = get_feedforward_node(node, layers,
393  m_nodes, m_stacks);
395  m_seq_nodes[iii] = get_time_distributed_node(node, layers,
397  } else if (node.type == NodeConfig::Type::SEQUENCE) {
398  std::unique_ptr<SequenceNode> seq_node(
399  get_sequence_node(node, layers, m_seq_nodes, m_seq_stacks));
400  // entering in m_nodes means that m_nodes will delete this one
401  m_nodes[iii] = nullptr;
402  m_seq_nodes[iii] = seq_node.get();
403  m_nodes[iii] = seq_node.release();
404  } else if (node.type == NodeConfig::Type::CONCATENATE) {
405  // build concatenate layer
406  std::vector<const INode*> in_nodes;
407  for (size_t source_node: node.sources) {
408  in_nodes.push_back(m_nodes.at(source_node));
409  }
410  m_nodes[iii] = new ConcatenateNode(in_nodes);
411  } else if (node.type == NodeConfig::Type::SUM) {
412  if (node.sources.size() != 1) {
413  throw NNConfigurationException("Sum node needs exactly 1 source");
414  }
415  m_nodes[iii] = new SumNode(m_seq_nodes.at(node.sources.at(0)));
416  } else {
417  throw NNConfigurationException("unknown node type");
418  }
419  }

◆ compute() [1/2]

VectorXd lwtDev::Graph::compute ( const ISource source) const

Definition at line 326 of file Graph.cxx.

326  {
327  if (!m_nodes.count(m_last_node)) {
328  throw OutputRankException("Graph: output is not a feed forward node");
329  }
330  return m_nodes.at(m_last_node)->compute(source);
331  }

◆ compute() [2/2]

VectorXd lwtDev::Graph::compute ( const ISource source,
size_t  node_number 
) const

Definition at line 315 of file Graph.cxx.

315  {
316  if (!m_nodes.count(node_number)) {
317  auto num = std::to_string(node_number);
318  if (m_seq_nodes.count(node_number)) {
319  throw OutputRankException(
320  "Graph: output at " + num + " not feed forward");
321  }
322  throw NNEvaluationException("Graph: no output at " + num);
323  }
324  return m_nodes.at(node_number)->compute(source);
325  }

◆ operator=()

Graph& lwtDev::Graph::operator= ( Graph )
delete

◆ scan() [1/2]

MatrixXd lwtDev::Graph::scan ( const ISource source) const

Definition at line 343 of file Graph.cxx.

343  {
344  if (!m_seq_nodes.count(m_last_node)) {
345  throw OutputRankException("Graph: output is not a sequence node");
346  }
347  return m_seq_nodes.at(m_last_node)->scan(source);
348  }

◆ scan() [2/2]

MatrixXd lwtDev::Graph::scan ( const ISource source,
size_t  node_number 
) const

Definition at line 332 of file Graph.cxx.

332  {
333  if (!m_seq_nodes.count(node_number)) {
334  auto num = std::to_string(node_number);
335  if (m_nodes.count(node_number)) {
336  throw OutputRankException(
337  "Graph: output at " + num + " not a sequence");
338  }
339  throw NNEvaluationException("Graph: no output at " + num);
340  }
341  return m_seq_nodes.at(node_number)->scan(source);
342  }

Member Data Documentation

◆ m_last_node

size_t lwtDev::Graph::m_last_node
private

Definition at line 139 of file Graph.h.

◆ m_nodes

std::unordered_map<size_t, INode*> lwtDev::Graph::m_nodes
private

Definition at line 138 of file Graph.h.

◆ m_seq_nodes

std::unordered_map<size_t, ISequenceNode*> lwtDev::Graph::m_seq_nodes
private

Definition at line 141 of file Graph.h.

◆ m_seq_stacks

std::unordered_map<size_t, RecurrentStack*> lwtDev::Graph::m_seq_stacks
private

Definition at line 142 of file Graph.h.

◆ m_stacks

std::unordered_map<size_t, Stack*> lwtDev::Graph::m_stacks
private

Definition at line 140 of file Graph.h.


The documentation for this class was generated from the following files:
node::type
void type(TYPE t)
Definition: node.h:48
lwtDev::SumNode
Definition: Graph.h:109
lwtDev::NodeConfig::Type::INPUT_SEQUENCE
@ INPUT_SEQUENCE
lwtDev::Graph::m_nodes
std::unordered_map< size_t, INode * > m_nodes
Definition: Graph.h:138
lwtDev::Graph::m_last_node
size_t m_last_node
Definition: Graph.h:139
lwtDev::NodeConfig::Type::SUM
@ SUM
module_driven_slicing.layers
layers
Definition: module_driven_slicing.py:114
lwtDev::NodeConfig::Type::SEQUENCE
@ SEQUENCE
lwtDev::InputSequenceNode
Definition: Graph.h:76
lwtDev::Graph::build_node
void build_node(const size_t, const std::vector< NodeConfig > &nodes, const std::vector< LayerConfig > &layers, std::set< size_t > cycle_check={})
Definition: Graph.cxx:353
lwtDev::NNConfigurationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:21
lwtDev::NodeConfig::Type::CONCATENATE
@ CONCATENATE
lwtDev::Graph::m_stacks
std::unordered_map< size_t, Stack * > m_stacks
Definition: Graph.h:140
lwtDev::NodeConfig::Type::TIME_DISTRIBUTED
@ TIME_DISTRIBUTED
lwtDev::OutputRankException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:31
lwtDev::NodeConfig::Type::INPUT
@ INPUT
lwtDev::FeedForwardNode
Definition: Graph.h:45
trigbs_pickEvents.num
num
Definition: trigbs_pickEvents.py:76
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lwtDev::NNEvaluationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:27
lwtDev::InputNode
Definition: Graph.h:34
lwtDev::Graph::m_seq_nodes
std::unordered_map< size_t, ISequenceNode * > m_seq_nodes
Definition: Graph.h:141
lwtDev::Stack
Definition: Stack.h:48
lwtDev::ConcatenateNode
Definition: Graph.h:56
copySelective.source
string source
Definition: copySelective.py:32
lwtDev::NodeConfig::Type::FEED_FORWARD
@ FEED_FORWARD
lwtDev::Graph::m_seq_stacks
std::unordered_map< size_t, RecurrentStack * > m_seq_stacks
Definition: Graph.h:142
node
Definition: memory_hooks-stdcmalloc.h:74
lwtDev::NodeConfig
Definition: NNLayerConfig.h:69