18 using namespace Eigen;
24 typedef std::vector<IP*> Preprocs;
27 typedef std::vector<IVP*> VecPreprocs;
32 class LazySource:
public ISource
35 LazySource(
const NodeVec&,
const SeqNodeVec&,
36 const Preprocs&,
const VecPreprocs&,
37 const SourceIndices& input_indices);
38 virtual VectorXd at(
size_t index)
const override;
39 virtual MatrixXd matrix_at(
size_t index)
const override;
42 const SeqNodeVec& m_seqs;
43 const Preprocs& m_preprocs;
44 const VecPreprocs& m_vec_preprocs;
45 const SourceIndices& m_input_indices;
48 LazySource::LazySource(
const NodeVec& n,
const SeqNodeVec& s,
49 const Preprocs& p,
const VecPreprocs& v,
51 m_nodes(
n), m_seqs(
s), m_preprocs(
p), m_vec_preprocs(
v),
55 VectorXd LazySource::at(
size_t index)
const
57 const auto& preproc = *m_preprocs.at(index);
58 size_t source_index = m_input_indices.
scalar.at(index);
59 if (source_index >= m_nodes.size()) {
60 throw NNEvaluationException(
61 "The NN needs an input VectorXd at position "
62 + std::to_string(source_index) +
" but only "
63 + std::to_string(m_nodes.size()) +
" inputs were given");
65 return preproc(m_nodes.at(source_index));
67 MatrixXd LazySource::matrix_at(
size_t index)
const
69 const auto& preproc = *m_vec_preprocs.at(index);
70 size_t source_index = m_input_indices.
sequence.at(index);
71 if (source_index >= m_nodes.size()) {
72 throw NNEvaluationException(
73 "The NN needs an input MatrixXd at position "
74 + std::to_string(source_index) +
" but only "
75 + std::to_string(m_nodes.size()) +
" inputs were given");
77 return preproc(m_seqs.at(source_index));
84 std::vector<size_t> get_node_indices(
86 const std::vector<lwt::InputNodeConfig>& inputs)
88 std::map<std::string, size_t> order_indices;
89 for (
size_t i = 0;
i <
order.size();
i++) {
90 order_indices[
order.at(i).first] =
i;
92 std::vector<size_t> node_indices;
93 for (
const lwt::InputNodeConfig& input: inputs) {
94 if (!order_indices.count(
input.name)) {
95 throw NNConfigurationException(
"Missing input " +
input.name);
97 node_indices.push_back(order_indices.at(
input.name));
104namespace lwt::atlas {
110 std::string default_output):
115 order.scalar,
config.inputs);
118 order.sequence,
config.input_sequences);
120 for (
size_t i = 0; i <
config.inputs.size(); i++) {
121 const lwt::InputNodeConfig&
node =
config.inputs.at(i);
123 std::vector<std::string> varorder = order.scalar.at(input_node).second;
127 for (
size_t i = 0; i <
config.input_sequences.size(); i++) {
128 const lwt::InputNodeConfig&
node =
config.input_sequences.at(i);
130 std::vector<std::string> varorder = order.sequence.at(input_node).second;
134 if (default_output.size() > 0) {
135 if (!
config.outputs.count(default_output)) {
136 throw NNConfigurationException(
"no output node" + default_output);
139 }
else if (
config.outputs.size() == 1) {
142 throw NNConfigurationException(
"you must specify a default output");
167 return m_graph->compute(source, idx);
SourceIndices m_input_indices
FastGraph(const GraphConfig &config, const InputOrder &order, std::string default_output="")
std::vector< Eigen::MatrixXd > SeqNodeVec
std::vector< Eigen::VectorXd > NodeVec
Eigen::VectorXd compute(const NodeVec &, const SeqNodeVec &={}) const
VecPreprocs m_vec_preprocs
Ensure that the extensions for the Vector3D are properly loaded.
std::vector< std::pair< std::string, std::vector< std::string > > > order_t
FastGraph::NodeVec NodeVec
std::vector< size_t > sequence
std::vector< size_t > scalar