ATLAS Offline Software
conifer.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef CONIFER_CPP_H__
6 #define CONIFER_CPP_H__
7 #include "nlohmann/json.hpp"
8 #include <vector>
9 #include <algorithm>
10 #include <numeric> //std::accumulate
11 #include <cassert>
12 #include <fstream>
13 
14 namespace conifer {
15 
16 /* ---
17  * Balanced tree reduce implementation.
18  * Reduces an array of inputs to a single value using the template binary
19  * operator 'Op', for example summing all elements with Op_add, or finding the
20  * maximum with Op_max Use only when the input array is fully unrolled. Or,
21  * slice out a fully unrolled section before applying and accumulate the result
22  * over the rolled dimension. Required for emulation to guarantee equality of
23  * ordering.
24  * --- */
25 constexpr int floorlog2(int x) { return (x < 2) ? 0 : 1 + floorlog2(x / 2); }
26 
27 template <int B> constexpr int pow(int x) {
28  return x == 0 ? 1 : B * pow<B>(x - 1);
29 }
30 
31 constexpr int pow2(int x) { return pow<2>(x); }
32 
33 template <class T, class Op> T reduce(std::vector<T> x, Op op) {
34  int N = x.size();
35  int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0;
36  // static constexpr int rightN = N - leftN > 0 ? N - leftN : 0;
37  if (N == 1) {
38  return x.at(0);
39  } else if (N == 2) {
40  return op(x.at(0), x.at(1));
41  } else {
42  std::vector<T> left(x.begin(), x.begin() + leftN);
43  std::vector<T> right(x.begin() + leftN, x.end());
44  return op(reduce<T, Op>(std::move(left), op), reduce<T, Op>(std::move(right), op));
45  }
46 }
47 
48 template <class T> class OpAdd {
49 public:
50  T operator()(T a, T b) { return a + b; }
51 };
52 
53 template <class T, class U> class DecisionTree {
54 
55 private:
56  std::vector<int> m_feature;
57  std::vector<int> m_children_left;
58  std::vector<int> m_children_right;
59  std::vector<T> m_threshold_;
60  std::vector<U> m_value_;
61  std::vector<double> m_threshold;
62  std::vector<double> m_value;
63 
64 public:
65  U decision_function(const std::vector<T>& x) const {
66  /* Do the prediction */
67  int i = 0;
68  while (m_feature[i] != -2) { // continue until reaching leaf
69  bool comparison = x[m_feature[i]] <= m_threshold_[i];
70  i = comparison ? m_children_left[i] : m_children_right[i];
71  }
72  return m_value_[i];
73  }
74 
75  void init_() {
76  /* Since T, U types may not be readable from the JSON, read them to double
77  * and the cast them here */
78  std::transform(m_threshold.begin(), m_threshold.end(),
79  std::back_inserter(m_threshold_),
80  [](double t) -> T { return (T)t; });
81  std::transform(m_value.begin(), m_value.end(), std::back_inserter(m_value_),
82  [](double v) -> U { return (U)v; });
83  }
84 
85  // Define how to read this class to/from JSON
86  friend void from_json(const nlohmann::json &j, DecisionTree &o) {
87  j.at("feature").get_to(o.m_feature);
88  j.at("children_left").get_to(o.m_children_left);
89  j.at("children_right").get_to(o.m_children_right);
90  j.at("threshold").get_to(o.m_threshold);
91  j.at("value").get_to(o.m_value);
92  }
93 
94 }; // class DecisionTree
95 
96 template <class T, class U, bool useAddTree = false> class BDT {
97 
98 private:
102  std::vector<double> m_init_predict;
103  std::vector<U> m_init_predict_;
104  // vector of decision trees: outer dimension tree, inner dimension class
105  std::vector<std::vector<DecisionTree<T, U>>> m_trees;
107 
108 public:
109  void init(/*std::string filename*/) {
110  /* Construct the BDT from conifer cpp backend JSON file */
111  // std::ifstream ifs(filename);
112  // nlohmann::json j = nlohmann::json::parse(ifs);
113  // from_json(j, *this);
114  /* Do some transformation to initialise things into the proper emulation T,
115  * U types */
116  if (m_n_classes == 2)
117  m_n_classes = 1;
119  std::back_inserter(m_init_predict_),
120  [](double ip) -> U { return (U)ip; });
121  for (int i = 0; i < m_n_trees; i++) {
122  for (int j = 0; j < m_n_classes; j++) {
123  m_trees.at(i).at(j).init_();
124  }
125  }
126  }
127 
128  std::vector<U> decision_function(std::vector<T> x) const {
129  /* Do the prediction */
130  assert("Size of feature vector mismatches expected m_n_features" &&
131  x.size() == static_cast<size_t>(m_n_features));
132  std::vector<U> values;
133  std::vector<std::vector<U>> values_trees;
134  values_trees.resize(m_n_classes);
135  values.resize(m_n_classes, U(0));
136  for (int i = 0; i < m_n_classes; i++) {
138  m_trees.begin(), m_trees.end(),
139  std::back_inserter(values_trees.at(i)),
140  [&i, &x](auto tree_v) { return tree_v.at(i).decision_function(x); });
141  if (useAddTree) {
142  values.at(i) = m_init_predict_.at(i);
143  values.at(i) += reduce<U, OpAdd<U>>(values_trees.at(i), m_add);
144  } else {
145  values.at(i) =
146  std::accumulate(values_trees.at(i).begin(),
147  values_trees.at(i).end(), U(m_init_predict_.at(i)));
148  }
149  }
150 
151  return values;
152  }
153 
154  // Define how to read this class to/from JSON
155  friend void from_json(const nlohmann::json &j, BDT &o) {
156  j.at("n_classes").get_to(o.m_n_classes);
157  j.at("n_trees").get_to(o.m_n_trees);
158  j.at("n_features").get_to(o.m_n_features);
159  j.at("init_predict").get_to(o.m_init_predict);
160  j.at("trees").get_to(o.m_trees);
161  }
162 }; // class BDT
163 
164 } // namespace conifer
165 
166 #endif
conifer::DecisionTree::m_threshold_
std::vector< T > m_threshold_
Definition: conifer.h:59
conifer::BDT::from_json
friend void from_json(const nlohmann::json &j, BDT &o)
Definition: conifer.h:155
conifer::BDT::m_n_features
int m_n_features
Definition: conifer.h:101
conifer::DecisionTree::m_threshold
std::vector< double > m_threshold
Definition: conifer.h:61
json
nlohmann::json json
Definition: HistogramDef.cxx:9
runITkAlign.accumulate
accumulate
Update flags based on parser line args.
Definition: runITkAlign.py:62
conifer::DecisionTree::m_feature
std::vector< int > m_feature
Definition: conifer.h:56
conifer::pow
constexpr int pow(int x)
Definition: conifer.h:27
conifer::OpAdd::operator()
T operator()(T a, T b)
Definition: conifer.h:50
conifer::DecisionTree::decision_function
U decision_function(const std::vector< T > &x) const
Definition: conifer.h:65
conifer
Definition: conifer.h:14
JetTiledMap::N
@ N
Definition: TiledEtaPhiMap.h:44
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
conifer::BDT::m_init_predict_
std::vector< U > m_init_predict_
Definition: conifer.h:103
conifer::DecisionTree::m_value
std::vector< double > m_value
Definition: conifer.h:62
x
#define x
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:808
conifer::floorlog2
constexpr int floorlog2(int x)
Definition: conifer.h:25
conifer::BDT::m_n_classes
int m_n_classes
Definition: conifer.h:99
conifer::pow2
constexpr int pow2(int x)
Definition: conifer.h:31
lumiFormat.i
int i
Definition: lumiFormat.py:85
conifer::BDT
Definition: conifer.h:96
conifer::BDT::decision_function
std::vector< U > decision_function(std::vector< T > x) const
Definition: conifer.h:128
Amg::transform
Amg::Vector3D transform(Amg::Vector3D &v, Amg::Transform3D &tr)
Transform a point from a Trasformation3D.
Definition: GeoPrimitivesHelpers.h:156
find_tgc_unfilled_channelids.ip
ip
Definition: find_tgc_unfilled_channelids.py:3
conifer::DecisionTree::m_value_
std::vector< U > m_value_
Definition: conifer.h:60
conifer::BDT::init
void init()
Definition: conifer.h:109
conifer::BDT::m_add
OpAdd< U > m_add
Definition: conifer.h:106
conifer::reduce
T reduce(std::vector< T > x, Op op)
Definition: conifer.h:33
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:76
conifer::DecisionTree::m_children_right
std::vector< int > m_children_right
Definition: conifer.h:58
conifer::BDT::m_n_trees
int m_n_trees
Definition: conifer.h:100
conifer::DecisionTree::from_json
friend void from_json(const nlohmann::json &j, DecisionTree &o)
Definition: conifer.h:86
conifer::DecisionTree::m_children_left
std::vector< int > m_children_left
Definition: conifer.h:57
python.PyAthena.v
v
Definition: PyAthena.py:154
conifer::OpAdd
Definition: conifer.h:48
a
TList * a
Definition: liststreamerinfos.cxx:10
conifer::DecisionTree
Definition: conifer.h:53
conifer::DecisionTree::init_
void init_()
Definition: conifer.h:75
conifer::BDT::m_trees
std::vector< std::vector< DecisionTree< T, U > > > m_trees
Definition: conifer.h:105
conifer::BDT::m_init_predict
std::vector< double > m_init_predict
Definition: conifer.h:102