ATLAS Offline Software
Loading...
Searching...
No Matches
convertXmlToRootTree.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3*/
4
5
6#include "TMVAToMVAUtils.h"
7#include "MVAUtils/BDT.h"
8#include "TMVA/Reader.h"
9#include "TMVA/MethodBDT.h"
10
11#include <TString.h>
12#include <TXMLEngine.h>
13#include <TSystem.h>
14#include <TFile.h>
15#include <TRandom3.h>
16
18#include <iostream>
19#include <memory>
20#include <vector>
21
22using namespace std;
23
33
35 TString expression;
36 TString label;
37 TString varType;
38 TString nodeName;
39 float min = 0.0F;
40 float max = 0.0F;
41};
43unsigned int NClass;
44
45std::vector<XmlVariableInfo>
46parseVariables(TXMLEngine *xml, void* node, const TString & nodeName)
47{
48 std::vector<XmlVariableInfo> result;
49 if (!xml || !node) return result;
50
51 // loop over all children inside <Variables> or <Spectators>
52 for (XMLNodePointer_t info_node = xml->GetChild(node); info_node != nullptr;
53 info_node = xml->GetNext(info_node))
54 {
55 XMLAttrPointer_t attr = xml->GetFirstAttr(info_node);
56 XmlVariableInfo varInfo;
57 // loop over the attributes of each child
58 while (attr != nullptr)
59 {
60 TString name = xml->GetAttrName(attr);
61 if (name == "Expression")
62 varInfo.expression = xml->GetAttrValue(attr);
63 else if (name == "Label")
64 varInfo.label = xml->GetAttrValue(attr);
65 else if (name == "Type")
66 varInfo.varType = xml->GetAttrValue(attr);
67 else if (name == "Min") varInfo.min=TString(xml->GetAttrValue(attr)).Atof();
68 else if (name == "Max") varInfo.max=TString(xml->GetAttrValue(attr)).Atof();
69
70 attr = xml->GetNextAttr(attr);
71 }
72 // ATH_MSG_DEBUG("Expression: " << expression << " Label: " << label << " varType: " << varType);
73 varInfo.nodeName = nodeName;
74 result.push_back(varInfo);
75 }
76 return result;
77}
78
79/*
80 * gSystem is a static expression of type TSystem
81 * so this is no re-entrant.
82 */
83std::vector<XmlVariableInfo>
84parseXml ATLAS_NOT_REENTRANT (const TString & xml_filename)
85{
86 std::vector<XmlVariableInfo> result;
87
88 TXMLEngine xml;
89 XMLDocPointer_t xmldoc = xml.ParseFile(xml_filename);
90 if (!xmldoc) {
91 std::cerr<<" file not found " <<xml_filename.Data() << " current directory is: " << gSystem->WorkingDirectory()<<std::endl;
92 gSystem->Abort();
93 }
94 XMLNodePointer_t mainnode = xml.DocGetRootElement(xmldoc);
95
96 // loop to find <Variables> and <Spectators>
97 XMLNodePointer_t node = xml.GetChild(mainnode);
98 while (node)
99 {
100 TString nodeName = xml.GetNodeName(node);
101 if (nodeName == "Variables" || nodeName == "Spectators") {
102 std::vector<XmlVariableInfo> r = parseVariables(&xml, node, nodeName);
103 result.insert(result.end(), r.begin(), r.end());
104 }
105 else if (nodeName == "GeneralInfo" && node) {
106 for (XMLNodePointer_t info_node = xml.GetChild(node); info_node != nullptr;
107 info_node = xml.GetNext(info_node)){
108 XMLAttrPointer_t attr = xml.GetFirstAttr(info_node);
109 // loop over the attributes of each child
110 while (attr != nullptr) {
111 // TString name = xml.GetAttrName(attr);
112 TString value = xml.GetAttrValue(attr);
113 attr = xml.GetNextAttr(attr);
114 if (value == "AnalysisType")
115 AnalysisType = TString(xml.GetAttrValue(attr));
116 }
117 }
118 }
119 else if (nodeName == "Classes") {
120 NClass = TString(xml.GetAttr(node,"NClass")).Atoi();
121 }
122 node = xml.GetNext(node);
123 }
124 xml.FreeDoc(xmldoc);
125 return result;
126}
127
128
129int main ATLAS_NOT_THREAD_SAFE (int argc, char** argv){
130 TRandom3 rand;
131 //float dummyFloat;
132 TMVA::Reader *reader = new TMVA::Reader("Silent");
133
134 TString xmlFileName="";
135 if(argc>1) xmlFileName=argv[1];
136 else return 0;
137 int last_slash=xmlFileName.Last('/');
138 last_slash = (last_slash<0 ? 0 : last_slash+1);
139 TString outFileName=xmlFileName(last_slash, xmlFileName.Length()-last_slash+1);
140 outFileName+=".root";
141 outFileName.ReplaceAll(".xml.root", ".root");
142 if(argc>2) outFileName=argv[2];
143
144 std::vector<XmlVariableInfo> variable_infos = parseXml(xmlFileName);
145 bool isRegression = (AnalysisType == "Regression");
146 bool isMulti = (AnalysisType == "Multiclass");
147 TString varList;
148 vector<float*> vars;
149 vector<float> var_avgerage;
150
151
152 cout << "Boosted Decision Tree for " << AnalysisType << endl;
153
154 for (const auto & variable_info : variable_infos){
155
156 TString infoType = (TString(variable_info.nodeName).Contains("Variable") ?
157 "variable" : "spectator");
158 TString expression = variable_info.expression;
159 TString varName = variable_info.label;
160 TString type = variable_info.varType;
161
162 TString varDefinition(varName);
163 if (varName != expression){
164 varDefinition += " := " + expression;
165 }
166
167 float average_value = (variable_info.min+variable_info.max)/2 ;
168 var_avgerage.push_back(average_value);
169 vars.push_back(new float(average_value));
170 if (infoType == "variable"){
171 varList+=varDefinition+",";
172 reader->AddVariable(varDefinition, vars.back());
173 cout << "Add variable: " << varDefinition << " " << type << endl;
174 }
175 else if (infoType == "spectator"){
176 reader->AddSpectator(varDefinition, vars.back());
177 cout << "Add spectator: " << varDefinition << " " << type << endl;
178 }
179 else // should never happen
180 {
181 cerr <<"Unknown type from parser "<< infoType.Data()<<endl;
182 //throw std::runtime_error("Unknown type from parser");
183 // delete vars.back();
184 vars.pop_back();
185 return 0;
186 }
187 }
188
189 reader->BookMVA("BDTG", xmlFileName);
190
191 TMVA::MethodBDT* method_bdt = dynamic_cast<TMVA::MethodBDT*> (reader->FindMVA("BDTG"));
192 bool useYesNoLeaf = false;
193 bool isGrad = false;
194 if(method_bdt->GetOptions().Contains("UseYesNoLeaf=True")) useYesNoLeaf = true;
195 if(method_bdt->GetOptions().Contains("BoostType=Grad")) isGrad = true;
196 cout << "UseYesNoLeaf? " << useYesNoLeaf << endl;
197 cout << "Gradient Boost? " << isGrad << endl;
198 std::unique_ptr<MVAUtils::BDT> bdt= TMVAToMVAUtils::convert(method_bdt, isRegression || isGrad, useYesNoLeaf);
199 bdt->SetPointers(vars);
200
201
202 cout << endl << "Testing MVA produced from TMVA::Reader " << endl;
203
204 cout << "MVAUtils::BDT : "
205 << (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
206 << " , TMVA::Reader : "
207 << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
208 << endl;
209
210 for(auto & var : vars) *var = 0;
211 cout << "MVAUtils::BDT : "
212 << (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(vars,NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
213 << " , TMVA::Reader : "
214 << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
215 << endl;
216
217
218 cout << "Writing MVAUtils::BDT in " << outFileName << endl;
219 TFile* f = TFile::Open(outFileName, "RECREATE");
220 bdt->WriteTree("BDT")->Write();
221 TNamed* n = new TNamed("varList", varList.Data());
222 n->Write();
223 f->Close();
224 delete f;
225 cout << endl << "Reading BDT from root file and testing " << outFileName << endl;
226
227 f = TFile::Open(outFileName, "READ");
228 TTree* bdt_tree = dynamic_cast<TTree*> (f->Get("BDT"));
229 if(!bdt_tree){
230 cerr <<"Could not Retrieve BDT TTree from file , should not happen" <<endl;
231 return 0;
232 }
233
234 bdt = std::make_unique<MVAUtils::BDT>(bdt_tree);
235 bdt->SetPointers(vars);
236 cout << bdt->GetResponse() << endl;
237 cout << "MVAUtils::BDT : "
238 << (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
239 << " , TMVA::Reader : "
240 << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
241 << endl;
242 for(uint i = 0; i != vars.size(); ++i) *vars[i] = var_avgerage[i];
243 cout << "MVAUtils::BDT : "
244 << (isRegression && !isGrad ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
245 << " , TMVA::Reader : "
246 << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
247 << endl;
248
249 cout << "Checking over many random events" << endl;
250 int n_events=0;
251 for(int i = 0; i != 100; ++i){
252 for(uint i = 0; i != vars.size(); ++i) *vars[i] = (1+(rand.Rndm()-0.5)/5)*var_avgerage[i];
253 float mva = (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(vars,NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification());
254 float tmva = (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"));
255 if( (tmva-mva)/mva > 0.00001 ){
256 cout << "MVAUtils::BDT : " << mva << " , TMVA::Reader : " << tmva << endl;
257 n_events++;
258 }
259 }
260 cout << "Found " << n_events << " events in disagreement " << endl;
261 cout << endl;
262}
int main(int, char **)
Main class for all the CppUnit test classes.
unsigned int uint
Define macros for attributes used to control the static checker.
#define ATLAS_NOT_THREAD_SAFE
getNoisyStrip() Find noisy strips from hitmaps and write out into xml/db formats
Definition node.h:24
TString AnalysisType
std::vector< XmlVariableInfo > parseVariables(TXMLEngine *xml, void *node, const TString &nodeName)
unsigned int NClass
int r
Definition globals.cxx:22
std::unique_ptr< MVAUtils::BDT > convert(TMVA::MethodBDT *bdt, bool isRegression=true, bool useYesNoLeaf=false)
STL namespace.
#define ATLAS_NOT_REENTRANT
Definition random.h:27
Utility to convert xml files from TMVA into root TTrees for this package.