ATLAS Offline Software
Public Member Functions | Private Attributes | List of all members
lwtDev::ReductionStack Class Reference

#include <Stack.h>

Collaboration diagram for lwtDev::ReductionStack:

Public Member Functions

 ReductionStack (size_t n_in, const std::vector< LayerConfig > &layers)
 
 ~ReductionStack ()
 
 ReductionStack (ReductionStack &)=delete
 
ReductionStackoperator= (ReductionStack &)=delete
 
VectorXd reduce (MatrixXd inputs) const
 
size_t n_outputs () const
 

Private Attributes

RecurrentStackm_recurrent
 
Stackm_stack
 

Detailed Description

Definition at line 195 of file Stack.h.

Constructor & Destructor Documentation

◆ ReductionStack() [1/2]

lwtDev::ReductionStack::ReductionStack ( size_t  n_in,
const std::vector< LayerConfig > &  layers 
)

Definition at line 396 of file Stack.cxx.

397  {
398  std::vector<LayerConfig> recurrent;
399  std::vector<LayerConfig> feed_forward;
400  std::set<Architecture> recurrent_arcs{
402  for (const auto& layer: layers) {
403  if (recurrent_arcs.count(layer.architecture)) {
404  recurrent.push_back(layer);
405  } else {
406  feed_forward.push_back(layer);
407  }
408  }
409  m_recurrent = new RecurrentStack(n_in, recurrent);
410  m_stack = new Stack(m_recurrent->n_outputs(), feed_forward);
411  }

◆ ~ReductionStack()

lwtDev::ReductionStack::~ReductionStack ( )

Definition at line 412 of file Stack.cxx.

412  {
413  delete m_recurrent;
414  delete m_stack;
415  }

◆ ReductionStack() [2/2]

lwtDev::ReductionStack::ReductionStack ( ReductionStack )
delete

Member Function Documentation

◆ n_outputs()

size_t lwtDev::ReductionStack::n_outputs ( ) const

Definition at line 420 of file Stack.cxx.

420  {
421  return m_stack->n_outputs();
422  }

◆ operator=()

ReductionStack& lwtDev::ReductionStack::operator= ( ReductionStack )
delete

◆ reduce()

VectorXd lwtDev::ReductionStack::reduce ( MatrixXd  inputs) const

Definition at line 416 of file Stack.cxx.

416  {
417  in = m_recurrent->scan(in);
418  return m_stack->compute(in.col(in.cols() -1));
419  }

Member Data Documentation

◆ m_recurrent

RecurrentStack* lwtDev::ReductionStack::m_recurrent
private

Definition at line 205 of file Stack.h.

◆ m_stack

Stack* lwtDev::ReductionStack::m_stack
private

Definition at line 206 of file Stack.h.


The documentation for this class was generated from the following files:
lwtDev::RecurrentStack::scan
MatrixXd scan(MatrixXd inputs) const
Definition: Stack.cxx:300
lwtDev::ReductionStack::m_stack
Stack * m_stack
Definition: Stack.h:206
module_driven_slicing.layers
layers
Definition: module_driven_slicing.py:114
lwtDev::Architecture::LSTM
@ LSTM
lwtDev::Stack::n_outputs
size_t n_outputs() const
Definition: Stack.cxx:57
lwtDev::ReductionStack::m_recurrent
RecurrentStack * m_recurrent
Definition: Stack.h:205
lwtDev::RecurrentStack::n_outputs
size_t n_outputs() const
Definition: Stack.cxx:306
TRT::Hit::layer
@ layer
Definition: HitInfo.h:79
lwtDev::RecurrentStack
Definition: Stack.h:174
lwtDev::Architecture::GRU
@ GRU
lwtDev::Stack
Definition: Stack.h:48
lwtDev::Architecture::EMBEDDING
@ EMBEDDING
lwtDev::Stack::compute
VectorXd compute(VectorXd) const
Definition: Stack.cxx:51