ATLAS Offline Software
Public Member Functions | Private Attributes | Friends | List of all members
conifer::BDT< T, U, useAddTree > Class Template Reference

#include <conifer.h>

Collaboration diagram for conifer::BDT< T, U, useAddTree >:

Public Member Functions

void init ()
 
std::vector< U > decision_function (std::vector< T > x) const
 

Private Attributes

int m_n_classes
 
int m_n_trees
 
int m_n_features
 
std::vector< double > m_init_predict
 
std::vector< U > m_init_predict_
 
std::vector< std::vector< DecisionTree< T, U > > > m_trees
 
OpAdd< U > m_add
 

Friends

void from_json (const nlohmann::json &j, BDT &o)
 

Detailed Description

template<class T, class U, bool useAddTree = false>
class conifer::BDT< T, U, useAddTree >

Definition at line 89 of file conifer.h.

Member Function Documentation

◆ decision_function()

template<class T , class U , bool useAddTree = false>
std::vector<U> conifer::BDT< T, U, useAddTree >::decision_function ( std::vector< T >  x) const
inline

Definition at line 121 of file conifer.h.

121  {
122  /* Do the prediction */
123  assert("Size of feature vector mismatches expected m_n_features" &&
124  x.size() == static_cast<size_t>(m_n_features));
125  std::vector<U> values;
126  std::vector<std::vector<U>> values_trees;
127  values_trees.resize(m_n_classes);
128  values.resize(m_n_classes, U(0));
129  for (int i = 0; i < m_n_classes; i++) {
131  m_trees.begin(), m_trees.end(),
132  std::back_inserter(values_trees.at(i)),
133  [&i, &x](auto tree_v) { return tree_v.at(i).decision_function(x); });
134  if (useAddTree) {
135  values.at(i) = m_init_predict_.at(i);
136  values.at(i) += reduce<U, OpAdd<U>>(values_trees.at(i), m_add);
137  } else {
138  values.at(i) =
139  std::accumulate(values_trees.at(i).begin(),
140  values_trees.at(i).end(), U(m_init_predict_.at(i)));
141  }
142  }
143 
144  return values;
145  }

◆ init()

template<class T , class U , bool useAddTree = false>
void conifer::BDT< T, U, useAddTree >::init ( )
inline

Definition at line 102 of file conifer.h.

102  {
103  /* Construct the BDT from conifer cpp backend JSON file */
104  // std::ifstream ifs(filename);
105  // nlohmann::json j = nlohmann::json::parse(ifs);
106  // from_json(j, *this);
107  /* Do some transformation to initialise things into the proper emulation T,
108  * U types */
109  if (m_n_classes == 2)
110  m_n_classes = 1;
112  std::back_inserter(m_init_predict_),
113  [](double ip) -> U { return (U)ip; });
114  for (int i = 0; i < m_n_trees; i++) {
115  for (int j = 0; j < m_n_classes; j++) {
116  m_trees.at(i).at(j).init_();
117  }
118  }
119  }

Friends And Related Function Documentation

◆ from_json

template<class T , class U , bool useAddTree = false>
void from_json ( const nlohmann::json j,
BDT< T, U, useAddTree > &  o 
)
friend

Definition at line 148 of file conifer.h.

148  {
149  j.at("n_classes").get_to(o.m_n_classes);
150  j.at("n_trees").get_to(o.m_n_trees);
151  j.at("n_features").get_to(o.m_n_features);
152  j.at("init_predict").get_to(o.m_init_predict);
153  j.at("trees").get_to(o.m_trees);
154  }

Member Data Documentation

◆ m_add

template<class T , class U , bool useAddTree = false>
OpAdd<U> conifer::BDT< T, U, useAddTree >::m_add
private

Definition at line 99 of file conifer.h.

◆ m_init_predict

template<class T , class U , bool useAddTree = false>
std::vector<double> conifer::BDT< T, U, useAddTree >::m_init_predict
private

Definition at line 95 of file conifer.h.

◆ m_init_predict_

template<class T , class U , bool useAddTree = false>
std::vector<U> conifer::BDT< T, U, useAddTree >::m_init_predict_
private

Definition at line 96 of file conifer.h.

◆ m_n_classes

template<class T , class U , bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_classes
private

Definition at line 92 of file conifer.h.

◆ m_n_features

template<class T , class U , bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_features
private

Definition at line 94 of file conifer.h.

◆ m_n_trees

template<class T , class U , bool useAddTree = false>
int conifer::BDT< T, U, useAddTree >::m_n_trees
private

Definition at line 93 of file conifer.h.

◆ m_trees

template<class T , class U , bool useAddTree = false>
std::vector<std::vector<DecisionTree<T, U> > > conifer::BDT< T, U, useAddTree >::m_trees
private

Definition at line 98 of file conifer.h.


The documentation for this class was generated from the following file:
conifer::BDT::m_n_features
int m_n_features
Definition: conifer.h:94
accumulate
bool accumulate(AccumulateMap &map, std::vector< module_t > const &modules, FPGATrackSimMatrixAccumulator const &acc)
Accumulates an accumulator (e.g.
Definition: FPGATrackSimMatrixAccumulator.cxx:22
conifer::BDT::m_init_predict_
std::vector< U > m_init_predict_
Definition: conifer.h:96
x
#define x
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:797
conifer::BDT::m_n_classes
int m_n_classes
Definition: conifer.h:92
lumiFormat.i
int i
Definition: lumiFormat.py:92
Amg::transform
Amg::Vector3D transform(Amg::Vector3D &v, Amg::Transform3D &tr)
Transform a point from a Trasformation3D.
Definition: GeoPrimitivesHelpers.h:156
find_tgc_unfilled_channelids.ip
ip
Definition: find_tgc_unfilled_channelids.py:3
conifer::BDT::m_add
OpAdd< U > m_add
Definition: conifer.h:99
conifer::BDT::m_n_trees
int m_n_trees
Definition: conifer.h:93
conifer::BDT::m_trees
std::vector< std::vector< DecisionTree< T, U > > > m_trees
Definition: conifer.h:98
conifer::BDT::m_init_predict
std::vector< double > m_init_predict
Definition: conifer.h:95