ATLAS Offline Software
Loading...
Searching...
No Matches
lwtDev::LSTMLayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::LSTMLayer:
Collaboration diagram for lwtDev::LSTMLayer:

Public Member Functions

 LSTMLayer (const ActivationConfig &activation, const ActivationConfig &inner_activation, const MatrixXd &W_i, const MatrixXd &U_i, const VectorXd &b_i, const MatrixXd &W_f, const MatrixXd &U_f, const VectorXd &b_f, const MatrixXd &W_o, const MatrixXd &U_o, const VectorXd &b_o, const MatrixXd &W_c, const MatrixXd &U_c, const VectorXd &b_c, bool go_backwards, bool return_sequence)
virtual ~LSTMLayer ()
virtual MatrixXd scan (const MatrixXd &) const override
void step (const VectorXd &input, LSTMState &) const

Public Attributes

bool m_go_backwards = false
bool m_return_sequence = false

Private Attributes

std::function< double(double)> m_activation_fun
std::function< double(double)> m_inner_activation_fun
MatrixXd m_W_i
MatrixXd m_U_i
VectorXd m_b_i
MatrixXd m_W_f
MatrixXd m_U_f
VectorXd m_b_f
MatrixXd m_W_o
MatrixXd m_U_o
VectorXd m_b_o
MatrixXd m_W_c
MatrixXd m_U_c
VectorXd m_b_c
int m_n_outputs

Detailed Description

Definition at line 236 of file Stack.h.

Constructor & Destructor Documentation

◆ LSTMLayer()

lwtDev::LSTMLayer::LSTMLayer ( const ActivationConfig & activation,
const ActivationConfig & inner_activation,
const MatrixXd & W_i,
const MatrixXd & U_i,
const VectorXd & b_i,
const MatrixXd & W_f,
const MatrixXd & U_f,
const VectorXd & b_f,
const MatrixXd & W_o,
const MatrixXd & U_o,
const VectorXd & b_o,
const MatrixXd & W_c,
const MatrixXd & U_c,
const VectorXd & b_c,
bool go_backwards,
bool return_sequence )

Definition at line 473 of file Stack.cxx.

480 :
481 m_W_i(W_i),
482 m_U_i(U_i),
483 m_b_i(b_i),
484 m_W_f(W_f),
485 m_U_f(U_f),
486 m_b_f(b_f),
487 m_W_o(W_o),
488 m_U_o(U_o),
489 m_b_o(b_o),
490 m_W_c(W_c),
491 m_U_c(U_c),
492 m_b_c(b_c)
493 {
494 m_n_outputs = m_W_o.rows();
495
496 m_activation_fun = get_activation(activation);
497 m_inner_activation_fun = get_activation(inner_activation);
498 m_go_backwards = go_backwards;
499 m_return_sequence = return_sequence;
500 }
MatrixXd m_U_c
Definition Stack.h:269
VectorXd m_b_i
Definition Stack.h:258
MatrixXd m_W_i
Definition Stack.h:256
std::function< double(double)> m_inner_activation_fun
Definition Stack.h:254
MatrixXd m_W_c
Definition Stack.h:268
MatrixXd m_U_i
Definition Stack.h:257
MatrixXd m_W_f
Definition Stack.h:260
std::function< double(double)> m_activation_fun
Definition Stack.h:253
MatrixXd m_U_o
Definition Stack.h:265
VectorXd m_b_o
Definition Stack.h:266
MatrixXd m_W_o
Definition Stack.h:264
VectorXd m_b_f
Definition Stack.h:262
MatrixXd m_U_f
Definition Stack.h:261
VectorXd m_b_c
Definition Stack.h:270
std::function< double(double)> get_activation(lwtDev::ActivationConfig act)
Definition Stack.cxx:671

◆ ~LSTMLayer()

virtual lwtDev::LSTMLayer::~LSTMLayer ( )
inlinevirtual

Definition at line 248 of file Stack.h.

