ATLAS Offline Software
Loading...
Searching...
No Matches
lwtDev::EmbeddingLayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::EmbeddingLayer:
Collaboration diagram for lwtDev::EmbeddingLayer:

Public Member Functions

 EmbeddingLayer (int var_row_index, const MatrixXd &W)
virtual ~EmbeddingLayer ()
virtual MatrixXd scan (const MatrixXd &) const override

Public Attributes

bool m_go_backwards = false
bool m_return_sequence = false

Private Attributes

int m_var_row_index
MatrixXd m_W

Detailed Description

Definition at line 222 of file Stack.h.

Constructor & Destructor Documentation

◆ EmbeddingLayer()

lwtDev::EmbeddingLayer::EmbeddingLayer ( int var_row_index,
const MatrixXd & W )

Definition at line 427 of file Stack.cxx.

427 :
428 m_var_row_index(var_row_index),
429 m_W(W)
430 {
431 if(var_row_index < 0)
432 throw NNConfigurationException(
433 "EmbeddingLayer::EmbeddingLayer - can not set var_row_index<0,"
434 " it is an index for a matrix row!");
435 }

◆ ~EmbeddingLayer()

virtual lwtDev::EmbeddingLayer::~EmbeddingLayer ( )
inlinevirtual

Definition at line 226 of file Stack.h.

226{};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::EmbeddingLayer::scan ( const MatrixXd & x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 437 of file Stack.cxx.

437 {
438
439 if( m_var_row_index >= x.rows() )
440 throw NNEvaluationException(
441 "EmbeddingLayer::scan - var_row_index is larger than input matrix"
442 " number of rows!");
443
444 MatrixXd embedded(m_W.rows(), x.cols());
445
446 for(int icol=0; icol<x.cols(); icol++) {
447 double vector_idx = x(m_var_row_index, icol);
448 bool is_int = std::floor(vector_idx) == vector_idx;
449 bool is_valid = (vector_idx >= 0) && (vector_idx < m_W.cols());
450 if (!is_int || !is_valid) throw NNEvaluationException(
451 "Invalid embedded index: " + std::to_string(vector_idx));
452 embedded.col(icol) = m_W.col( vector_idx );
453 }
454
455 //only embed 1 variable at a time, so this should be correct size
456 MatrixXd out(m_W.rows() + (x.rows() - 1), x.cols());
457
458 //assuming m_var_row_index is an index with first possible value of 0
459 if(m_var_row_index > 0)
460 out.topRows(m_var_row_index) = x.topRows(m_var_row_index);
461
462 out.block(m_var_row_index, 0, embedded.rows(), embedded.cols()) = embedded;
463
464 if( m_var_row_index < (x.rows()-1) )
465 out.bottomRows( x.cols() - 1 - m_var_row_index)
466 = x.bottomRows( x.cols() - 1 - m_var_row_index);
467
468 return out;
469 }
#define x

Member Data Documentation

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.

◆ m_var_row_index

int lwtDev::EmbeddingLayer::m_var_row_index
private

Definition at line 230 of file Stack.h.

◆ m_W

MatrixXd lwtDev::EmbeddingLayer::m_W
private

Definition at line 231 of file Stack.h.


The documentation for this class was generated from the following files: