Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
LightweightGraph.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef LIGHTWEIGHT_GRAPH_HH_TAURECTOOLS
6 #define LIGHTWEIGHT_GRAPH_HH_TAURECTOOLS
7 
8 /* Lightweight Graph
9 
10  The lightweightGraph class is a more flexible version of the
11  LightweightNeuralNetwork class. This flexibility comes from the
12  ability to read from multiple inputs, merge them, and then expose
13  multiple outputs.
14 
15  For example, a conventional feed-forward network may be structured
16  as follows:
17 
18  I <-- simple input vector
19  |
20  D <-- dense feed-forward layer
21  |
22  O <-- output activation function
23 
24  A graph is more flexible, allowing structures like the following:
25 
26  I_s <-- sequential input
27  |
28  GRU I_v <-- simple input vector
29  \ /
30  M <-- merge layer
31  |
32  D <-- dense layer
33  / \
34  D2 D3
35  | |
36  | O_c <-- multiclass output (softmax activation)
37  |
38  O_r <-- regression output (linnear output)
39 
40  i.e. a graph can combine any number of sequential and "standard"
41  rank-1 inputs, and can use the same internal features to infer many
42  different attributes from the input pattern.
43 
44  Like the LightweightNeuralNetwork, it contains no Eigen code: it
45  only serves as a high-level wrapper to convert std::map objects to
46  Eigen objects and Eigen objects back to std::maps. For the
47  underlying implementation, see Graph.h. */
48 
50 
51 namespace lwtDev {
52 
53  class Graph;
54  class InputPreprocessor;
55  class InputVectorPreprocessor;
56 
57  // We currently allow several input types
58  // The "ValueMap" is for simple rank-1 inputs
59  typedef std::map<std::string, double> ValueMap;
60  // The "VectorMap" is for sequence inputs
61  typedef std::map<std::string, std::vector<double> > VectorMap;
62 
63  // Graph class
65  {
66  public:
67  // Since a graph has multiple input nodes, we actually call
68  typedef std::map<std::string, ValueMap> NodeMap;
69  typedef std::map<std::string, VectorMap> SeqNodeMap;
70 
71  // In cases where the graph has multiple outputs, we have to
72  // define a "default" output, so that calling "compute" with no
73  // output specified doesn't lead to ambiguity.
75  const std::string& default_output = "");
76 
80 
81  // The simpler "compute" function
82  ValueMap compute(const NodeMap&, const SeqNodeMap& = {}) const;
83 
84  // More complicated version, only needed when you have multiple
85  // output nodes and need to specify the non-default ones
86  ValueMap compute(const NodeMap&, const SeqNodeMap&,
87  const std::string& output) const;
88 
89  // The simpler "scan" function
90  VectorMap scan(const NodeMap&, const SeqNodeMap& = {}) const;
91 
92  // More complicated version, only needed when you have multiple
93  // output nodes and need to specify the non-default ones
94  VectorMap scan(const NodeMap&, const SeqNodeMap&,
95  const std::string& output) const;
96 
97  private:
100  typedef std::vector<std::pair<std::string, IP*> > Preprocs;
101  typedef std::vector<std::pair<std::string, IVP*> > VecPreprocs;
102 
103  ValueMap compute(const NodeMap&, const SeqNodeMap&, size_t) const;
104  VectorMap scan(const NodeMap&, const SeqNodeMap&, size_t) const;
108  std::vector<std::pair<size_t, std::vector<std::string> > > m_outputs;
109  std::map<std::string, size_t> m_output_indices;
111  };
112 }
113 
114 #endif
lwtDev::ValueMap
std::map< std::string, double > ValueMap
Definition: InputPreprocessor.h:22
lwtDev::VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: InputPreprocessor.h:24
lwtDev::InputPreprocessor
Definition: InputPreprocessor.h:30
lwtDev::LightweightGraph::operator=
LightweightGraph & operator=(LightweightGraph &)=delete
lwtDev::InputVectorPreprocessor
Definition: InputPreprocessor.h:42
lwtDev::LightweightGraph::Preprocs
std::vector< std::pair< std::string, IP * > > Preprocs
Definition: LightweightGraph.h:100
lwtDev::LightweightGraph::m_graph
Graph * m_graph
Definition: LightweightGraph.h:105
lwtDev::NodeMap
LightweightGraph::NodeMap NodeMap
Definition: LightweightGraph.cxx:67
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
lwtDev::LightweightGraph::m_preprocs
Preprocs m_preprocs
Definition: LightweightGraph.h:106
lwtDev::LightweightGraph::LightweightGraph
LightweightGraph(LightweightGraph &)=delete
lwtDev::Graph
Definition: Graph.h:120
lwtDev::LightweightGraph
Definition: LightweightGraph.h:65
lwtDev::LightweightGraph::m_outputs
std::vector< std::pair< size_t, std::vector< std::string > > > m_outputs
Definition: LightweightGraph.h:108
lwtDev::LightweightGraph::compute
ValueMap compute(const NodeMap &, const SeqNodeMap &={}) const
Definition: LightweightGraph.cxx:110
lwtDev::LightweightGraph::SeqNodeMap
std::map< std::string, VectorMap > SeqNodeMap
Definition: LightweightGraph.h:69
lwtDev::LightweightGraph::IP
InputPreprocessor IP
Definition: LightweightGraph.h:98
lwtDev::LightweightGraph::~LightweightGraph
~LightweightGraph()
Definition: LightweightGraph.cxx:98
lwtDev::LightweightGraph::NodeMap
std::map< std::string, ValueMap > NodeMap
Definition: LightweightGraph.h:68
merge.output
output
Definition: merge.py:17
lwtDev
Definition: Reconstruction/tauRecTools/Root/lwtnn/Exceptions.cxx:8
lwtDev::LightweightGraph::LightweightGraph
LightweightGraph(const GraphConfig &config, const std::string &default_output="")
Definition: LightweightGraph.cxx:68
lightweight_network_config.h
lwtDev::LightweightGraph::m_vec_preprocs
VecPreprocs m_vec_preprocs
Definition: LightweightGraph.h:107
VectorMap
std::map< std::string, std::vector< double > > VectorMap
Definition: TauDecayModeNNClassifier.cxx:23
lwtDev::LightweightGraph::scan
VectorMap scan(const NodeMap &, const SeqNodeMap &={}) const
Definition: LightweightGraph.cxx:135
lwtDev::LightweightGraph::m_output_indices
std::map< std::string, size_t > m_output_indices
Definition: LightweightGraph.h:109
lwtDev::LightweightGraph::VecPreprocs
std::vector< std::pair< std::string, IVP * > > VecPreprocs
Definition: LightweightGraph.h:101
lwtDev::GraphConfig
Definition: lightweight_network_config.h:58
lwtDev::LightweightGraph::m_default_output
size_t m_default_output
Definition: LightweightGraph.h:110
lwtDev::LightweightGraph::IVP
InputVectorPreprocessor IVP
Definition: LightweightGraph.h:99