ATLAS Offline Software
Public Member Functions | Public Attributes | Private Attributes | List of all members
lwtDev::LSTMLayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::LSTMLayer:
Collaboration diagram for lwtDev::LSTMLayer:

Public Member Functions

 LSTMLayer (const ActivationConfig &activation, const ActivationConfig &inner_activation, const MatrixXd &W_i, const MatrixXd &U_i, const VectorXd &b_i, const MatrixXd &W_f, const MatrixXd &U_f, const VectorXd &b_f, const MatrixXd &W_o, const MatrixXd &U_o, const VectorXd &b_o, const MatrixXd &W_c, const MatrixXd &U_c, const VectorXd &b_c, bool go_backwards, bool return_sequence)
 
virtual ~LSTMLayer ()
 
virtual MatrixXd scan (const MatrixXd &) const override
 
void step (const VectorXd &input, LSTMState &) const
 

Public Attributes

bool m_go_backwards = false
 
bool m_return_sequence = false
 

Private Attributes

std::function< double(double)> m_activation_fun
 
std::function< double(double)> m_inner_activation_fun
 
MatrixXd m_W_i
 
MatrixXd m_U_i
 
VectorXd m_b_i
 
MatrixXd m_W_f
 
MatrixXd m_U_f
 
VectorXd m_b_f
 
MatrixXd m_W_o
 
MatrixXd m_U_o
 
VectorXd m_b_o
 
MatrixXd m_W_c
 
MatrixXd m_U_c
 
VectorXd m_b_c
 
int m_n_outputs
 

Detailed Description

Definition at line 236 of file Stack.h.

Constructor & Destructor Documentation

◆ LSTMLayer()

lwtDev::LSTMLayer::LSTMLayer ( const ActivationConfig activation,
const ActivationConfig inner_activation,
const MatrixXd &  W_i,
const MatrixXd &  U_i,
const VectorXd &  b_i,
const MatrixXd &  W_f,
const MatrixXd &  U_f,
const VectorXd &  b_f,
const MatrixXd &  W_o,
const MatrixXd &  U_o,
const VectorXd &  b_o,
const MatrixXd &  W_c,
const MatrixXd &  U_c,
const VectorXd &  b_c,
bool  go_backwards,
bool  return_sequence 
)

Definition at line 473 of file Stack.cxx.

480  :
481  m_W_i(W_i),
482  m_U_i(U_i),
483  m_b_i(b_i),
484  m_W_f(W_f),
485  m_U_f(U_f),
486  m_b_f(b_f),
487  m_W_o(W_o),
488  m_U_o(U_o),
489  m_b_o(b_o),
490  m_W_c(W_c),
491  m_U_c(U_c),
492  m_b_c(b_c)
493  {
494  m_n_outputs = m_W_o.rows();
495 
496  m_activation_fun = get_activation(activation);
497  m_inner_activation_fun = get_activation(inner_activation);
498  m_go_backwards = go_backwards;
499  m_return_sequence = return_sequence;
500  }

◆ ~LSTMLayer()

virtual lwtDev::LSTMLayer::~LSTMLayer ( )
inlinevirtual

Definition at line 248 of file Stack.h.

