ATLAS Offline Software
convertXmlToRootTree.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 
6 #include "TMVAToMVAUtils.h"
7 #include "MVAUtils/BDT.h"
8 #include "TMVA/Reader.h"
9 #include "TMVA/MethodBDT.h"
10 
11 #include <TString.h>
12 #include <TXMLEngine.h>
13 #include <TSystem.h>
14 #include <TFile.h>
15 #include <TRandom3.h>
16 
18 #include <iostream>
19 #include <memory>
20 #include <vector>
21 
22 using namespace std;
23 
35  TString expression;
36  TString label;
37  TString varType;
38  TString nodeName;
39  float min = 0.0F;
40  float max = 0.0F;
41 };
42 TString AnalysisType;
43 unsigned int NClass;
44 
45 std::vector<XmlVariableInfo>
46 parseVariables(TXMLEngine *xml, void* node, const TString & nodeName)
47 {
48  std::vector<XmlVariableInfo> result;
49  if (!xml || !node) return result;
50 
51  // loop over all children inside <Variables> or <Spectators>
52  for (XMLNodePointer_t info_node = xml->GetChild(node); info_node != nullptr;
53  info_node = xml->GetNext(info_node))
54  {
55  XMLAttrPointer_t attr = xml->GetFirstAttr(info_node);
56  XmlVariableInfo varInfo;
57  // loop over the attributes of each child
58  while (attr != nullptr)
59  {
60  TString name = xml->GetAttrName(attr);
61  if (name == "Expression")
62  varInfo.expression = xml->GetAttrValue(attr);
63  else if (name == "Label")
64  varInfo.label = xml->GetAttrValue(attr);
65  else if (name == "Type")
66  varInfo.varType = xml->GetAttrValue(attr);
67  else if (name == "Min") varInfo.min=TString(xml->GetAttrValue(attr)).Atof();
68  else if (name == "Max") varInfo.max=TString(xml->GetAttrValue(attr)).Atof();
69 
70  attr = xml->GetNextAttr(attr);
71  }
72  // ATH_MSG_DEBUG("Expression: " << expression << " Label: " << label << " varType: " << varType);
73  varInfo.nodeName = nodeName;
74  result.push_back(varInfo);
75  }
76  return result;
77 }
78 
79 /*
80  * gSystem is a static expression of type TSystem
81  * so this is no re-entrant.
82  */
83 std::vector<XmlVariableInfo>
84 parseXml ATLAS_NOT_REENTRANT (const TString & xml_filename)
85 {
86  std::vector<XmlVariableInfo> result;
87 
88  TXMLEngine xml;
89  XMLDocPointer_t xmldoc = xml.ParseFile(xml_filename);
90  if (!xmldoc) {
91  std::cerr<<" file not found " <<xml_filename.Data() << " current directory is: " << gSystem->WorkingDirectory()<<std::endl;
92  gSystem->Abort();
93  }
94  XMLNodePointer_t mainnode = xml.DocGetRootElement(xmldoc);
95 
96  // loop to find <Variables> and <Spectators>
97  XMLNodePointer_t node = xml.GetChild(mainnode);
98  while (node)
99  {
100  TString nodeName = xml.GetNodeName(node);
101  if (nodeName == "Variables" || nodeName == "Spectators") {
102  std::vector<XmlVariableInfo> r = parseVariables(&xml, node, nodeName);
103  result.insert(result.end(), r.begin(), r.end());
104  }
105  else if (nodeName == "GeneralInfo" && node) {
106  for (XMLNodePointer_t info_node = xml.GetChild(node); info_node != nullptr;
107  info_node = xml.GetNext(info_node)){
108  XMLAttrPointer_t attr = xml.GetFirstAttr(info_node);
109  // loop over the attributes of each child
110  while (attr != nullptr) {
111  // TString name = xml.GetAttrName(attr);
112  TString value = xml.GetAttrValue(attr);
113  attr = xml.GetNextAttr(attr);
114  if (value == "AnalysisType")
115  AnalysisType = TString(xml.GetAttrValue(attr));
116  }
117  }
118  }
119  else if (nodeName == "Classes") {
120  NClass = TString(xml.GetAttr(node,"NClass")).Atoi();
121  }
122  node = xml.GetNext(node);
123  }
124  xml.FreeDoc(xmldoc);
125  return result;
126 }
127 
128 
130  TRandom3 rand;
131  //float dummyFloat;
132  TMVA::Reader *reader = new TMVA::Reader("Silent");
133 
134  TString xmlFileName="";
135  if(argc>1) xmlFileName=argv[1];
136  else return 0;
137  int last_slash=xmlFileName.Last('/');
138  last_slash = (last_slash<0 ? 0 : last_slash+1);
139  TString outFileName=xmlFileName(last_slash, xmlFileName.Length()-last_slash+1);
140  outFileName+=".root";
141  outFileName.ReplaceAll(".xml.root", ".root");
142  if(argc>2) outFileName=argv[2];
143 
144  std::vector<XmlVariableInfo> variable_infos = parseXml(xmlFileName);
145  bool isRegression = (AnalysisType == "Regression");
146  bool isMulti = (AnalysisType == "Multiclass");
147  TString varList;
148  vector<float*> vars;
149  vector<float> var_avgerage;
150 
151 
152  cout << "Boosted Decision Tree for " << AnalysisType << endl;
153 
154  for (const auto & variable_info : variable_infos){
155 
156  TString infoType = (TString(variable_info.nodeName).Contains("Variable") ?
157  "variable" : "spectator");
158  TString expression = variable_info.expression;
159  TString varName = variable_info.label;
160  TString type = variable_info.varType;
161 
162  TString varDefinition(varName);
163  if (varName != expression){
164  varDefinition += " := " + expression;
165  }
166 
167  float average_value = (variable_info.min+variable_info.max)/2 ;
168  var_avgerage.push_back(average_value);
169  vars.push_back(new float(average_value));
170  if (infoType == "variable"){
171  varList+=varDefinition+",";
172  reader->AddVariable(varDefinition, vars.back());
173  cout << "Add variable: " << varDefinition << " " << type << endl;
174  }
175  else if (infoType == "spectator"){
176  reader->AddSpectator(varDefinition, vars.back());
177  cout << "Add spectator: " << varDefinition << " " << type << endl;
178  }
179  else // should never happen
180  {
181  cerr <<"Unknown type from parser "<< infoType.Data()<<endl;
182  //throw std::runtime_error("Unknown type from parser");
183  // delete vars.back();
184  vars.pop_back();
185  return 0;
186  }
187  }
188 
189  reader->BookMVA("BDTG", xmlFileName);
190 
191  TMVA::MethodBDT* method_bdt = dynamic_cast<TMVA::MethodBDT*> (reader->FindMVA("BDTG"));
192  bool useYesNoLeaf = false;
193  bool isGrad = false;
194  if(method_bdt->GetOptions().Contains("UseYesNoLeaf=True")) useYesNoLeaf = true;
195  if(method_bdt->GetOptions().Contains("BoostType=Grad")) isGrad = true;
196  cout << "UseYesNoLeaf? " << useYesNoLeaf << endl;
197  cout << "Gradient Boost? " << isGrad << endl;
198  std::unique_ptr<MVAUtils::BDT> bdt= TMVAToMVAUtils::convert(method_bdt, isRegression || isGrad, useYesNoLeaf);
199  bdt->SetPointers(vars);
200 
201 
202  cout << endl << "Testing MVA produced from TMVA::Reader " << endl;
203 
204  cout << "MVAUtils::BDT : "
205  << (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
206  << " , TMVA::Reader : "
207  << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
208  << endl;
209 
210  for(auto & var : vars) *var = 0;
211  cout << "MVAUtils::BDT : "
212  << (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(vars,NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
213  << " , TMVA::Reader : "
214  << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
215  << endl;
216 
217 
218  cout << "Writing MVAUtils::BDT in " << outFileName << endl;
219  TFile* f = TFile::Open(outFileName, "RECREATE");
220  bdt->WriteTree("BDT")->Write();
221  TNamed* n = new TNamed("varList", varList.Data());
222  n->Write();
223  f->Close();
224  delete f;
225  cout << endl << "Reading BDT from root file and testing " << outFileName << endl;
226 
227  f = TFile::Open(outFileName, "READ");
228  TTree* bdt_tree = dynamic_cast<TTree*> (f->Get("BDT"));
229  if(!bdt_tree){
230  cerr <<"Could not Retrieve BDT TTree from file , should not happen" <<endl;
231  return 0;
232  }
233 
234  bdt = std::make_unique<MVAUtils::BDT>(bdt_tree);
235  bdt->SetPointers(vars);
236  cout << bdt->GetResponse() << endl;
237  cout << "MVAUtils::BDT : "
238  << (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
239  << " , TMVA::Reader : "
240  << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
241  << endl;
242  for(uint i = 0; i != vars.size(); ++i) *vars[i] = var_avgerage[i];
243  cout << "MVAUtils::BDT : "
244  << (isRegression && !isGrad ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification())
245  << " , TMVA::Reader : "
246  << (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"))
247  << endl;
248 
249  cout << "Checking over many random events" << endl;
250  int n_events=0;
251  for(int i = 0; i != 100; ++i){
252  for(uint i = 0; i != vars.size(); ++i) *vars[i] = (1+(rand.Rndm()-0.5)/5)*var_avgerage[i];
253  float mva = (isRegression ? bdt->GetResponse() : isMulti ? bdt->GetMultiResponse(vars,NClass)[NClass-1] : isGrad ? bdt->GetGradBoostMVA(vars) : bdt->GetClassification());
254  float tmva = (isRegression ? reader->EvaluateRegression(0, "BDTG") : isMulti ? reader->EvaluateMulticlass("BDTG")[NClass-1] : reader->EvaluateMVA("BDTG"));
255  if( (tmva-mva)/mva > 0.00001 ){
256  cout << "MVAUtils::BDT : " << mva << " , TMVA::Reader : " << tmva << endl;
257  n_events++;
258  }
259  }
260  cout << "Found " << n_events << " events in disagreement " << endl;
261  cout << endl;
262 }
beamspotman.r
def r
Definition: beamspotman.py:676
BDT.h
beamspotnt.var
var
Definition: bin/beamspotnt.py:1394
XmlVariableInfo::varType
TString varType
Definition: convertXmlToRootTree.cxx:37
ATLAS_NOT_REENTRANT
std::vector< XmlVariableInfo > parseXml ATLAS_NOT_REENTRANT(const TString &xml_filename)
Definition: convertXmlToRootTree.cxx:84
get_generator_info.result
result
Definition: get_generator_info.py:21
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
XmlVariableInfo::max
float max
Definition: convertXmlToRootTree.cxx:40
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
MVAUtils::BDT::SetPointers
void SetPointers(const std::vector< float * > &pointers)
Set the stored pointers so that one can use methods with no args.
athena.value
value
Definition: athena.py:124
TMVAToMVAUtils.h
python.HION12.expression
string expression
Definition: HION12.py:55
NClass
unsigned int NClass
Definition: convertXmlToRootTree.cxx:43
AnalysisType
TString AnalysisType
Definition: convertXmlToRootTree.cxx:42
ExtractEBRunDetails.xml
xml
Definition: ExtractEBRunDetails.py:239
DumpGeoConfig.outFileName
string outFileName
Definition: DumpGeoConfig.py:252
main
int main(int, char **)
Main class for all the CppUnit test classes
Definition: CppUnit_SGtestdriver.cxx:141
uint
unsigned int uint
Definition: LArOFPhaseFill.cxx:20
XmlVariableInfo::min
float min
Definition: convertXmlToRootTree.cxx:39
LArG4FSStartPointFilter.rand
rand
Definition: LArG4FSStartPointFilter.py:80
lumiFormat.i
int i
Definition: lumiFormat.py:85
LArCellNtuple.argv
argv
Definition: LArCellNtuple.py:152
beamspotman.n
n
Definition: beamspotman.py:731
PixelAthClusterMonAlgCfg.varName
string varName
end cluster ToT and charge
Definition: PixelAthClusterMonAlgCfg.py:125
LHEF::Reader
Pythia8::Reader Reader
Definition: Prophecy4fMerger.cxx:11
XmlVariableInfo::label
TString label
Definition: convertXmlToRootTree.cxx:36
hist_file_dump.f
f
Definition: hist_file_dump.py:135
XmlVariableInfo
Utility to convert xml files from TMVA into root TTrees for this package.
Definition: convertXmlToRootTree.cxx:34
MVAUtils::BDT::GetClassification
float GetClassification(const std::vector< float > &values) const
Get response of the forest, for classification.
DQHistogramMergeRegExp.argc
argc
Definition: DQHistogramMergeRegExp.py:20
parseVariables
std::vector< XmlVariableInfo > parseVariables(TXMLEngine *xml, void *node, const TString &nodeName)
Definition: convertXmlToRootTree.cxx:46
ATLAS_NOT_THREAD_SAFE
int main ATLAS_NOT_THREAD_SAFE(int argc, char **argv)
Definition: convertXmlToRootTree.cxx:129
MVAUtils::BDT::WriteTree
TTree * WriteTree(TString name="BDT") const
Return a TTree representing the BDT: each entry is a binary tree, each element of the vectors is a no...
Definition: BDT.cxx:91
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
TMVAToMVAUtils::convert
std::unique_ptr< MVAUtils::BDT > convert(TMVA::MethodBDT *bdt, bool isRegression=true, bool useYesNoLeaf=false)
Definition: TMVAToMVAUtils.h:114
MVAUtils::BDT::GetGradBoostMVA
float GetGradBoostMVA(const std::vector< float > &values) const
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
MVAUtils::BDT::GetResponse
float GetResponse(const std::vector< float > &values) const
Get response of the forest, for regression.
beamspotnt.varList
list varList
Definition: bin/beamspotnt.py:1108
collisions.reader
reader
read the goodrunslist xml file(s)
Definition: collisions.py:22
checker_macros.h
Define macros for attributes used to control the static checker.
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(const std::vector< float > &values, unsigned int numClasses) const
Get response of the forest, for multiclassification (e.g.
XmlVariableInfo::nodeName
TString nodeName
Definition: convertXmlToRootTree.cxx:38
node
Definition: memory_hooks-stdcmalloc.h:74
XmlVariableInfo::expression
TString expression
Definition: convertXmlToRootTree.cxx:35