ATLAS Offline Software
Loading...
Searching...
No Matches
lwtDev::GRULayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::GRULayer:
Collaboration diagram for lwtDev::GRULayer:

Public Member Functions

 GRULayer (const ActivationConfig &activation, const ActivationConfig &inner_activation, const MatrixXd &W_z, const MatrixXd &U_z, const VectorXd &b_z, const MatrixXd &W_r, const MatrixXd &U_r, const VectorXd &b_r, const MatrixXd &W_h, const MatrixXd &U_h, const VectorXd &b_h)
virtual ~GRULayer ()
virtual MatrixXd scan (const MatrixXd &) const override
void step (const VectorXd &input, GRUState &) const

Public Attributes

bool m_go_backwards = false
bool m_return_sequence = false

Private Attributes

std::function< double(double)> m_activation_fun
std::function< double(double)> m_inner_activation_fun
MatrixXd m_W_z
MatrixXd m_U_z
VectorXd m_b_z
MatrixXd m_W_r
MatrixXd m_U_r
VectorXd m_b_r
MatrixXd m_W_h
MatrixXd m_U_h
VectorXd m_b_h
int m_n_outputs

Detailed Description

Definition at line 277 of file Stack.h.

Constructor & Destructor Documentation

◆ GRULayer()

lwtDev::GRULayer::GRULayer ( const ActivationConfig & activation,
const ActivationConfig & inner_activation,
const MatrixXd & W_z,
const MatrixXd & U_z,
const VectorXd & b_z,
const MatrixXd & W_r,
const MatrixXd & U_r,
const VectorXd & b_r,
const MatrixXd & W_h,
const MatrixXd & U_h,
const VectorXd & b_h )

Definition at line 550 of file Stack.cxx.

554 :
555 m_W_z(W_z),
556 m_U_z(U_z),
557 m_b_z(b_z),
558 m_W_r(W_r),
559 m_U_r(U_r),
560 m_b_r(b_r),
561 m_W_h(W_h),
562 m_U_h(U_h),
563 m_b_h(b_h)
564 {
565 m_n_outputs = m_W_h.rows();
566
567 m_activation_fun = get_activation(activation);
568 m_inner_activation_fun = get_activation(inner_activation);
569 }
std::function< double(double)> m_inner_activation_fun
Definition Stack.h:292
MatrixXd m_W_z
Definition Stack.h:294
MatrixXd m_W_h
Definition Stack.h:302
std::function< double(double)> m_activation_fun
Definition Stack.h:291
MatrixXd m_W_r
Definition Stack.h:298
MatrixXd m_U_z
Definition Stack.h:295
MatrixXd m_U_h
Definition Stack.h:303
VectorXd m_b_h
Definition Stack.h:304
MatrixXd m_U_r
Definition Stack.h:299
VectorXd m_b_r
Definition Stack.h:300
VectorXd m_b_z
Definition Stack.h:296
std::function< double(double)> get_activation(lwtDev::ActivationConfig act)
Definition Stack.cxx:671

◆ ~GRULayer()

virtual lwtDev::GRULayer::~GRULayer ( )
inlinevirtual

Definition at line 286 of file Stack.h.

286{};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::GRULayer::scan ( const MatrixXd & x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 598 of file Stack.cxx.

598 {
599
600 GRUState state(x.cols(), m_n_outputs);
601
602 for(state.time = 0; state.time < x.cols(); state.time++) {
603 step( x.col( state.time ), state );
604 }
605
606 return state.h_t;
607 }
#define x
void step(const VectorXd &input, GRUState &) const
Definition Stack.cxx:582

◆ step()

void lwtDev::GRULayer::step ( const VectorXd & input,
GRUState & s ) const

Definition at line 582 of file Stack.cxx.

582 {
583 // https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py#L547
584
585 const auto& act_fun = m_activation_fun;
586 const auto& in_act_fun = m_inner_activation_fun;
587
588 int tm1 = s.time == 0 ? 0 : s.time - 1;
589 VectorXd h_tm1 = s.h_t.col(tm1);
590 VectorXd z = (m_W_z*x_t + m_b_z + m_U_z*h_tm1).unaryExpr(in_act_fun);
591 VectorXd r = (m_W_r*x_t + m_b_r + m_U_r*h_tm1).unaryExpr(in_act_fun);
592 VectorXd rh = r.cwiseProduct(h_tm1);
593 VectorXd hh = (m_W_h*x_t + m_b_h + m_U_h*rh).unaryExpr(act_fun);
594 VectorXd one = VectorXd::Ones(z.size());
595 s.h_t.col(s.time) = z.cwiseProduct(h_tm1) + (one - z).cwiseProduct(hh);
596 }
#define z
int r
Definition globals.cxx:22

Member Data Documentation

◆ m_activation_fun

std::function<double(double)> lwtDev::GRULayer::m_activation_fun
private

Definition at line 291 of file Stack.h.

◆ m_b_h

VectorXd lwtDev::GRULayer::m_b_h
private

Definition at line 304 of file Stack.h.

◆ m_b_r

VectorXd lwtDev::GRULayer::m_b_r
private

Definition at line 300 of file Stack.h.

◆ m_b_z

VectorXd lwtDev::GRULayer::m_b_z
private

Definition at line 296 of file Stack.h.

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_inner_activation_fun

std::function<double(double)> lwtDev::GRULayer::m_inner_activation_fun
private

Definition at line 292 of file Stack.h.

◆ m_n_outputs

int lwtDev::GRULayer::m_n_outputs
private

Definition at line 306 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.

◆ m_U_h

MatrixXd lwtDev::GRULayer::m_U_h
private

Definition at line 303 of file Stack.h.

◆ m_U_r

MatrixXd lwtDev::GRULayer::m_U_r
private

Definition at line 299 of file Stack.h.

◆ m_U_z

MatrixXd lwtDev::GRULayer::m_U_z
private

Definition at line 295 of file Stack.h.

◆ m_W_h

MatrixXd lwtDev::GRULayer::m_W_h
private

Definition at line 302 of file Stack.h.

◆ m_W_r

MatrixXd lwtDev::GRULayer::m_W_r
private

Definition at line 298 of file Stack.h.

◆ m_W_z

MatrixXd lwtDev::GRULayer::m_W_z
private

Definition at line 294 of file Stack.h.


The documentation for this class was generated from the following files: