Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Public Member Functions | Public Attributes | Private Attributes | List of all members
lwtDev::GRULayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::GRULayer:
Collaboration diagram for lwtDev::GRULayer:

Public Member Functions

 GRULayer (const ActivationConfig &activation, const ActivationConfig &inner_activation, const MatrixXd &W_z, const MatrixXd &U_z, const VectorXd &b_z, const MatrixXd &W_r, const MatrixXd &U_r, const VectorXd &b_r, const MatrixXd &W_h, const MatrixXd &U_h, const VectorXd &b_h)
 
virtual ~GRULayer ()
 
virtual MatrixXd scan (const MatrixXd &) const override
 
void step (const VectorXd &input, GRUState &) const
 

Public Attributes

bool m_go_backwards = false
 
bool m_return_sequence = false
 

Private Attributes

std::function< double(double)> m_activation_fun
 
std::function< double(double)> m_inner_activation_fun
 
MatrixXd m_W_z
 
MatrixXd m_U_z
 
VectorXd m_b_z
 
MatrixXd m_W_r
 
MatrixXd m_U_r
 
VectorXd m_b_r
 
MatrixXd m_W_h
 
MatrixXd m_U_h
 
VectorXd m_b_h
 
int m_n_outputs
 

Detailed Description

Definition at line 277 of file Stack.h.

Constructor & Destructor Documentation

◆ GRULayer()

lwtDev::GRULayer::GRULayer ( const ActivationConfig activation,
const ActivationConfig inner_activation,
const MatrixXd &  W_z,
const MatrixXd &  U_z,
const VectorXd &  b_z,
const MatrixXd &  W_r,
const MatrixXd &  U_r,
const VectorXd &  b_r,
const MatrixXd &  W_h,
const MatrixXd &  U_h,
const VectorXd &  b_h 
)

Definition at line 550 of file Stack.cxx.

554  :
555  m_W_z(W_z),
556  m_U_z(U_z),
557  m_b_z(b_z),
558  m_W_r(W_r),
559  m_U_r(U_r),
560  m_b_r(b_r),
561  m_W_h(W_h),
562  m_U_h(U_h),
563  m_b_h(b_h)
564  {
565  m_n_outputs = m_W_h.rows();
566 
567  m_activation_fun = get_activation(activation);
568  m_inner_activation_fun = get_activation(inner_activation);
569  }

◆ ~GRULayer()

virtual lwtDev::GRULayer::~GRULayer ( )
inlinevirtual

Definition at line 286 of file Stack.h.

