ATLAS Offline Software
Loading...
Searching...
No Matches
lwtDev::BidirectionalLayer Class Reference

bidirectional unit /// More...

#include <Stack.h>

Inheritance diagram for lwtDev::BidirectionalLayer:
Collaboration diagram for lwtDev::BidirectionalLayer:

Public Member Functions

 BidirectionalLayer (std::unique_ptr< IRecurrentLayer > forward_layer, std::unique_ptr< IRecurrentLayer > backward_layer, const std::string &merge_mode, bool return_sequence)
 bidirectional layer ///
virtual ~BidirectionalLayer ()
virtual MatrixXd scan (const MatrixXd &) const override

Public Attributes

bool m_go_backwards = false
bool m_return_sequence = false

Private Attributes

std::unique_ptr< const IRecurrentLayerm_forward_layer
std::unique_ptr< const IRecurrentLayerm_backward_layer
std::string m_merge_mode

Detailed Description

bidirectional unit ///

Definition at line 310 of file Stack.h.

Constructor & Destructor Documentation

◆ BidirectionalLayer()

lwtDev::BidirectionalLayer::BidirectionalLayer ( std::unique_ptr< IRecurrentLayer > forward_layer,
std::unique_ptr< IRecurrentLayer > backward_layer,
const std::string & merge_mode,
bool return_sequence )

bidirectional layer ///

Definition at line 611 of file Stack.cxx.

614 :
615 m_forward_layer(std::move(forward_layer)),
616 m_backward_layer(std::move(backward_layer)),
617 m_merge_mode(merge_mode)
618 {
619 //baseclass variable
620 m_return_sequence=return_sequence;
621 }
std::unique_ptr< const IRecurrentLayer > m_backward_layer
Definition Stack.h:323
std::string m_merge_mode
Definition Stack.h:325
std::unique_ptr< const IRecurrentLayer > m_forward_layer
Definition Stack.h:322

◆ ~BidirectionalLayer()

virtual lwtDev::BidirectionalLayer::~BidirectionalLayer ( )
inlinevirtual

Definition at line 318 of file Stack.h.

318{};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::BidirectionalLayer::scan ( const MatrixXd & x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 623 of file Stack.cxx.

623 {
624 const MatrixXd & forward = m_forward_layer->scan(x);
625 const MatrixXd & backward = m_backward_layer->scan(x);
626 MatrixXd backward_rev;
628 backward_rev = backward.rowwise().reverse();
629 }else{
630 backward_rev = backward;
631 }
632
633 if(m_merge_mode == "mul")
634 return forward.array()*backward_rev.array();
635 else if(m_merge_mode == "sum")
636 return forward.array() + backward_rev.array();
637 else if(m_merge_mode == "ave")
638 return (forward.array() + backward_rev.array())/2.;
639 else if(m_merge_mode == "concat"){
640 MatrixXd concatMatr(forward.rows(), forward.cols()+backward_rev.cols());
641 concatMatr << forward, backward_rev;
642 return concatMatr;
643 }else
644 throw NNEvaluationException(
645 "Merge mode "+m_merge_mode+"not implemented. Choose one of [mul, sum, ave, concat]");
646
647 // mute compiler
648 return forward;
649 }
#define x

Member Data Documentation

◆ m_backward_layer

std::unique_ptr<const IRecurrentLayer> lwtDev::BidirectionalLayer::m_backward_layer
private

Definition at line 323 of file Stack.h.

◆ m_forward_layer

std::unique_ptr<const IRecurrentLayer> lwtDev::BidirectionalLayer::m_forward_layer
private

Definition at line 322 of file Stack.h.

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_merge_mode

std::string lwtDev::BidirectionalLayer::m_merge_mode
private

Definition at line 325 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.


The documentation for this class was generated from the following files: