ATLAS Offline Software
Loading...
Searching...
No Matches
conifer.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef CONIFER_CPP_H__
6#define CONIFER_CPP_H__
7#include "nlohmann/json.hpp"
8#include <vector>
9#include <algorithm>
10#include <numeric> //std::accumulate
11#include <cassert>
12#include <fstream>
13
14namespace conifer {
15
16/* ---
17 * Balanced tree reduce implementation.
18 * Reduces an array of inputs to a single value using the template binary
19 * operator 'Op', for example summing all elements with Op_add, or finding the
20 * maximum with Op_max Use only when the input array is fully unrolled. Or,
21 * slice out a fully unrolled section before applying and accumulate the result
22 * over the rolled dimension. Required for emulation to guarantee equality of
23 * ordering.
24 * --- */
25constexpr int floorlog2(int x) { return (x < 2) ? 0 : 1 + floorlog2(x / 2); }
26
27template <int B> constexpr int pow(int x) {
28 return x == 0 ? 1 : B * pow<B>(x - 1);
29}
30
31constexpr int pow2(int x) { return pow<2>(x); }
32
33template <class T, class Op> T reduce(std::vector<T> x, Op op) {
34 int N = x.size();
35 int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0;
36 // static constexpr int rightN = N - leftN > 0 ? N - leftN : 0;
37 if (N == 1) {
38 return x.at(0);
39 } else if (N == 2) {
40 return op(x.at(0), x.at(1));
41 } else {
42 std::vector<T> left(x.begin(), x.begin() + leftN);
43 std::vector<T> right(x.begin() + leftN, x.end());
44 return op(reduce<T, Op>(std::move(left), op), reduce<T, Op>(std::move(right), op));
45 }
46}
47
48template <class T> class OpAdd {
49public:
50 T operator()(T a, T b) { return a + b; }
51};
52
53template <class T, class U> class DecisionTree {
54
55private:
56 std::vector<int> m_feature;
57 std::vector<int> m_children_left;
58 std::vector<int> m_children_right;
59 std::vector<T> m_threshold_;
60 std::vector<U> m_value_;
61 std::vector<double> m_threshold;
62 std::vector<double> m_value;
63
64public:
65 U decision_function(const std::vector<T>& x) const {
66 /* Do the prediction */
67 int i = 0;
68 while (m_feature[i] != -2) { // continue until reaching leaf
69 bool comparison = x[m_feature[i]] <= m_threshold_[i];
70 i = comparison ? m_children_left[i] : m_children_right[i];
71 }
72 return m_value_[i];
73 }
74
75 void init_() {
76 /* Since T, U types may not be readable from the JSON, read them to double
77 * and the cast them here */
78 std::transform(m_threshold.begin(), m_threshold.end(),
79 std::back_inserter(m_threshold_),
80 [](double t) -> T { return (T)t; });
81 std::transform(m_value.begin(), m_value.end(), std::back_inserter(m_value_),
82 [](double v) -> U { return (U)v; });
83 }
84
85 // Define how to read this class to/from JSON
86 friend void from_json(const nlohmann::json &j, DecisionTree &o) {
87 j.at("feature").get_to(o.m_feature);
88 j.at("children_left").get_to(o.m_children_left);
89 j.at("children_right").get_to(o.m_children_right);
90 j.at("threshold").get_to(o.m_threshold);
91 j.at("value").get_to(o.m_value);
92 }
93
94}; // class DecisionTree
95
96template <class T, class U, bool useAddTree = false> class BDT {
97
98private:
102 std::vector<double> m_init_predict;
103 std::vector<U> m_init_predict_;
104 // vector of decision trees: outer dimension tree, inner dimension class
105 std::vector<std::vector<DecisionTree<T, U>>> m_trees;
107
108public:
109 void init(/*std::string filename*/) {
110 /* Construct the BDT from conifer cpp backend JSON file */
111 // std::ifstream ifs(filename);
112 // nlohmann::json j = nlohmann::json::parse(ifs);
113 // from_json(j, *this);
114 /* Do some transformation to initialise things into the proper emulation T,
115 * U types */
116 if (m_n_classes == 2)
117 m_n_classes = 1;
118 std::transform(m_init_predict.begin(), m_init_predict.end(),
119 std::back_inserter(m_init_predict_),
120 [](double ip) -> U { return (U)ip; });
121 for (int i = 0; i < m_n_trees; i++) {
122 for (int j = 0; j < m_n_classes; j++) {
123 m_trees.at(i).at(j).init_();
124 }
125 }
126 }
127
128 std::vector<U> decision_function(std::vector<T> x) const {
129 /* Do the prediction */
130 assert("Size of feature vector mismatches expected m_n_features" &&
131 x.size() == static_cast<size_t>(m_n_features));
132 std::vector<U> values;
133 std::vector<std::vector<U>> values_trees;
134 values_trees.resize(m_n_classes);
135 values.resize(m_n_classes, U(0));
136 for (int i = 0; i < m_n_classes; i++) {
137 std::transform(
138 m_trees.begin(), m_trees.end(),
139 std::back_inserter(values_trees.at(i)),
140 [&i, &x](auto tree_v) { return tree_v.at(i).decision_function(x); });
141 if (useAddTree) {
142 values.at(i) = m_init_predict_.at(i);
143 values.at(i) += reduce<U, OpAdd<U>>(values_trees.at(i), m_add);
144 } else {
145 values.at(i) =
146 std::accumulate(values_trees.at(i).begin(),
147 values_trees.at(i).end(), U(m_init_predict_.at(i)));
148 }
149 }
150
151 return values;
152 }
153
154 // Define how to read this class to/from JSON
155 friend void from_json(const nlohmann::json &j, BDT &o) {
156 j.at("n_classes").get_to(o.m_n_classes);
157 j.at("n_trees").get_to(o.m_n_trees);
158 j.at("n_features").get_to(o.m_n_features);
159 j.at("init_predict").get_to(o.m_init_predict);
160 j.at("trees").get_to(o.m_trees);
161 }
162}; // class BDT
163
164} // namespace conifer
165
166#endif
static Double_t a
static const std::map< unsigned int, unsigned int > pow2
#define x
std::vector< double > m_init_predict
Definition conifer.h:102
std::vector< U > decision_function(std::vector< T > x) const
Definition conifer.h:128
OpAdd< U > m_add
Definition conifer.h:106
std::vector< std::vector< DecisionTree< T, U > > > m_trees
Definition conifer.h:105
int m_n_classes
Definition conifer.h:99
int m_n_trees
Definition conifer.h:100
int m_n_features
Definition conifer.h:101
std::vector< U > m_init_predict_
Definition conifer.h:103
friend void from_json(const nlohmann::json &j, BDT &o)
Definition conifer.h:155
void init()
Definition conifer.h:109
std::vector< int > m_feature
Definition conifer.h:56
std::vector< double > m_threshold
Definition conifer.h:61
std::vector< int > m_children_right
Definition conifer.h:58
std::vector< U > m_value_
Definition conifer.h:60
std::vector< T > m_threshold_
Definition conifer.h:59
friend void from_json(const nlohmann::json &j, DecisionTree &o)
Definition conifer.h:86
std::vector< int > m_children_left
Definition conifer.h:57
U decision_function(const std::vector< T > &x) const
Definition conifer.h:65
std::vector< double > m_value
Definition conifer.h:62
T operator()(T a, T b)
Definition conifer.h:50
constexpr int floorlog2(int x)
Definition conifer.h:25
T reduce(std::vector< T > x, Op op)
Definition conifer.h:33
constexpr int pow(int x)
Definition conifer.h:27