ATLAS Offline Software
Public Member Functions | Public Attributes | Private Attributes | List of all members
lwtDev::EmbeddingLayer Class Reference

#include <Stack.h>

Inheritance diagram for lwtDev::EmbeddingLayer:
Collaboration diagram for lwtDev::EmbeddingLayer:

Public Member Functions

 EmbeddingLayer (int var_row_index, const MatrixXd &W)
 
virtual ~EmbeddingLayer ()
 
virtual MatrixXd scan (const MatrixXd &) const override
 

Public Attributes

bool m_go_backwards = false
 
bool m_return_sequence = false
 

Private Attributes

int m_var_row_index
 
MatrixXd m_W
 

Detailed Description

Definition at line 222 of file Stack.h.

Constructor & Destructor Documentation

◆ EmbeddingLayer()

lwtDev::EmbeddingLayer::EmbeddingLayer ( int  var_row_index,
const MatrixXd &  W 
)

Definition at line 427 of file Stack.cxx.

427  :
428  m_var_row_index(var_row_index),
429  m_W(W)
430  {
431  if(var_row_index < 0)
433  "EmbeddingLayer::EmbeddingLayer - can not set var_row_index<0,"
434  " it is an index for a matrix row!");
435  }

◆ ~EmbeddingLayer()

virtual lwtDev::EmbeddingLayer::~EmbeddingLayer ( )
inlinevirtual

Definition at line 226 of file Stack.h.

226 {};

Member Function Documentation

◆ scan()

MatrixXd lwtDev::EmbeddingLayer::scan ( const MatrixXd &  x) const
overridevirtual

Implements lwtDev::IRecurrentLayer.

Definition at line 437 of file Stack.cxx.

437  {
438 
439  if( m_var_row_index >= x.rows() )
440  throw NNEvaluationException(
441  "EmbeddingLayer::scan - var_row_index is larger than input matrix"
442  " number of rows!");
443 
444  MatrixXd embedded(m_W.rows(), x.cols());
445 
446  for(int icol=0; icol<x.cols(); icol++) {
447  double vector_idx = x(m_var_row_index, icol);
448  bool is_int = std::floor(vector_idx) == vector_idx;
449  bool is_valid = (vector_idx >= 0) && (vector_idx < m_W.cols());
450  if (!is_int || !is_valid) throw NNEvaluationException(
451  "Invalid embedded index: " + std::to_string(vector_idx));
452  embedded.col(icol) = m_W.col( vector_idx );
453  }
454 
455  //only embed 1 variable at a time, so this should be correct size
456  MatrixXd out(m_W.rows() + (x.rows() - 1), x.cols());
457 
458  //assuming m_var_row_index is an index with first possible value of 0
459  if(m_var_row_index > 0)
460  out.topRows(m_var_row_index) = x.topRows(m_var_row_index);
461 
462  out.block(m_var_row_index, 0, embedded.rows(), embedded.cols()) = embedded;
463 
464  if( m_var_row_index < (x.rows()-1) )
465  out.bottomRows( x.cols() - 1 - m_var_row_index)
466  = x.bottomRows( x.cols() - 1 - m_var_row_index);
467 
468  return out;
469  }

Member Data Documentation

◆ m_go_backwards

bool lwtDev::IRecurrentLayer::m_go_backwards = false
inherited

Definition at line 218 of file Stack.h.

◆ m_return_sequence

bool lwtDev::IRecurrentLayer::m_return_sequence = false
inherited

Definition at line 219 of file Stack.h.

◆ m_var_row_index

int lwtDev::EmbeddingLayer::m_var_row_index
private

Definition at line 230 of file Stack.h.

◆ m_W

MatrixXd lwtDev::EmbeddingLayer::m_W
private

Definition at line 231 of file Stack.h.


The documentation for this class was generated from the following files:
JetTiledMap::W
@ W
Definition: TiledEtaPhiMap.h:44
lwtDev::EmbeddingLayer::m_W
MatrixXd m_W
Definition: Stack.h:231
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
x
#define x
atlasStyleMacro.icol
int icol
Definition: atlasStyleMacro.py:13
lwtDev::NNConfigurationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:21
lwtDev::EmbeddingLayer::m_var_row_index
int m_var_row_index
Definition: Stack.h:230
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lwtDev::NNEvaluationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:27