ATLAS Offline Software
BDT.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "MVAUtils/BDT.h"
6 #include "MVAUtils/ForestTMVA.h"
7 #include "MVAUtils/ForestLGBM.h"
9 
10 #include "TTree.h"
11 #include <cmath>
12 
13 #include <memory>
14 #include <set>
15 #include <sstream>
16 #include <stack>
17 #include <stdexcept>
18 #include <string>
19 #include <utility>
20 
21 using namespace MVAUtils;
22 
23 namespace{
24 
25 /* utility functions : to split option (e.g. "creator=lgbm;node=lgbm_simple")
26 * in a std::map {{"creator", "lgbm"}, {"node", "lgbm_simple"}}
27 */
28 std::string get_default_string_map(const std::map <std::string, std::string> & m,
29  const std::string& key, const std::string & defval="")
30 {
31  std::map<std::string, std::string>::const_iterator it = m.find(key);
32  if (it == m.end()) { return defval; }
33  return it->second;
34 }
35 
36 std::map<std::string, std::string> parseOptions(const std::string& raw_options)
37 {
38  std::stringstream ss(raw_options);
39  std::map<std::string, std::string> options;
40  std::string item;
41  while (std::getline(ss, item, ';')) {
42  auto pos = item.find_first_of('=');
43  const auto right = item.substr(pos+1);
44  const auto left = item.substr(0, pos);
45  if (!options.insert(std::make_pair(left, right)).second)
46  {
47  throw std::runtime_error(std::string("option ") + left +
48  " duplicated in title of TTree used as input");
49  }
50  }
51 
52  return options;
53 }
54 }
55 
57 BDT::BDT(::TTree *tree)
58 {
59  // at runtime decide which flavour of BDT we need to build
60  // the information is coming from the title of the TTree
61 
62  if(!tree){
63  throw std::runtime_error("nullptr to a TTree passed ");
64  }
65  tree->SetCacheSize(0); // Avoid unnecessary memory allocations
66  std::map<std::string, std::string> options = parseOptions(tree->GetTitle());
67  std::string creator = get_default_string_map(options, std::string("creator"));
68  if (creator == "lgbm")
69  {
70  std::string node_type = get_default_string_map (options, std::string("node_type"));
71  if (node_type == "lgbm") {
72  m_forest = std::make_unique<ForestLGBM>(tree);
73  } else if (node_type == "lgbm_simple") {
74  m_forest = std::make_unique<ForestLGBMSimple>(
75  tree); // this do not support nan as inputs
76  } else {
77  throw std::runtime_error(
78  "the title of the input tree is misformatted: cannot understand which "
79  "BDT implementation to use");
80  }
81  } else if (creator == "xgboost") {
82  // this do support nan as inputs
83  m_forest = std::make_unique<ForestXGBoost>(tree);
84  } else {
85  // default for compatibility: old TTree (based on TMVA) don't have a special title
86  m_forest = std::make_unique<ForestTMVA>(tree);
87  }
88 }
89 
90 
91 TTree* BDT::WriteTree(TString name) const { return m_forest->WriteTree(std::move(name)); }
92 void BDT::PrintForest() const { m_forest->PrintForest(); }
93 void BDT::PrintTree(unsigned int itree) const { m_forest->PrintTree(itree); }
python.SystemOfUnits.m
int m
Definition: SystemOfUnits.py:91
MVAUtils
Definition: InDetTrkInJetType.h:47
PowhegControl_ttHplus_NLO.ss
ss
Definition: PowhegControl_ttHplus_NLO.py:83
MVAUtils::BDT::PrintTree
void PrintTree(unsigned int itree) const
Definition: BDT.cxx:93
MVAUtils::BDT::BDT
BDT()=delete
tree
TChain * tree
Definition: tile_monitor.h:30
skel.it
it
Definition: skel.GENtoEVGEN.py:423
ForestLGBM.h
MVAUtils::BDT::m_forest
std::unique_ptr< IForest > m_forest
the implementation of the forest, doing the hard work
Definition: BDT.h:106
PyPoolBrowser.item
item
Definition: PyPoolBrowser.py:129
ForestTMVA.h
ForestXGBoost.h
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:379
BDT.h
MVAUtils::BDT::PrintForest
void PrintForest() const
for debugging, print out tree or forest to stdout
Definition: BDT.cxx:92
MVAUtils::BDT::WriteTree
TTree * WriteTree(TString name="BDT") const
Return a TTree representing the BDT: each entry is a binary tree, each element of the vectors is a no...
Definition: BDT.cxx:91
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
item
Definition: ItemListSvc.h:43
python.LumiBlobConversion.pos
pos
Definition: LumiBlobConversion.py:18
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37