ATLAS Offline Software
Loading...
Searching...
No Matches
BDTHelper.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include "TFile.h"
8#include "TTree.h"
9#include "TObjArray.h"
10#include "TObjString.h"
11
12namespace tauRecTools {
13
15 asg::AsgMessaging("BDTHelper"),
16 m_BDT(nullptr) {
17}
18
19
22
23StatusCode BDTHelper::initialize(const TString& weightFileName) {
24
25 std::unique_ptr<TFile> file(TFile::Open(weightFileName));
26 if (!file) {
27 ATH_MSG_ERROR("Cannot find input BDT file: " << weightFileName);
28 return StatusCode::FAILURE;
29 }
30 ATH_MSG_INFO( "Open file: " << weightFileName);
31
32 TTree* tree = dynamic_cast<TTree*> (file->Get("BDT"));
33 if (!tree) {
34 ATH_MSG_ERROR("Cannot find input BDT tree");
35 return StatusCode::FAILURE;
36 }
37 m_BDT = std::make_unique<MVAUtils::BDT>(tree);
38
39 TNamed* varList = dynamic_cast<TNamed*> (file->Get("varList"));
40 if (!varList) {
41 ATH_MSG_ERROR("No variable list in file: " << weightFileName);
42 return StatusCode::FAILURE;
43 }
44 TString names = varList->GetTitle();
45 delete varList;
46
47 // abtain the list of input variables
49
50 file->Close();
51
52 return StatusCode::SUCCESS;
53}
54
55std::vector<TString> BDTHelper::parseString(const TString& str, const TString& delim/*=","*/) const {
56 std::vector<TString> parsedString;
57
58 TObjArray* objList = str.Tokenize(delim);
59 size_t arraySize = objList->GetEntries();
60
61 // split the string with ",", and put them into a vector
62 for(size_t i = 0; i < arraySize; ++i) {
63 if (auto *str = dynamic_cast<TObjString*> (objList->At(i))) {
64 TString var = str->String();
65 var.ReplaceAll(" ", "");
66 if(var.Contains(":=")) {
67 var=var(var.Index(":=")+2, var.Length()-var.Index(":=")-2);
68 }
69 if(0==var.Length()) continue;
70 parsedString.push_back(var);
71 }
72 }
73
74 delete objList;
75
76 return parsedString;
77}
78
79std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float>& availableVariables) const {
80 std::vector<float> values;
81
82 // sort the input variables by the order in varList (from BDT)
83 for (const TString& name : m_inputVariableNames) {
84 std::map<TString, float>::const_iterator itr = availableVariables.find(name);
85 if(itr==availableVariables.end()) {
86 ATH_MSG_ERROR(name << " not available");
87 }
88 else {
89 values.push_back(itr->second);
90 }
91 }
92
93 return values;
94}
95
96std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float*>& availableVariables) const {
97 std::vector<float> values;
98
99 // sort the input variables by the order in varList (from BDT)
100 for (const TString& name : m_inputVariableNames) {
101 std::map<TString, float*>::const_iterator itr = availableVariables.find(name);
102 if(itr==availableVariables.end()) {
103 ATH_MSG_ERROR(name << " not available");
104 }
105 else {
106 values.push_back(*itr->second);
107 }
108 }
109
110 return values;
111}
112
113float BDTHelper::getGradBoostMVA(const std::map<TString, float>& availableVariables) const {
114 std::vector<float> values = getInputVariables(availableVariables);
115
116 if (values.size() < m_inputVariableNames.size()) {
117 ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
118 return -999;
119 }
120 else {
121 return m_BDT->GetGradBoostMVA(values);
122 }
123}
124
125
126float BDTHelper::getResponse(const std::map<TString, float*>& availableVariables) const {
127 std::vector<float> values = getInputVariables(availableVariables);
128
129 if (values.size() < m_inputVariableNames.size()) {
130 ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
131 return -999;
132 }
133 else {
134 return m_BDT->GetResponse(values);
135 }
136}
137
138} // end of namespace tauRecTools
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
AsgMessaging(const std::string &name)
Constructor with a name.
std::vector< TString > parseString(const TString &str, const TString &delim=",") const
Definition BDTHelper.cxx:55
StatusCode initialize(const TString &weightFileName)
Definition BDTHelper.cxx:23
std::vector< TString > m_inputVariableNames
Definition BDTHelper.h:37
float getResponse(const std::map< TString, float * > &availableVariables) const
float getGradBoostMVA(const std::map< TString, float > &availableVariables) const
std::vector< float > getInputVariables(const std::map< TString, float > &availableVariables) const
Definition BDTHelper.cxx:79
std::unique_ptr< MVAUtils::BDT > m_BDT
Definition BDTHelper.h:36
Implementation of a TrackClassifier based on an RNN.
Definition BDTHelper.cxx:12
TChain * tree
TFile * file