248{};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::LSTMLayer::scan ( const MatrixXd & x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 535 of file Stack.cxx.

535 {
536 LSTMState state(x.cols(), m_n_outputs);
537
538 for(state.time = 0; state.time < x.cols(); state.time++) {
540 step( x.col( x.cols() -1 - state.time ), state );
541 else
542 step( x.col( state.time ), state );
543 }
544
545 return state.h_t;
546 }
#define x
void step(const VectorXd &input, LSTMState &) const
Definition Stack.cxx:516

◆ step()

void lwtDev::LSTMLayer::step ( const VectorXd & input,
LSTMState & s ) const

Definition at line 516 of file Stack.cxx.

516 {
517 // https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py#L740
518
519 const auto& act_fun = m_activation_fun;
520 const auto& in_act_fun = m_inner_activation_fun;
521
522 int tm1 = s.time == 0 ? 0 : s.time - 1;
523 VectorXd h_tm1 = s.h_t.col(tm1);
524 VectorXd C_tm1 = s.C_t.col(tm1);
525
526 VectorXd i = (m_W_i*x_t + m_b_i + m_U_i*h_tm1).unaryExpr(in_act_fun);
527 VectorXd f = (m_W_f*x_t + m_b_f + m_U_f*h_tm1).unaryExpr(in_act_fun);
528 VectorXd o = (m_W_o*x_t + m_b_o + m_U_o*h_tm1).unaryExpr(in_act_fun);
529 VectorXd ct = (m_W_c*x_t + m_b_c + m_U_c*h_tm1).unaryExpr(act_fun);
530
531 s.C_t.col(s.time) = f.cwiseProduct(C_tm1) + i.cwiseProduct(ct);
532 s.h_t.col(s.time) = o.cwiseProduct(s.C_t.col(s.time).unaryExpr(act_fun));
533 }

Member Data Documentation

◆ m_activation_fun

std::function<double(double)> lwtDev::LSTMLayer::m_activation_fun
private

Definition at line 253 of file Stack.h.

◆ m_b_c

VectorXd lwtDev::LSTMLayer::m_b_c
private

Definition at line 270 of file Stack.h.

◆ m_b_f

VectorXd lwtDev::LSTMLayer::m_b_f
private

Definition at line 262 of file Stack.h.

◆ m_b_i

VectorXd lwtDev::LSTMLayer::m_b_i
private

Definition at line 258 of file Stack.h.

◆ m_b_o

VectorXd lwtDev::LSTMLayer::m_b_o
private

Definition at line 266 of file Stack.h.

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_inner_activation_fun

std::function<double(double)> lwtDev::LSTMLayer::m_inner_activation_fun
private

Definition at line 254 of file Stack.h.

◆ m_n_outputs

int lwtDev::LSTMLayer::m_n_outputs
private

Definition at line 272 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.

◆ m_U_c

MatrixXd lwtDev::LSTMLayer::m_U_c
private

Definition at line 269 of file Stack.h.

◆ m_U_f

MatrixXd lwtDev::LSTMLayer::m_U_f
private

Definition at line 261 of file Stack.h.

◆ m_U_i

MatrixXd lwtDev::LSTMLayer::m_U_i
private

Definition at line 257 of file Stack.h.

◆ m_U_o

MatrixXd lwtDev::LSTMLayer::m_U_o
private

Definition at line 265 of file Stack.h.

◆ m_W_c

MatrixXd lwtDev::LSTMLayer::m_W_c
private

Definition at line 268 of file Stack.h.

◆ m_W_f

MatrixXd lwtDev::LSTMLayer::m_W_f
private

Definition at line 260 of file Stack.h.

◆ m_W_i

MatrixXd lwtDev::LSTMLayer::m_W_i
private

Definition at line 256 of file Stack.h.

◆ m_W_o

MatrixXd lwtDev::LSTMLayer::m_W_o
private

Definition at line 264 of file Stack.h.


The documentation for this class was generated from the following files: