ATLAS Offline Software
Loading...
Searching...
No Matches
lwtDev::HighwayLayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::HighwayLayer:
Collaboration diagram for lwtDev::HighwayLayer:

Public Member Functions

 HighwayLayer (const MatrixXd &W, const VectorXd &b, const MatrixXd &W_carry, const VectorXd &b_carry, ActivationConfig activation)
virtual VectorXd compute (const VectorXd &) const override

Private Attributes

MatrixXd m_w_t
VectorXd m_b_t
MatrixXd m_w_c
VectorXd m_b_c
std::function< double(double)> m_act

Detailed Description

Definition at line 153 of file Stack.h.

Constructor & Destructor Documentation

◆ HighwayLayer()

lwtDev::HighwayLayer::HighwayLayer ( const MatrixXd & W,
const VectorXd & b,
const MatrixXd & W_carry,
const VectorXd & b_carry,
ActivationConfig activation )

Definition at line 252 of file Stack.cxx.

256 :
257 m_w_t(W), m_b_t(b), m_w_c(W_carry), m_b_c(b_carry),
258 m_act(get_activation(activation))
259 {
260 }
std::function< double(double)> m_act
Definition Stack.h:167
std::function< double(double)> get_activation(lwtDev::ActivationConfig act)
Definition Stack.cxx:671

Member Function Documentation

◆ compute()

VectorXd lwtDev::HighwayLayer::compute ( const VectorXd & in) const
overridevirtual

Implements lwtDev::ILayer.

Definition at line 261 of file Stack.cxx.

261 {
262 const std::function<double(double)> sig(nn_sigmoid);
263 ArrayXd c = (m_w_c * in + m_b_c).unaryExpr(sig);
264 ArrayXd t = (m_w_t * in + m_b_t).unaryExpr(m_act);
265 return c * t + (1 - c) * in.array();
266 }
double nn_sigmoid(double x)
Definition Stack.cxx:690

Member Data Documentation

◆ m_act

std::function<double(double)> lwtDev::HighwayLayer::m_act
private

Definition at line 167 of file Stack.h.

◆ m_b_c

VectorXd lwtDev::HighwayLayer::m_b_c
private

Definition at line 166 of file Stack.h.

◆ m_b_t

VectorXd lwtDev::HighwayLayer::m_b_t
private

Definition at line 164 of file Stack.h.

◆ m_w_c

MatrixXd lwtDev::HighwayLayer::m_w_c
private

Definition at line 165 of file Stack.h.

◆ m_w_t

MatrixXd lwtDev::HighwayLayer::m_w_t
private

Definition at line 163 of file Stack.h.


The documentation for this class was generated from the following files: