ATLAS Offline Software
Loading...
Searching...
No Matches
BDT.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3*/
4
5#include "MVAUtils/BDT.h"
9
10#include "TTree.h"
11#include <cmath>
12
13#include <memory>
14#include <set>
15#include <sstream>
16#include <stack>
17#include <stdexcept>
18#include <string>
19#include <utility>
20
21using namespace MVAUtils;
22
23namespace{
24
25/* utility functions : to split option (e.g. "creator=lgbm;node=lgbm_simple")
26* in a std::map {{"creator", "lgbm"}, {"node", "lgbm_simple"}}
27*/
28std::string get_default_string_map(const std::map <std::string, std::string> & m,
29 const std::string& key, const std::string & defval="")
30{
31 std::map<std::string, std::string>::const_iterator it = m.find(key);
32 if (it == m.end()) { return defval; }
33 return it->second;
34}
35
36std::map<std::string, std::string> parseOptions(const std::string& raw_options)
37{
38 std::stringstream ss(raw_options);
39 std::map<std::string, std::string> options;
40 std::string item;
41 while (std::getline(ss, item, ';')) {
42 auto pos = item.find_first_of('=');
43 const auto right = item.substr(pos+1);
44 const auto left = item.substr(0, pos);
45 if (!options.insert(std::make_pair(left, right)).second)
46 {
47 throw std::runtime_error(std::string("option ") + left +
48 " duplicated in title of TTree used as input");
49 }
50 }
51
52 return options;
53}
54}
55
57BDT::BDT(::TTree *tree)
58{
59 // at runtime decide which flavour of BDT we need to build
60 // the information is coming from the title of the TTree
61
62 if(!tree){
63 throw std::runtime_error("nullptr to a TTree passed ");
64 }
65 tree->SetCacheSize(0); // Avoid unnecessary memory allocations
66 std::map<std::string, std::string> options = parseOptions(tree->GetTitle());
67 std::string creator = get_default_string_map(options, std::string("creator"));
68 if (creator == "lgbm")
69 {
70 std::string node_type = get_default_string_map (options, std::string("node_type"));
71 if (node_type == "lgbm") {
72 m_forest = std::make_unique<ForestLGBM>(tree);
73 } else if (node_type == "lgbm_simple") {
74 m_forest = std::make_unique<ForestLGBMSimple>(
75 tree); // this do not support nan as inputs
76 } else {
77 throw std::runtime_error(
78 "the title of the input tree is misformatted: cannot understand which "
79 "BDT implementation to use");
80 }
81 } else if (creator == "xgboost") {
82 // this do support nan as inputs
83 m_forest = std::make_unique<ForestXGBoost>(tree);
84 } else {
85 // default for compatibility: old TTree (based on TMVA) don't have a special title
86 m_forest = std::make_unique<ForestTMVA>(tree);
87 }
88}
89
90
91TTree* BDT::WriteTree(TString name) const { return m_forest->WriteTree(std::move(name)); }
92void BDT::PrintForest() const { m_forest->PrintForest(); }
93void BDT::PrintTree(unsigned int itree) const { m_forest->PrintTree(itree); }
static Double_t ss
TTree * WriteTree(TString name="BDT") const
Return a TTree representing the BDT: each entry is a binary tree, each element of the vectors is a no...
Definition BDT.cxx:91
std::unique_ptr< IForest > m_forest
the implementation of the forest, doing the hard work
void PrintTree(unsigned int itree) const
Definition BDT.cxx:93
void PrintForest() const
for debugging, print out tree or forest to stdout
Definition BDT.cxx:92
BDT()=delete
TChain * tree