ATLAS Offline Software
Loading...
Searching...
No Matches
FastGraph.cxx
Go to the documentation of this file.
1// this is -*- C++ -*-
2/*
3 Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
4*/
5
6// Any modifications to this file may be copied to lwtnn[1] without
7// attribution.
8//
9// [1]: https::www.github.com/lwtnn/lwtnn
10
14
15#include <Eigen/Dense>
16
17namespace {
18 using namespace Eigen;
19 using namespace lwt;
20 using namespace lwt::atlas;
21
24 typedef std::vector<IP*> Preprocs;
25 typedef atlas::FastGraph::SeqNodeVec SeqNodeVec;
27 typedef std::vector<IVP*> VecPreprocs;
28
29
30 // this is used internally to ensure that we only look up map inputs
31 // when the network asks for them.
32 class LazySource: public ISource
33 {
34 public:
35 LazySource(const NodeVec&, const SeqNodeVec&,
36 const Preprocs&, const VecPreprocs&,
37 const SourceIndices& input_indices);
38 virtual VectorXd at(size_t index) const override;
39 virtual MatrixXd matrix_at(size_t index) const override;
40 private:
41 const NodeVec& m_nodes;
42 const SeqNodeVec& m_seqs;
43 const Preprocs& m_preprocs;
44 const VecPreprocs& m_vec_preprocs;
45 const SourceIndices& m_input_indices;
46 };
47
48 LazySource::LazySource(const NodeVec& n, const SeqNodeVec& s,
49 const Preprocs& p, const VecPreprocs& v,
50 const SourceIndices& i):
51 m_nodes(n), m_seqs(s), m_preprocs(p), m_vec_preprocs(v),
52 m_input_indices(i)
53 {
54 }
55 VectorXd LazySource::at(size_t index) const
56 {
57 const auto& preproc = *m_preprocs.at(index);
58 size_t source_index = m_input_indices.scalar.at(index);
59 if (source_index >= m_nodes.size()) {
60 throw NNEvaluationException(
61 "The NN needs an input VectorXd at position "
62 + std::to_string(source_index) + " but only "
63 + std::to_string(m_nodes.size()) + " inputs were given");
64 }
65 return preproc(m_nodes.at(source_index));
66 }
67 MatrixXd LazySource::matrix_at(size_t index) const
68 {
69 const auto& preproc = *m_vec_preprocs.at(index);
70 size_t source_index = m_input_indices.sequence.at(index);
71 if (source_index >= m_nodes.size()) {
72 throw NNEvaluationException(
73 "The NN needs an input MatrixXd at position "
74 + std::to_string(source_index) + " but only "
75 + std::to_string(m_nodes.size()) + " inputs were given");
76 }
77 return preproc(m_seqs.at(source_index));
78 }
79
80 // utility functions
81 //
82 // Build a mapping from the inputs in the saved network to the
83 // inputs that the user is going to hand us.
84 std::vector<size_t> get_node_indices(
85 const order_t& order,
86 const std::vector<lwt::InputNodeConfig>& inputs)
87 {
88 std::map<std::string, size_t> order_indices;
89 for (size_t i = 0; i < order.size(); i++) {
90 order_indices[order.at(i).first] = i;
91 }
92 std::vector<size_t> node_indices;
93 for (const lwt::InputNodeConfig& input: inputs) {
94 if (!order_indices.count(input.name)) {
95 throw NNConfigurationException("Missing input " + input.name);
96 }
97 node_indices.push_back(order_indices.at(input.name));
98 }
99 return node_indices;
100 }
101
102
103}
104namespace lwt::atlas {
105 // ______________________________________________________________________
106 // Fast Graph
107
109 FastGraph::FastGraph(const GraphConfig& config, const InputOrder& order,
110 std::string default_output):
111 m_graph(new Graph(config.nodes, config.layers))
112 {
113
114 m_input_indices.scalar = get_node_indices(
115 order.scalar, config.inputs);
116
117 m_input_indices.sequence = get_node_indices(
118 order.sequence, config.input_sequences);
119
120 for (size_t i = 0; i < config.inputs.size(); i++) {
121 const lwt::InputNodeConfig& node = config.inputs.at(i);
122 size_t input_node = m_input_indices.scalar.at(i);
123 std::vector<std::string> varorder = order.scalar.at(input_node).second;
124 m_preprocs.emplace_back(
125 new FastInputPreprocessor(node.variables, varorder));
126 }
127 for (size_t i = 0; i < config.input_sequences.size(); i++) {
128 const lwt::InputNodeConfig& node = config.input_sequences.at(i);
129 size_t input_node = m_input_indices.sequence.at(i);
130 std::vector<std::string> varorder = order.sequence.at(input_node).second;
131 m_vec_preprocs.emplace_back(
132 new FastInputVectorPreprocessor(node.variables, varorder));
133 }
134 if (default_output.size() > 0) {
135 if (!config.outputs.count(default_output)) {
136 throw NNConfigurationException("no output node" + default_output);
137 }
138 m_default_output = config.outputs.at(default_output).node_index;
139 } else if (config.outputs.size() == 1) {
140 m_default_output = config.outputs.begin()->second.node_index;
141 } else {
142 throw NNConfigurationException("you must specify a default output");
143 }
144 }
145
147 delete m_graph;
148 for (auto& preproc: m_preprocs) {
149 delete preproc;
150 preproc = 0;
151 }
152 for (auto& preproc: m_vec_preprocs) {
153 delete preproc;
154 preproc = 0;
155 }
156 }
157
158 VectorXd FastGraph::compute(const NodeVec& nodes,
159 const SeqNodeVec& seq) const {
160 return compute(nodes, seq, m_default_output);
161 }
162 VectorXd FastGraph::compute(const NodeVec& nodes,
163 const SeqNodeVec& seq,
164 size_t idx) const {
165 LazySource source(nodes, seq, m_preprocs, m_vec_preprocs,
167 return m_graph->compute(source, idx);
168 }
169
170}
SourceIndices m_input_indices
Definition FastGraph.h:67
FastGraph(const GraphConfig &config, const InputOrder &order, std::string default_output="")
std::vector< Eigen::MatrixXd > SeqNodeVec
Definition FastGraph.h:37
std::vector< Eigen::VectorXd > NodeVec
Definition FastGraph.h:36
Eigen::VectorXd compute(const NodeVec &, const SeqNodeVec &={}) const
VecPreprocs m_vec_preprocs
Definition FastGraph.h:64
Definition node.h:24
Definition index.py:1
Ensure that the extensions for the Vector3D are properly loaded.
std::vector< std::pair< std::string, std::vector< std::string > > > order_t
Definition InputOrder.h:25
FastGraph::NodeVec NodeVec
order
Configure Herwig7.
std::vector< size_t > sequence
Definition FastGraph.h:28
std::vector< size_t > scalar
Definition FastGraph.h:27