ATLAS Offline Software
LightweightGraph.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
8 #include <iostream>
9 #include <Eigen/Dense>
10 
11 namespace {
12  using namespace Eigen;
13  using namespace lwtDev;
14 
16  typedef InputPreprocessor IP;
17  typedef std::vector<std::pair<std::string, IP*> > Preprocs;
19  typedef InputVectorPreprocessor IVP;
20  typedef std::vector<std::pair<std::string, IVP*> > VecPreprocs;
21 
22  // this is used internally to ensure that we only look up map inputs
23  // when the network asks for them.
24  class LazySource: public ISource
25  {
26  public:
27  LazySource(const NodeMap&, const SeqNodeMap&,
28  const Preprocs&, const VecPreprocs&);
29  virtual VectorXd at(size_t index) const override;
30  virtual MatrixXd matrix_at(size_t index) const override;
31  private:
32  const NodeMap& m_nodes;
33  const SeqNodeMap& m_seqs;
34  const Preprocs& m_preprocs;
35  const VecPreprocs& m_vec_preprocs;
36  };
37 
38  LazySource::LazySource(const NodeMap& n, const SeqNodeMap& s,
39  const Preprocs& p, const VecPreprocs& v):
40  m_nodes(n), m_seqs(s), m_preprocs(p), m_vec_preprocs(v)
41  {
42  }
43  VectorXd LazySource::at(size_t index) const
44  {
45  const auto& proc = m_preprocs.at(index);
46  if (!m_nodes.count(proc.first)) {
47  throw NNEvaluationException("Can't find node " + proc.first);
48  }
49  const auto& preproc = *proc.second;
50  return preproc(m_nodes.at(proc.first));
51  }
52  MatrixXd LazySource::matrix_at(size_t index) const
53  {
54  const auto& proc = m_vec_preprocs.at(index);
55  if (!m_seqs.count(proc.first)) {
56  throw NNEvaluationException("Can't find sequence node " + proc.first);
57  }
58  const auto& preproc = *proc.second;
59  return preproc(m_seqs.at(proc.first));
60  }
61 
62 }
63 namespace lwtDev {
64  // ______________________________________________________________________
65  // Lightweight Graph
66 
69  const std::string& default_output):
70  m_graph(new Graph(config.nodes, config.layers))
71  {
72  for (const auto& node: config.inputs) {
73  m_preprocs.emplace_back(
74  node.name, new InputPreprocessor(node.variables));
75  }
76  for (const auto& node: config.input_sequences) {
77  m_vec_preprocs.emplace_back(
78  node.name, new InputVectorPreprocessor(node.variables));
79  }
80  size_t output_n = 0;
81  for (const auto& node: config.outputs) {
82  m_outputs.emplace_back(node.second.node_index, node.second.labels);
83  m_output_indices.emplace(node.first, output_n);
84  output_n++;
85  }
86  if (default_output.size() > 0) {
87  if (!m_output_indices.count(default_output)) {
88  throw NNConfigurationException("no output node" + default_output);
89  }
90  m_default_output = m_output_indices.at(default_output);
91  } else if (output_n == 1) {
92  m_default_output = 0;
93  } else {
94  throw NNConfigurationException("you must specify a default output");
95  }
96  }
97 
99  delete m_graph;
100  for (auto& preproc: m_preprocs) {
101  delete preproc.second;
102  preproc.second = 0;
103  }
104  for (auto& preproc: m_vec_preprocs) {
105  delete preproc.second;
106  preproc.second = 0;
107  }
108  }
109 
111  const SeqNodeMap& seq) const {
112  return compute(nodes, seq, m_default_output);
113  }
115  const SeqNodeMap& seq,
116  const std::string& output) const {
117  if (!m_output_indices.count(output)) {
118  throw NNEvaluationException("no output node " + output);
119  }
120  return compute(nodes, seq, m_output_indices.at(output));
121  }
123  const SeqNodeMap& seq,
124  size_t idx) const {
125  LazySource source(nodes, seq, m_preprocs, m_vec_preprocs);
126  VectorXd result = m_graph->compute(source, m_outputs.at(idx).first);
127  const std::vector<std::string>& labels = m_outputs.at(idx).second;
128  std::map<std::string, double> output;
129  for (size_t iii = 0; iii < labels.size(); iii++) {
130  output[labels.at(iii)] = result(iii);
131  }
132  return output;
133  }
134 
136  const SeqNodeMap& seq) const {
137  return scan(nodes, seq, m_default_output);
138  }
140  const SeqNodeMap& seq,
141  const std::string& output) const {
142  if (!m_output_indices.count(output)) {
143  throw NNEvaluationException("no output node " + output);
144  }
145  return scan(nodes, seq, m_output_indices.at(output));
146  }
148  const SeqNodeMap& seq,
149  size_t idx) const {
150  LazySource source(nodes, seq, m_preprocs, m_vec_preprocs);
151  MatrixXd result = m_graph->scan(source, m_outputs.at(idx).first);
152  const std::vector<std::string>& labels = m_outputs.at(idx).second;
153  std::map<std::string, std::vector<double> > output;
154  for (size_t iii = 0; iii < labels.size(); iii++) {
155  VectorXd row = result.row(iii);
156  std::vector<double> out_vector(row.data(), row.data() + row.size());
157  output[labels.at(iii)] = out_vector;
158  }
159  return output;
160  }
161 
162 }
test_athena_ntuple_filter.seq
seq
filter configuration ## -> we use the special sequence 'AthMasterSeq' which is run before any other a...
Definition: test_athena_ntuple_filter.py:18
query_example.row
row
Definition: query_example.py:24
lwtDev::ValueMap
std::map< std::string, double > ValueMap
Definition: InputPreprocessor.h:22
python.SystemOfUnits.s
int s
Definition: SystemOfUnits.py:131
lwtDev::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: InputPreprocessor.h:24
get_generator_info.result
result
Definition: get_generator_info.py:21
lwtDev::InputPreprocessor
Definition: InputPreprocessor.h:30
index
Definition: index.py:1
LightweightGraph.h
CSV_InDetExporter.new
new
Definition: CSV_InDetExporter.py:145
module_driven_slicing.layers
layers
Definition: module_driven_slicing.py:114
lwtDev::InputVectorPreprocessor
Definition: InputPreprocessor.h:42
lwtDev::LightweightGraph::m_graph
Graph * m_graph
Definition: LightweightGraph.h:105
lwtDev::NodeMap
LightweightGraph::NodeMap NodeMap
Definition: LightweightGraph.cxx:67
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
lwtDev::NNConfigurationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:21
lwtDev::LightweightGraph::m_preprocs
Preprocs m_preprocs
Definition: LightweightGraph.h:106
lwtDev::Graph
Definition: Graph.h:120
beamspotnt.labels
list labels
Definition: bin/beamspotnt.py:1447
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
lwtDev::LightweightGraph::m_outputs
std::vector< std::pair< size_t, std::vector< std::string > > > m_outputs
Definition: LightweightGraph.h:108
lwtDev::LightweightGraph::compute
ValueMap compute(const NodeMap &, const SeqNodeMap &={}) const
Definition: LightweightGraph.cxx:110
beamspotman.n
n
Definition: beamspotman.py:731
lwtDev::LightweightGraph::SeqNodeMap
std::map< std::string, VectorMap > SeqNodeMap
Definition: LightweightGraph.h:69
InputPreprocessor.h
lwtDev::LightweightGraph::~LightweightGraph
~LightweightGraph()
Definition: LightweightGraph.cxx:98
lwtDev::LightweightGraph::NodeMap
std::map< std::string, ValueMap > NodeMap
Definition: LightweightGraph.h:68
merge.output
output
Definition: merge.py:17
node::name
void name(const std::string &n)
Definition: node.h:37
lwtDev
Definition: Reconstruction/tauRecTools/Root/lwtnn/Exceptions.cxx:8
lwtDev::LightweightGraph::LightweightGraph
LightweightGraph(const GraphConfig &config, const std::string &default_output="")
Definition: LightweightGraph.cxx:68
mc.proc
proc
Definition: mc.PhPy8EG_A14NNPDF23_gg4l_example.py:22
lwtDev::NNEvaluationException
Definition: Reconstruction/tauRecTools/tauRecTools/lwtnn/Exceptions.h:27
lwtDev::LightweightGraph::m_vec_preprocs
VecPreprocs m_vec_preprocs
Definition: LightweightGraph.h:107
Graph.h
python.PyAthena.v
v
Definition: PyAthena.py:154
lwtDev::ISource
Definition: Source.h:18
lwtDev::LightweightGraph::scan
VectorMap scan(const NodeMap &, const SeqNodeMap &={}) const
Definition: LightweightGraph.cxx:135
lwtDev::Graph::compute
VectorXd compute(const ISource &, size_t node_number) const
Definition: Graph.cxx:315
lwtDev::LightweightGraph::m_output_indices
std::map< std::string, size_t > m_output_indices
Definition: LightweightGraph.h:109
lwtDev::Graph::scan
MatrixXd scan(const ISource &, size_t node_number) const
Definition: Graph.cxx:332
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
copySelective.source
string source
Definition: copySelective.py:32
tauRecTools::SeqNodeMap
std::map< std::string, VectorMap > SeqNodeMap
Definition: TauTrackRNNClassifier.h:44
lwtDev::GraphConfig
Definition: lightweight_network_config.h:58
node
Definition: memory_hooks-stdcmalloc.h:74
lwtDev::LightweightGraph::m_default_output
size_t m_default_output
Definition: LightweightGraph.h:110