ATLAS Offline Software
Loading...
Searching...
No Matches
conifer::BDT< T, U, useAddTree > Class Template Reference

#include <conifer.h>

Collaboration diagram for conifer::BDT< T, U, useAddTree >:

Public Member Functions

void init ()
std::vector< U > decision_function (std::vector< T > x) const

Private Attributes

int m_n_classes
int m_n_trees
int m_n_features
std::vector< double > m_init_predict
std::vector< U > m_init_predict_
std::vector< std::vector< DecisionTree< T, U > > > m_trees
OpAdd< U > m_add

Friends

void from_json (const nlohmann::json &j, BDT &o)

Detailed Description

template<class T, class U, bool useAddTree = false>
class conifer::BDT< T, U, useAddTree >

Definition at line 96 of file conifer.h.

Member Function Documentation

◆ decision_function()

template<class T, class U, bool useAddTree = false>
std::vector< U > conifer::BDT< T, U, useAddTree >::decision_function ( std::vector< T > x) const
inline

Definition at line 128 of file conifer.h.

128 {
129 /* Do the prediction */
130 assert("Size of feature vector mismatches expected m_n_features" &&
131 x.size() == static_cast<size_t>(m_n_features));
135 values.resize(m_n_classes, U(0));
136 for (int i = 0; i < m_n_classes; i++) {
138 m_trees.begin(), m_trees.end(),
140 [&i, &x](auto tree_v) { return tree_v.at(i).decision_function(x); });
141 if (useAddTree) {
142 values.at(i) = m_init_predict_.at(i);
144 } else {
145 values.at(i) =
146 std::accumulate(values_trees.at(i).begin(),
147 values_trees.at(i).end(), U(m_init_predict_.at(i)));
148 }
149 }
150
151 return values;
152 }
OpAdd< U > m_add
Definition conifer.h:106
std::vector< std::vector< DecisionTree< T, U > > > m_trees
Definition conifer.h:105
int m_n_classes
Definition conifer.h:99
int m_n_features
Definition conifer.h:101
std::vector< U > m_init_predict_
Definition conifer.h:103

◆ init()

template<class T, class U, bool useAddTree = false>
void conifer::BDT< T, U, useAddTree >::init ( )
inline

Definition at line 109 of file conifer.h.

109 {
110 /* Construct the BDT from conifer cpp backend JSON file */
111 // std::ifstream ifs(filename);
112 // nlohmann::json j = nlohmann::json::parse(ifs);
113 // from_json(j, *this);
114 /* Do some transformation to initialise things into the proper emulation T,
115 * U types */
116 if (m_n_classes == 2)
117 m_n_classes = 1;
120 [](double ip) -> U { return (U)ip; });
121 for (int i = 0; i < m_n_trees; i++) {
122 for (int j = 0; j < m_n_classes; j++) {
123 m_trees.at(i).at(j).init_();
124 }
125 }
126 }
std::vector< double > m_init_predict
Definition conifer.h:102
int m_n_trees
Definition conifer.h:100

◆ from_json

template<class T, class U, bool useAddTree = false>
void from_json ( const nlohmann::json & j,
BDT< T, U, useAddTree > & o )
friend

Definition at line 155 of file conifer.h.

155 {
156 j.at("n_classes").get_to(o.m_n_classes);
157 j.at("n_trees").get_to(o.m_n_trees);
158 j.at("n_features").get_to(o.m_n_features);
159 j.at("init_predict").get_to(o.m_init_predict);
160 j.at("trees").get_to(o.m_trees);
161 }

Member Data Documentation

◆ m_add

template<class T, class U, bool useAddTree = false>
OpAdd<U> conifer::BDT< T, U, useAddTree >::m_add
private

Definition at line 106 of file conifer.h.

◆ m_init_predict

template<class T, class U, bool useAddTree = false>
std::vector<double> conifer::BDT< T, U, useAddTree >::m_init_predict
private

Definition at line 102 of file conifer.h.

◆ m_init_predict_

template<class T, class U, bool useAddTree = false>
std::vector<U> conifer::BDT< T, U, useAddTree >::m_init_predict_
private

Definition at line 103 of file conifer.h.

◆ m_n_classes

template<class T, class U, bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_classes
private

Definition at line 99 of file conifer.h.

◆ m_n_features

template<class T, class U, bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_features
private

Definition at line 101 of file conifer.h.

◆ m_n_trees

template<class T, class U, bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_trees
private

Definition at line 100 of file conifer.h.

◆ m_trees

template<class T, class U, bool useAddTree = false>
std::vector<std::vector<DecisionTree<T, U> > > conifer::BDT< T, U, useAddTree >::m_trees
private

Definition at line 105 of file conifer.h.


The documentation for this class was generated from the following file: