ATLAS Offline Software
FastGraph.cxx
Go to the documentation of this file.
1 // this is -*- C++ -*-
2 /*
3  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
4 */
5 
6 // Any modifications to this file may be copied to lwtnn[1] without
7 // attribution.
8 //
9 // [1]: https::www.github.com/lwtnn/lwtnn
10 
11 #include "LwtnnUtils/FastGraph.h"
13 #include "LwtnnUtils/InputOrder.h"
14 
15 #include <Eigen/Dense>
16 
17 namespace {
18  using namespace Eigen;
19  using namespace lwt;
20  using namespace lwt::atlas;
21 
24  typedef std::vector<IP*> Preprocs;
25  typedef atlas::FastGraph::SeqNodeVec SeqNodeVec;
27  typedef std::vector<IVP*> VecPreprocs;
28 
29 
30  // this is used internally to ensure that we only look up map inputs
31  // when the network asks for them.
32  class LazySource: public ISource
33  {
34  public:
35  LazySource(const NodeVec&, const SeqNodeVec&,
36  const Preprocs&, const VecPreprocs&,
37  const SourceIndices& input_indices);
38  virtual VectorXd at(size_t index) const override;
39  virtual MatrixXd matrix_at(size_t index) const override;
40  private:
41  const NodeVec& m_nodes;
42  const SeqNodeVec& m_seqs;
43  const Preprocs& m_preprocs;
44  const VecPreprocs& m_vec_preprocs;
45  const SourceIndices& m_input_indices;
46  };
47 
48  LazySource::LazySource(const NodeVec& n, const SeqNodeVec& s,
49  const Preprocs& p, const VecPreprocs& v,
50  const SourceIndices& i):
51  m_nodes(n), m_seqs(s), m_preprocs(p), m_vec_preprocs(v),
52  m_input_indices(i)
53  {
54  }
55  VectorXd LazySource::at(size_t index) const
56  {
57  const auto& preproc = *m_preprocs.at(index);
58  size_t source_index = m_input_indices.scalar.at(index);
59  if (source_index >= m_nodes.size()) {
60  throw NNEvaluationException(
61  "The NN needs an input VectorXd at position "
62  + std::to_string(source_index) + " but only "
63  + std::to_string(m_nodes.size()) + " inputs were given");
64  }
65  return preproc(m_nodes.at(source_index));
66  }
67  MatrixXd LazySource::matrix_at(size_t index) const
68  {
69  const auto& preproc = *m_vec_preprocs.at(index);
70  size_t source_index = m_input_indices.sequence.at(index);
71  if (source_index >= m_nodes.size()) {
72  throw NNEvaluationException(
73  "The NN needs an input MatrixXd at position "
74  + std::to_string(source_index) + " but only "
75  + std::to_string(m_nodes.size()) + " inputs were given");
76  }
77  return preproc(m_seqs.at(source_index));
78  }
79 
80  // utility functions
81  //
82  // Build a mapping from the inputs in the saved network to the
83  // inputs that the user is going to hand us.
84  std::vector<size_t> get_node_indices(
85  const order_t& order,
86  const std::vector<lwt::InputNodeConfig>& inputs)
87  {
88  std::map<std::string, size_t> order_indices;
89  for (size_t i = 0; i < order.size(); i++) {
90  order_indices[order.at(i).first] = i;
91  }
92  std::vector<size_t> node_indices;
93  for (const lwt::InputNodeConfig& input: inputs) {
94  if (!order_indices.count(input.name)) {
95  throw NNConfigurationException("Missing input " + input.name);
96  }
97  node_indices.push_back(order_indices.at(input.name));
98  }
99  return node_indices;
100  }
101 
102 
103 }
104 namespace lwt::atlas {
105  // ______________________________________________________________________
106  // Fast Graph
107 
109  FastGraph::FastGraph(const GraphConfig& config, const InputOrder& order,
110  std::string default_output):
111  m_graph(new Graph(config.nodes, config.layers))
112  {
113 
114  m_input_indices.scalar = get_node_indices(
115  order.scalar, config.inputs);
116 
117  m_input_indices.sequence = get_node_indices(
118  order.sequence, config.input_sequences);
119 
120  for (size_t i = 0; i < config.inputs.size(); i++) {
121  const lwt::InputNodeConfig& node = config.inputs.at(i);
122  size_t input_node = m_input_indices.scalar.at(i);
123  std::vector<std::string> varorder = order.scalar.at(input_node).second;
124  m_preprocs.emplace_back(
125  new FastInputPreprocessor(node.variables, varorder));
126  }
127  for (size_t i = 0; i < config.input_sequences.size(); i++) {
128  const lwt::InputNodeConfig& node = config.input_sequences.at(i);
129  size_t input_node = m_input_indices.sequence.at(i);
130  std::vector<std::string> varorder = order.sequence.at(input_node).second;
131  m_vec_preprocs.emplace_back(
132  new FastInputVectorPreprocessor(node.variables, varorder));
133  }
134  if (default_output.size() > 0) {
135  if (!config.outputs.count(default_output)) {
136  throw NNConfigurationException("no output node" + default_output);
137  }
138  m_default_output = config.outputs.at(default_output).node_index;
139  } else if (config.outputs.size() == 1) {
140  m_default_output = config.outputs.begin()->second.node_index;
141  } else {
142  throw NNConfigurationException("you must specify a default output");
143  }
144  }
145 
147  delete m_graph;
148  for (auto& preproc: m_preprocs) {
149  delete preproc;
150  preproc = 0;
151  }
152  for (auto& preproc: m_vec_preprocs) {
153  delete preproc;
154  preproc = 0;
155  }
156  }
157 
158  VectorXd FastGraph::compute(const NodeVec& nodes,
159  const SeqNodeVec& seq) const {
160  return compute(nodes, seq, m_default_output);
161  }
162  VectorXd FastGraph::compute(const NodeVec& nodes,
163  const SeqNodeVec& seq,
164  size_t idx) const {
165  LazySource source(nodes, seq, m_preprocs, m_vec_preprocs,
167  return m_graph->compute(source, idx);
168  }
169 
170 }
test_athena_ntuple_filter.seq
seq
filter configuration ## -> we use the special sequence 'AthMasterSeq' which is run before any other a...
Definition: test_athena_ntuple_filter.py:18
python.SystemOfUnits.s
int s
Definition: SystemOfUnits.py:131
lwt::atlas::FastGraph::m_vec_preprocs
VecPreprocs m_vec_preprocs
Definition: FastGraph.h:64
FastGraph.h
index
Definition: index.py:1
CSV_InDetExporter.new
new
Definition: CSV_InDetExporter.py:145
module_driven_slicing.layers
layers
Definition: module_driven_slicing.py:114
lwt::atlas::FastGraph::m_graph
Graph * m_graph
Definition: FastGraph.h:62
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
lwt::atlas::FastGraph::m_input_indices
SourceIndices m_input_indices
Definition: FastGraph.h:67
lwt::atlas::FastGraph::compute
Eigen::VectorXd compute(const NodeVec &, const SeqNodeVec &={}) const
Definition: FastGraph.cxx:158
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
lwt::atlas::FastGraph::FastGraph
FastGraph(const GraphConfig &config, const InputOrder &order, std::string default_output="")
Definition: FastGraph.cxx:109
lwt::atlas::FastGraph::~FastGraph
~FastGraph()
Definition: FastGraph.cxx:146
lumiFormat.i
int i
Definition: lumiFormat.py:85
beamspotman.n
n
Definition: beamspotman.py:731
mc.order
order
Configure Herwig7.
Definition: mc.Herwig7_Dijet.py:12
lwt::atlas::InputOrder
Definition: InputOrder.h:28
lwt::atlas::order_t
std::vector< std::pair< std::string, std::vector< std::string > > > order_t
Definition: InputOrder.h:25
PlotPulseshapeFromCool.input
input
Definition: PlotPulseshapeFromCool.py:106
lwt
Definition: NnClusterizationFactory.h:52
lwt::atlas::FastInputVectorPreprocessor
Definition: FastInputPreprocessor.h:46
lwt::atlas::FastGraph::m_default_output
size_t m_default_output
Definition: FastGraph.h:65
lwt::atlas
Ensure that the extensions for the Vector3D are properly loaded.
Definition: LWTNNCondAlg.h:24
lwt::atlas::SourceIndices
Definition: FastGraph.h:26
lwt::atlas::FastGraph::m_preprocs
Preprocs m_preprocs
Definition: FastGraph.h:63
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
InputOrder.h
python.PyAthena.v
v
Definition: PyAthena.py:154
lwt::atlas::FastGraph::SeqNodeVec
std::vector< Eigen::MatrixXd > SeqNodeVec
Definition: FastGraph.h:37
lwt::atlas::FastGraph::NodeVec
std::vector< Eigen::VectorXd > NodeVec
Definition: FastGraph.h:36
lwt::atlas::SourceIndices::sequence
std::vector< size_t > sequence
Definition: FastGraph.h:28
lwt::atlas::NodeVec
FastGraph::NodeVec NodeVec
Definition: FastGraph.cxx:108
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
copySelective.source
string source
Definition: copySelective.py:32
lwt::atlas::SourceIndices::scalar
std::vector< size_t > scalar
Definition: FastGraph.h:27
FastInputPreprocessor.h
node
Definition: memory_hooks-stdcmalloc.h:74
lwt::atlas::FastInputPreprocessor
Definition: FastInputPreprocessor.h:33