ATLAS Offline Software
Public Member Functions | Private Attributes | Friends | List of all members
conifer::BDT< T, U, useAddTree > Class Template Reference

#include <conifer.h>

Collaboration diagram for conifer::BDT< T, U, useAddTree >:

Public Member Functions

void init ()
 
std::vector< U > decision_function (std::vector< T > x) const
 

Private Attributes

int m_n_classes
 
int m_n_trees
 
int m_n_features
 
std::vector< double > m_init_predict
 
std::vector< U > m_init_predict_
 
std::vector< std::vector< DecisionTree< T, U > > > m_trees
 
OpAdd< U > m_add
 

Friends

void from_json (const nlohmann::json &j, BDT &o)
 

Detailed Description

template<class T, class U, bool useAddTree = false>
class conifer::BDT< T, U, useAddTree >

Definition at line 96 of file conifer.h.

Member Function Documentation

◆ decision_function()

template<class T , class U , bool useAddTree = false>
std::vector<U> conifer::BDT< T, U, useAddTree >::decision_function ( std::vector< T >  x) const
inline

Definition at line 128 of file conifer.h.

128  {
129  /* Do the prediction */
130  assert("Size of feature vector mismatches expected m_n_features" &&
131  x.size() == static_cast<size_t>(m_n_features));
132  std::vector<U> values;
133  std::vector<std::vector<U>> values_trees;
134  values_trees.resize(m_n_classes);
135  values.resize(m_n_classes, U(0));
136  for (int i = 0; i < m_n_classes; i++) {
138  m_trees.begin(), m_trees.end(),
139  std::back_inserter(values_trees.at(i)),
140  [&i, &x](auto tree_v) { return tree_v.at(i).decision_function(x); });
141  if (useAddTree) {
142  values.at(i) = m_init_predict_.at(i);
143  values.at(i) += reduce<U, OpAdd<U>>(values_trees.at(i), m_add);
144  } else {
145  values.at(i) =
146  std::accumulate(values_trees.at(i).begin(),
147  values_trees.at(i).end(), U(m_init_predict_.at(i)));
148  }
149  }
150 
151  return values;
152  }

◆ init()

template<class T , class U , bool useAddTree = false>
void conifer::BDT< T, U, useAddTree >::init ( )
inline

Definition at line 109 of file conifer.h.

109  {
110  /* Construct the BDT from conifer cpp backend JSON file */
111  // std::ifstream ifs(filename);
112  // nlohmann::json j = nlohmann::json::parse(ifs);
113  // from_json(j, *this);
114  /* Do some transformation to initialise things into the proper emulation T,
115  * U types */
116  if (m_n_classes == 2)
117  m_n_classes = 1;
119  std::back_inserter(m_init_predict_),
120  [](double ip) -> U { return (U)ip; });
121  for (int i = 0; i < m_n_trees; i++) {
122  for (int j = 0; j < m_n_classes; j++) {
123  m_trees.at(i).at(j).init_();
124  }
125  }
126  }

Friends And Related Function Documentation

◆ from_json

template<class T , class U , bool useAddTree = false>
void from_json ( const nlohmann::json j,
BDT< T, U, useAddTree > &  o 
)
friend

Definition at line 155 of file conifer.h.

155  {
156  j.at("n_classes").get_to(o.m_n_classes);
157  j.at("n_trees").get_to(o.m_n_trees);
158  j.at("n_features").get_to(o.m_n_features);
159  j.at("init_predict").get_to(o.m_init_predict);
160  j.at("trees").get_to(o.m_trees);
161  }

Member Data Documentation

◆ m_add

template<class T , class U , bool useAddTree = false>
OpAdd<U> conifer::BDT< T, U, useAddTree >::m_add
private

Definition at line 106 of file conifer.h.

◆ m_init_predict

template<class T , class U , bool useAddTree = false>
std::vector<double> conifer::BDT< T, U, useAddTree >::m_init_predict
private

Definition at line 102 of file conifer.h.

◆ m_init_predict_

template<class T , class U , bool useAddTree = false>
std::vector<U> conifer::BDT< T, U, useAddTree >::m_init_predict_
private

Definition at line 103 of file conifer.h.

◆ m_n_classes

template<class T , class U , bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_classes
private

Definition at line 99 of file conifer.h.

◆ m_n_features

template<class T , class U , bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_features
private

Definition at line 101 of file conifer.h.

◆ m_n_trees

template<class T , class U , bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_trees
private

Definition at line 100 of file conifer.h.

◆ m_trees

template<class T , class U , bool useAddTree = false>
std::vector<std::vector<DecisionTree<T, U> > > conifer::BDT< T, U, useAddTree >::m_trees
private

Definition at line 105 of file conifer.h.


The documentation for this class was generated from the following file:
conifer::BDT::m_n_features
int m_n_features
Definition: conifer.h:101
runITkAlign.accumulate
accumulate
Update flags based on parser line args.
Definition: runITkAlign.py:62
conifer::BDT::m_init_predict_
std::vector< U > m_init_predict_
Definition: conifer.h:103
x
#define x
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:808
conifer::BDT::m_n_classes
int m_n_classes
Definition: conifer.h:99
lumiFormat.i
int i
Definition: lumiFormat.py:85
Amg::transform
Amg::Vector3D transform(Amg::Vector3D &v, Amg::Transform3D &tr)
Transform a point from a Trasformation3D.
Definition: GeoPrimitivesHelpers.h:156
find_tgc_unfilled_channelids.ip
ip
Definition: find_tgc_unfilled_channelids.py:3
conifer::BDT::m_add
OpAdd< U > m_add
Definition: conifer.h:106
conifer::BDT::m_n_trees
int m_n_trees
Definition: conifer.h:100
conifer::BDT::m_trees
std::vector< std::vector< DecisionTree< T, U > > > m_trees
Definition: conifer.h:105
conifer::BDT::m_init_predict
std::vector< double > m_init_predict
Definition: conifer.h:102