248 {};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::LSTMLayer::scan ( const MatrixXd &  x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 535 of file Stack.cxx.

535  {
536  LSTMState state(x.cols(), m_n_outputs);
537 
538  for(state.time = 0; state.time < x.cols(); state.time++) {
539  if(m_go_backwards)
540  step( x.col( x.cols() -1 - state.time ), state );
541  else
542  step( x.col( state.time ), state );
543  }
544 
545  return state.h_t;
546  }

◆ step()

void lwtDev::LSTMLayer::step ( const VectorXd &  input,
LSTMState s 
) const

Definition at line 516 of file Stack.cxx.

516  {
517  // https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py#L740
518 
519  const auto& act_fun = m_activation_fun;
520  const auto& in_act_fun = m_inner_activation_fun;
521 
522  int tm1 = s.time == 0 ? 0 : s.time - 1;
523  VectorXd h_tm1 = s.h_t.col(tm1);
524  VectorXd C_tm1 = s.C_t.col(tm1);
525 
526  VectorXd i = (m_W_i*x_t + m_b_i + m_U_i*h_tm1).unaryExpr(in_act_fun);
527  VectorXd f = (m_W_f*x_t + m_b_f + m_U_f*h_tm1).unaryExpr(in_act_fun);
528  VectorXd o = (m_W_o*x_t + m_b_o + m_U_o*h_tm1).unaryExpr(in_act_fun);
529  VectorXd ct = (m_W_c*x_t + m_b_c + m_U_c*h_tm1).unaryExpr(act_fun);
530 
531  s.C_t.col(s.time) = f.cwiseProduct(C_tm1) + i.cwiseProduct(ct);
532  s.h_t.col(s.time) = o.cwiseProduct(s.C_t.col(s.time).unaryExpr(act_fun));
533  }

Member Data Documentation

◆ m_activation_fun

std::function<double(double)> lwtDev::LSTMLayer::m_activation_fun
private

Definition at line 253 of file Stack.h.

◆ m_b_c

VectorXd lwtDev::LSTMLayer::m_b_c
private

Definition at line 270 of file Stack.h.

◆ m_b_f

VectorXd lwtDev::LSTMLayer::m_b_f
private

Definition at line 262 of file Stack.h.

◆ m_b_i

VectorXd lwtDev::LSTMLayer::m_b_i
private

Definition at line 258 of file Stack.h.

◆ m_b_o

VectorXd lwtDev::LSTMLayer::m_b_o
private

Definition at line 266 of file Stack.h.

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_inner_activation_fun

std::function<double(double)> lwtDev::LSTMLayer::m_inner_activation_fun
private

Definition at line 254 of file Stack.h.

◆ m_n_outputs

int lwtDev::LSTMLayer::m_n_outputs
private

Definition at line 272 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.

◆ m_U_c

MatrixXd lwtDev::LSTMLayer::m_U_c
private

Definition at line 269 of file Stack.h.

◆ m_U_f

MatrixXd lwtDev::LSTMLayer::m_U_f
private

Definition at line 261 of file Stack.h.

◆ m_U_i

MatrixXd lwtDev::LSTMLayer::m_U_i
private

Definition at line 257 of file Stack.h.

◆ m_U_o

MatrixXd lwtDev::LSTMLayer::m_U_o
private

Definition at line 265 of file Stack.h.

◆ m_W_c

MatrixXd lwtDev::LSTMLayer::m_W_c
private

Definition at line 268 of file Stack.h.

◆ m_W_f

MatrixXd lwtDev::LSTMLayer::m_W_f
private

Definition at line 260 of file Stack.h.

◆ m_W_i

MatrixXd lwtDev::LSTMLayer::m_W_i
private

Definition at line 256 of file Stack.h.

◆ m_W_o

MatrixXd lwtDev::LSTMLayer::m_W_o
private

Definition at line 264 of file Stack.h.


The documentation for this class was generated from the following files:
lwtDev::LSTMLayer::m_U_f
MatrixXd m_U_f
Definition: Stack.h:261
python.SystemOfUnits.s
int s
Definition: SystemOfUnits.py:131
lwtDev::LSTMState
Definition: Stack.cxx:503
lwtDev::LSTMLayer::m_b_f
VectorXd m_b_f
Definition: Stack.h:262
lwtDev::LSTMLayer::m_inner_activation_fun
std::function< double(double)> m_inner_activation_fun
Definition: Stack.h:254
lwtDev::LSTMLayer::m_b_i
VectorXd m_b_i
Definition: Stack.h:258
lwtDev::LSTMLayer::m_U_c
MatrixXd m_U_c
Definition: Stack.h:269
lwtDev::LSTMLayer::m_U_o
MatrixXd m_U_o
Definition: Stack.h:265
x
#define x
lwtDev::IRecurrentLayer::m_go_backwards
bool m_go_backwards
Definition: Stack.h:218
lwtDev::LSTMLayer::m_n_outputs
int m_n_outputs
Definition: Stack.h:272
lwtDev::LSTMLayer::m_W_o
MatrixXd m_W_o
Definition: Stack.h:264
lwtDev::LSTMLayer::m_b_o
VectorXd m_b_o
Definition: Stack.h:266
lwtDev::LSTMLayer::m_W_i
MatrixXd m_W_i
Definition: Stack.h:256
lumiFormat.i
int i
Definition: lumiFormat.py:85
lwtDev::LSTMLayer::m_activation_fun
std::function< double(double)> m_activation_fun
Definition: Stack.h:253
lwtDev::LSTMLayer::m_W_f
MatrixXd m_W_f
Definition: Stack.h:260
lwtDev::IRecurrentLayer::m_return_sequence
bool m_return_sequence
Definition: Stack.h:219
hist_file_dump.f
f
Definition: hist_file_dump.py:135
lwtDev::LSTMLayer::step
void step(const VectorXd &input, LSTMState &) const
Definition: Stack.cxx:516
calibdata.ct
ct
Definition: calibdata.py:418
lwtDev::LSTMLayer::m_U_i
MatrixXd m_U_i
Definition: Stack.h:257
lwtDev::get_activation
std::function< double(double)> get_activation(lwtDev::ActivationConfig act)
Definition: Stack.cxx:671
lwtDev::LSTMLayer::m_W_c
MatrixXd m_W_c
Definition: Stack.h:268
lwtDev::LSTMLayer::m_b_c
VectorXd m_b_c
Definition: Stack.h:270