ATLAS Offline Software
Public Member Functions | Public Attributes | Private Attributes | List of all members
lwtDev::BidirectionalLayer Class Reference

bidirectional unit /// More...

#include <Stack.h>

Inheritance diagram for lwtDev::BidirectionalLayer:
Collaboration diagram for lwtDev::BidirectionalLayer:

Public Member Functions

 BidirectionalLayer (std::unique_ptr< IRecurrentLayer > forward_layer, std::unique_ptr< IRecurrentLayer > backward_layer, const std::string &merge_mode, bool return_sequence)
 bidirectional layer /// More...
 
virtual ~BidirectionalLayer ()
 
virtual MatrixXd scan (const MatrixXd &) const override
 

Public Attributes

bool m_go_backwards = false
 
bool m_return_sequence = false
 

Private Attributes

std::unique_ptr< const IRecurrentLayerm_forward_layer
 
std::unique_ptr< const IRecurrentLayerm_backward_layer
 
std::string m_merge_mode
 

Detailed Description

bidirectional unit ///

Definition at line 310 of file Stack.h.

Constructor & Destructor Documentation

◆ BidirectionalLayer()

lwtDev::BidirectionalLayer::BidirectionalLayer ( std::unique_ptr< IRecurrentLayer forward_layer,
std::unique_ptr< IRecurrentLayer backward_layer,
const std::string &  merge_mode,
bool  return_sequence 
)

bidirectional layer ///

Definition at line 611 of file Stack.cxx.

614  :
615  m_forward_layer(std::move(forward_layer)),
616  m_backward_layer(std::move(backward_layer)),
617  m_merge_mode(merge_mode)
618  {
619  //baseclass variable
620  m_return_sequence=return_sequence;
621  }

◆ ~BidirectionalLayer()

virtual lwtDev::BidirectionalLayer::~BidirectionalLayer ( )
inlinevirtual

Definition at line 318 of file Stack.h.

318 {};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::BidirectionalLayer::scan ( const MatrixXd &  x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 623 of file Stack.cxx.

623  {
624  const MatrixXd & forward = m_forward_layer->scan(x);
625  const MatrixXd & backward = m_backward_layer->scan(x);
626  MatrixXd backward_rev;
627  if (m_return_sequence){
628  backward_rev = backward.rowwise().reverse();
629  }else{
630  backward_rev = backward;
631  }
632 
633  if(m_merge_mode == "mul")
634  return forward.array()*backward_rev.array();
635  else if(m_merge_mode == "sum")
636  return forward.array() + backward_rev.array();
637  else if(m_merge_mode == "ave")
638  return (forward.array() + backward_rev.array())/2.;
639  else if(m_merge_mode == "concat"){
640  MatrixXd concatMatr(forward.rows(), forward.cols()+backward_rev.cols());
641  concatMatr << forward, backward_rev;
642  return concatMatr;
643  }else
644  throw NNEvaluationException(
645  "Merge mode "+m_merge_mode+"not implemented. Choose one of [mul, sum, ave, concat]");
646 
647  // mute compiler
648  return forward;
649  }

Member Data Documentation

◆ m_backward_layer

std::unique_ptr<const IRecurrentLayer> lwtDev::BidirectionalLayer::m_backward_layer
private

Definition at line 323 of file Stack.h.

◆ m_forward_layer

std::unique_ptr<const IRecurrentLayer> lwtDev::BidirectionalLayer::m_forward_layer
private

Definition at line 322 of file Stack.h.

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_merge_mode

std::string lwtDev::BidirectionalLayer::m_merge_mode
private

Definition at line 325 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.


The documentation for this class was generated from the following files:
lwtDev::BidirectionalLayer::m_backward_layer
std::unique_ptr< const IRecurrentLayer > m_backward_layer
Definition: Stack.h:323
lwtDev::BidirectionalLayer::m_forward_layer
std::unique_ptr< const IRecurrentLayer > m_forward_layer
Definition: Stack.h:322
x
#define x
lwtDev::IRecurrentLayer::m_return_sequence
bool m_return_sequence
Definition: Stack.h:219
lwtDev::NNEvaluationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:27
lwtDev::BidirectionalLayer::m_merge_mode
std::string m_merge_mode
Definition: Stack.h:325