286 {};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::GRULayer::scan ( const MatrixXd &  x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 598 of file Stack.cxx.

598  {
599 
600  GRUState state(x.cols(), m_n_outputs);
601 
602  for(state.time = 0; state.time < x.cols(); state.time++) {
603  step( x.col( state.time ), state );
604  }
605 
606  return state.h_t;
607  }

◆ step()

void lwtDev::GRULayer::step ( const VectorXd &  input,
GRUState s 
) const

Definition at line 582 of file Stack.cxx.

582  {
583  // https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py#L547
584 
585  const auto& act_fun = m_activation_fun;
586  const auto& in_act_fun = m_inner_activation_fun;
587 
588  int tm1 = s.time == 0 ? 0 : s.time - 1;
589  VectorXd h_tm1 = s.h_t.col(tm1);
590  VectorXd z = (m_W_z*x_t + m_b_z + m_U_z*h_tm1).unaryExpr(in_act_fun);
591  VectorXd r = (m_W_r*x_t + m_b_r + m_U_r*h_tm1).unaryExpr(in_act_fun);
592  VectorXd rh = r.cwiseProduct(h_tm1);
593  VectorXd hh = (m_W_h*x_t + m_b_h + m_U_h*rh).unaryExpr(act_fun);
594  VectorXd one = VectorXd::Ones(z.size());
595  s.h_t.col(s.time) = z.cwiseProduct(h_tm1) + (one - z).cwiseProduct(hh);
596  }

Member Data Documentation

◆ m_activation_fun

std::function<double(double)> lwtDev::GRULayer::m_activation_fun
private

Definition at line 291 of file Stack.h.

◆ m_b_h

VectorXd lwtDev::GRULayer::m_b_h
private

Definition at line 304 of file Stack.h.

◆ m_b_r

VectorXd lwtDev::GRULayer::m_b_r
private

Definition at line 300 of file Stack.h.

◆ m_b_z

VectorXd lwtDev::GRULayer::m_b_z
private

Definition at line 296 of file Stack.h.

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_inner_activation_fun

std::function<double(double)> lwtDev::GRULayer::m_inner_activation_fun
private

Definition at line 292 of file Stack.h.

◆ m_n_outputs

int lwtDev::GRULayer::m_n_outputs
private

Definition at line 306 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.

◆ m_U_h

MatrixXd lwtDev::GRULayer::m_U_h
private

Definition at line 303 of file Stack.h.

◆ m_U_r

MatrixXd lwtDev::GRULayer::m_U_r
private

Definition at line 299 of file Stack.h.

◆ m_U_z

MatrixXd lwtDev::GRULayer::m_U_z
private

Definition at line 295 of file Stack.h.

◆ m_W_h

MatrixXd lwtDev::GRULayer::m_W_h
private

Definition at line 302 of file Stack.h.

◆ m_W_r

MatrixXd lwtDev::GRULayer::m_W_r
private

Definition at line 298 of file Stack.h.

◆ m_W_z

MatrixXd lwtDev::GRULayer::m_W_z
private

Definition at line 294 of file Stack.h.


The documentation for this class was generated from the following files:
beamspotman.r
def r
Definition: beamspotman.py:676
python.SystemOfUnits.s
int s
Definition: SystemOfUnits.py:131
lwtDev::GRULayer::m_b_h
VectorXd m_b_h
Definition: Stack.h:304
DiTauMassTools::TauTypes::hh
@ hh
Definition: PhysicsAnalysis/TauID/DiTauMassTools/DiTauMassTools/HelperFunctions.h:53
lwtDev::GRULayer::m_W_r
MatrixXd m_W_r
Definition: Stack.h:298
lwtDev::GRULayer::m_U_r
MatrixXd m_U_r
Definition: Stack.h:299
lwtDev::GRULayer::m_activation_fun
std::function< double(double)> m_activation_fun
Definition: Stack.h:291
lwtDev::GRULayer::m_b_r
VectorXd m_b_r
Definition: Stack.h:300
Trk::one
@ one
Definition: TrkDetDescr/TrkSurfaces/TrkSurfaces/RealQuadraticEquation.h:22
lwtDev::GRULayer::m_U_h
MatrixXd m_U_h
Definition: Stack.h:303
x
#define x
lwtDev::GRULayer::m_inner_activation_fun
std::function< double(double)> m_inner_activation_fun
Definition: Stack.h:292
z
#define z
lwtDev::GRULayer::m_b_z
VectorXd m_b_z
Definition: Stack.h:296
lwtDev::GRULayer::step
void step(const VectorXd &input, GRUState &) const
Definition: Stack.cxx:582
lwtDev::GRULayer::m_W_h
MatrixXd m_W_h
Definition: Stack.h:302
lwtDev::GRULayer::m_U_z
MatrixXd m_U_z
Definition: Stack.h:295
lwtDev::get_activation
std::function< double(double)> get_activation(lwtDev::ActivationConfig act)
Definition: Stack.cxx:671
lwtDev::GRULayer::m_W_z
MatrixXd m_W_z
Definition: Stack.h:294
lwtDev::GRULayer::m_n_outputs
int m_n_outputs
Definition: Stack.h:306
lwtDev::GRUState
Definition: Stack.cxx:571