ATLAS Offline Software
BDTHelper.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 #include "TFile.h"
8 #include "TTree.h"
9 #include "TObjArray.h"
10 #include "TObjString.h"
11 
12 namespace tauRecTools {
13 
15  asg::AsgMessaging("BDTHelper"),
16  m_BDT(nullptr) {
17 }
18 
19 
20 
22 }
23 
24 
25 
26 StatusCode BDTHelper::initialize(const TString& weightFileName) {
27 
28  std::unique_ptr<TFile> file(TFile::Open(weightFileName));
29  if (!file) {
30  ATH_MSG_ERROR("Cannot find input BDT file: " << weightFileName);
31  return StatusCode::FAILURE;
32  }
33  ATH_MSG_INFO( "Open file: " << weightFileName);
34 
35  TTree* tree = dynamic_cast<TTree*> (file->Get("BDT"));
36  if (!tree) {
37  ATH_MSG_ERROR("Cannot find input BDT tree");
38  return StatusCode::FAILURE;
39  }
40  m_BDT = std::make_unique<MVAUtils::BDT>(tree);
41 
42  TNamed* varList = dynamic_cast<TNamed*> (file->Get("varList"));
43  if (!varList) {
44  ATH_MSG_ERROR("No variable list in file: " << weightFileName);
45  return StatusCode::FAILURE;
46  }
47  TString names = varList->GetTitle();
48  delete varList;
49 
50  // abtain the list of input variables
52 
53  file->Close();
54 
55  return StatusCode::SUCCESS;
56 }
57 
58 std::vector<TString> BDTHelper::parseString(const TString& str, const TString& delim/*=","*/) const {
59  std::vector<TString> parsedString;
60 
61  TObjArray* objList = str.Tokenize(delim);
62  size_t arraySize = objList->GetEntries();
63 
64  // split the string with ",", and put them into a vector
65  for(size_t i = 0; i < arraySize; ++i) {
66  if (auto *str = dynamic_cast<TObjString*> (objList->At(i))) {
67  TString var = str->String();
68  var.ReplaceAll(" ", "");
69  if(var.Contains(":=")) {
70  var=var(var.Index(":=")+2, var.Length()-var.Index(":=")-2);
71  }
72  if(0==var.Length()) continue;
73  parsedString.push_back(var);
74  }
75  }
76 
77  delete objList;
78 
79  return parsedString;
80 }
81 
82 std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float>& availableVariables) const {
83  std::vector<float> values;
84 
85  // sort the input variables by the order in varList (from BDT)
86  for (const TString& name : m_inputVariableNames) {
87  std::map<TString, float>::const_iterator itr = availableVariables.find(name);
88  if(itr==availableVariables.end()) {
89  ATH_MSG_ERROR(name << " not available");
90  }
91  else {
92  values.push_back(itr->second);
93  }
94  }
95 
96  return values;
97 }
98 
99 std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float*>& availableVariables) const {
100  std::vector<float> values;
101 
102  // sort the input variables by the order in varList (from BDT)
103  for (const TString& name : m_inputVariableNames) {
104  std::map<TString, float*>::const_iterator itr = availableVariables.find(name);
105  if(itr==availableVariables.end()) {
106  ATH_MSG_ERROR(name << " not available");
107  }
108  else {
109  values.push_back(*itr->second);
110  }
111  }
112 
113  return values;
114 }
115 
116 std::vector<float> BDTHelper::getInputVariables(const xAOD::TauJet& tau) const {
117  std::vector<float> values;
118 
119  // obtain the values of input variables by the name
120  // all the variables should be decorated to tau already
121  for (TString name : m_inputVariableNames) {
122  // remove prefix (::TauJets.centFrac -> cenFrac)
123  if(name.Index(".")>=0){
124  name = name(name.Last('.')+1, name.Length()-name.Last('.')-1);
125  }
126 
128  float value = accessor(tau);
129  values.push_back(value);
130  }
131 
132  return values;
133 }
134 
135 
136 float BDTHelper::getGradBoostMVA(const std::map<TString, float>& availableVariables) const {
137  std::vector<float> values = getInputVariables(availableVariables);
138 
139  if (values.size() < m_inputVariableNames.size()) {
140  ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
141  return -999;
142  }
143  else {
144  return m_BDT->GetGradBoostMVA(values);
145  }
146 }
147 
148 
149 float BDTHelper::getResponse(const std::map<TString, float*>& availableVariables) const {
150  std::vector<float> values = getInputVariables(availableVariables);
151 
152  if (values.size() < m_inputVariableNames.size()) {
153  ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
154  return -999;
155  }
156  else {
157  return m_BDT->GetResponse(values);
158  }
159 }
160 
161 
162 float BDTHelper::getClassification(const std::map<TString, float*>& availableVariables) const {
163  std::vector<float> values = getInputVariables(availableVariables);
164 
165  if (values.size() < m_inputVariableNames.size()) {
166  ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
167  return -999;
168  }
169  else {
170  return m_BDT->GetClassification(values);
171  }
172 }
173 
174 
175 float BDTHelper::getGradBoostMVA(const xAOD::TauJet& tau) const {
176  std::vector<float> values = getInputVariables(tau);
177 
178  if (values.size() < m_inputVariableNames.size()) {
179  ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
180  return -999;
181  }
182  else {
183  return m_BDT->GetGradBoostMVA(values);
184  }
185 }
186 
187 } // end of namespace tauRecTools
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
tauRecTools::BDTHelper::m_BDT
std::unique_ptr< MVAUtils::BDT > m_BDT
Definition: BDTHelper.h:42
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
BDTHelper.h
tree
TChain * tree
Definition: tile_monitor.h:30
asg
Definition: DataHandleTestTool.h:28
athena.value
value
Definition: athena.py:122
SG::ConstAccessor
Helper class to provide constant type-safe access to aux data.
Definition: ConstAccessor.h:54
tauRecTools::BDTHelper::getResponse
float getResponse(const std::map< TString, float * > &availableVariables) const
Definition: BDTHelper.cxx:149
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:797
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
tauRecTools::BDTHelper::getClassification
float getClassification(const std::map< TString, float * > &availableVariables) const
Definition: BDTHelper.cxx:162
lumiFormat.i
int i
Definition: lumiFormat.py:92
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
python.subdetectors.mmg.names
names
Definition: mmg.py:8
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
tauRecTools::BDTHelper::~BDTHelper
~BDTHelper()
Definition: BDTHelper.cxx:21
file
TFile * file
Definition: tile_monitor.h:29
PyPoolBrowser.objList
dictionary objList
Definition: PyPoolBrowser.py:103
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
tauRecTools::BDTHelper::getGradBoostMVA
float getGradBoostMVA(const std::map< TString, float > &availableVariables) const
Definition: BDTHelper.cxx:136
xAOD::JetAttributeAccessor::accessor
const AccessorWrapper< T > * accessor(xAOD::JetAttribute::AttributeID id)
Returns an attribute accessor corresponding to an AttributeID.
Definition: JetAccessorMap.h:26
tauRecTools::BDTHelper::BDTHelper
BDTHelper()
Definition: BDTHelper.cxx:14
tauRecTools::BDTHelper::parseString
std::vector< TString > parseString(const TString &str, const TString &delim=",") const
Definition: BDTHelper.cxx:58
beamspotnt.varList
list varList
Definition: bin/beamspotnt.py:1108
tauRecTools::BDTHelper::initialize
StatusCode initialize(const TString &weightFileName)
Definition: BDTHelper.cxx:26
str
Definition: BTagTrackIpAccessor.cxx:11
tauRecTools
Implementation of a TrackClassifier based on an RNN.
Definition: BDTHelper.cxx:12
tauRecTools::BDTHelper::getInputVariables
std::vector< float > getInputVariables(const std::map< TString, float > &availableVariables) const
Definition: BDTHelper.cxx:82
tauRecTools::BDTHelper::m_inputVariableNames
std::vector< TString > m_inputVariableNames
Definition: BDTHelper.h:43