ATLAS Offline Software
BDTHelper.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 #include "TFile.h"
8 #include "TTree.h"
9 #include "TObjArray.h"
10 #include "TObjString.h"
11 
12 namespace tauRecTools {
13 
15  asg::AsgMessaging("BDTHelper"),
16  m_BDT(nullptr) {
17 }
18 
19 
21 }
22 
23 StatusCode BDTHelper::initialize(const TString& weightFileName) {
24 
25  std::unique_ptr<TFile> file(TFile::Open(weightFileName));
26  if (!file) {
27  ATH_MSG_ERROR("Cannot find input BDT file: " << weightFileName);
28  return StatusCode::FAILURE;
29  }
30  ATH_MSG_INFO( "Open file: " << weightFileName);
31 
32  TTree* tree = dynamic_cast<TTree*> (file->Get("BDT"));
33  if (!tree) {
34  ATH_MSG_ERROR("Cannot find input BDT tree");
35  return StatusCode::FAILURE;
36  }
37  m_BDT = std::make_unique<MVAUtils::BDT>(tree);
38 
39  TNamed* varList = dynamic_cast<TNamed*> (file->Get("varList"));
40  if (!varList) {
41  ATH_MSG_ERROR("No variable list in file: " << weightFileName);
42  return StatusCode::FAILURE;
43  }
44  TString names = varList->GetTitle();
45  delete varList;
46 
47  // abtain the list of input variables
49 
50  file->Close();
51 
52  return StatusCode::SUCCESS;
53 }
54 
55 std::vector<TString> BDTHelper::parseString(const TString& str, const TString& delim/*=","*/) const {
56  std::vector<TString> parsedString;
57 
58  TObjArray* objList = str.Tokenize(delim);
59  size_t arraySize = objList->GetEntries();
60 
61  // split the string with ",", and put them into a vector
62  for(size_t i = 0; i < arraySize; ++i) {
63  if (auto *str = dynamic_cast<TObjString*> (objList->At(i))) {
64  TString var = str->String();
65  var.ReplaceAll(" ", "");
66  if(var.Contains(":=")) {
67  var=var(var.Index(":=")+2, var.Length()-var.Index(":=")-2);
68  }
69  if(0==var.Length()) continue;
70  parsedString.push_back(var);
71  }
72  }
73 
74  delete objList;
75 
76  return parsedString;
77 }
78 
79 std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float>& availableVariables) const {
80  std::vector<float> values;
81 
82  // sort the input variables by the order in varList (from BDT)
83  for (const TString& name : m_inputVariableNames) {
84  std::map<TString, float>::const_iterator itr = availableVariables.find(name);
85  if(itr==availableVariables.end()) {
86  ATH_MSG_ERROR(name << " not available");
87  }
88  else {
89  values.push_back(itr->second);
90  }
91  }
92 
93  return values;
94 }
95 
96 std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float*>& availableVariables) const {
97  std::vector<float> values;
98 
99  // sort the input variables by the order in varList (from BDT)
100  for (const TString& name : m_inputVariableNames) {
101  std::map<TString, float*>::const_iterator itr = availableVariables.find(name);
102  if(itr==availableVariables.end()) {
103  ATH_MSG_ERROR(name << " not available");
104  }
105  else {
106  values.push_back(*itr->second);
107  }
108  }
109 
110  return values;
111 }
112 
113 float BDTHelper::getGradBoostMVA(const std::map<TString, float>& availableVariables) const {
114  std::vector<float> values = getInputVariables(availableVariables);
115 
116  if (values.size() < m_inputVariableNames.size()) {
117  ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
118  return -999;
119  }
120  else {
121  return m_BDT->GetGradBoostMVA(values);
122  }
123 }
124 
125 
126 float BDTHelper::getResponse(const std::map<TString, float*>& availableVariables) const {
127  std::vector<float> values = getInputVariables(availableVariables);
128 
129  if (values.size() < m_inputVariableNames.size()) {
130  ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
131  return -999;
132  }
133  else {
134  return m_BDT->GetResponse(values);
135  }
136 }
137 
138 } // end of namespace tauRecTools
beamspotnt.var
var
Definition: bin/beamspotnt.py:1393
tauRecTools::BDTHelper::m_BDT
std::unique_ptr< MVAUtils::BDT > m_BDT
Definition: BDTHelper.h:36
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
BDTHelper.h
tree
TChain * tree
Definition: tile_monitor.h:30
asg
Definition: DataHandleTestTool.h:28
tauRecTools::BDTHelper::getResponse
float getResponse(const std::map< TString, float * > &availableVariables) const
Definition: BDTHelper.cxx:126
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:808
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
python.subdetectors.mmg.names
names
Definition: mmg.py:8
tauRecTools::BDTHelper::~BDTHelper
~BDTHelper()
Definition: BDTHelper.cxx:20
file
TFile * file
Definition: tile_monitor.h:29
PyPoolBrowser.objList
dictionary objList
Definition: PyPoolBrowser.py:103
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
tauRecTools::BDTHelper::getGradBoostMVA
float getGradBoostMVA(const std::map< TString, float > &availableVariables) const
Definition: BDTHelper.cxx:113
tauRecTools::BDTHelper::BDTHelper
BDTHelper()
Definition: BDTHelper.cxx:14
tauRecTools::BDTHelper::parseString
std::vector< TString > parseString(const TString &str, const TString &delim=",") const
Definition: BDTHelper.cxx:55
beamspotnt.varList
list varList
Definition: bin/beamspotnt.py:1107
tauRecTools::BDTHelper::initialize
StatusCode initialize(const TString &weightFileName)
Definition: BDTHelper.cxx:23
str
Definition: BTagTrackIpAccessor.cxx:11
tauRecTools
Implementation of a TrackClassifier based on an RNN.
Definition: BDTHelper.cxx:12
tauRecTools::BDTHelper::getInputVariables
std::vector< float > getInputVariables(const std::map< TString, float > &availableVariables) const
Definition: BDTHelper.cxx:79
tauRecTools::BDTHelper::m_inputVariableNames
std::vector< TString > m_inputVariableNames
Definition: BDTHelper.h:37