ATLAS Offline Software
Tool_ModeDiscriminator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
11 #include "TString.h"
12 #include "TFile.h"
13 #include "TTree.h"
14 #include <memory>
15 
17  asg::AsgTool(name),
18  m_Name_InputAlg("InvalidInputAlg"),
19  m_Name_ModeCase("InvalidModeCase"),
20  m_Tool_InformationStore("PanTau::Tool_InformationStore/Tool_InformationStore"),
21  m_MVABDT_List()
22 {
23  declareProperty("calibFolder", m_calib_path, "Location of calib files in cvmfs");//sync'd with tauRecFlags.tauRecToolsCVMFSPath()
24  declareProperty("Name_InputAlg", m_Name_InputAlg, "Name of the input algorithm for this instance");
25  declareProperty("Name_ModeCase", m_Name_ModeCase, "Name of the two modes to be distinguished for this instance");
26  declareProperty("Tool_InformationStore", m_Tool_InformationStore, "Handle to the information store tool");
27  declareProperty("Tool_InformationStoreName",m_Tool_InformationStoreName,"Handle to the information store tool");
28 }
29 
30 
32 
33 
35 
36  ATH_MSG_DEBUG( name() << " initialize()" );
37  m_init=true;
38 
39  ATH_CHECK( HelperFunctions::bindToolHandle( m_Tool_InformationStore, m_Tool_InformationStoreName ) );
40 
41  ATH_CHECK(m_Tool_InformationStore.retrieve());
42 
43  // get the required information from the informationstore tool
44  ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble("ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
45  ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_ReaderOption", m_ReaderOption) );
46  ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_TMVAMethod", m_MethodName) );
47 
48  // build the name of the variable that contains the variable list for this discri tool
49  std::string varNameList_Full = "ModeDiscriminator_BDTVariableNames_" + m_Name_InputAlg + "_" + m_Name_ModeCase;
50  ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
51 
52  std::string varDefaultValueList_Full = "ModeDiscriminator_BDTVariableDefaults_" + m_Name_InputAlg + "_" + m_Name_ModeCase;
53  ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
54 
55 
56  // consistency check:
57  // Number of feature names and feature default values has to match
58  if ( m_List_BDTVariableDefaultValues.size() != m_List_BDTVariableNames.size() ) {
59  ATH_MSG_ERROR("Number of variable names does not match number of default values! Check jobOptions!");
60  return StatusCode::FAILURE;
61  }
62 
63  // Create reader for each pT Bin; nBins = Edges-1
64  for (unsigned int iPtBin=0; iPtBin<(m_BinEdges_Pt.size() - 1); iPtBin++) {
65 
66  std::string bin_lowerStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin]/1000.);
67  std::string bin_upperStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin+1]/1000.);
68  std::string curPtBin = "ET_" + bin_lowerStr + "_" + bin_upperStr;
69 
70  // weight files
71  std::string curWeightFile = m_calib_path + (m_calib_path.length() ? "/" : "");
72  curWeightFile += "TrainModes_";
73  curWeightFile += m_Name_InputAlg + "_";
74  curWeightFile += curPtBin + "_";
75  curWeightFile += m_Name_ModeCase + "_";
76  curWeightFile += m_MethodName + ".weights.root";
77 
78  std::string resolvedWeightFileName = PathResolverFindCalibFile(curWeightFile);
79 
80  if (resolvedWeightFileName.empty()) {
81  ATH_MSG_ERROR("Weight file " << curWeightFile << " not found!");
82  return StatusCode::FAILURE;
83  }
84 
85  // MVAUtils BDT
86  std::unique_ptr<TFile> fBDT = std::make_unique<TFile>( resolvedWeightFileName.c_str() );
87  TTree* tBDT = dynamic_cast<TTree*> (fBDT->Get("BDT"));
88  std::unique_ptr<MVAUtils::BDT> curBDT = std::make_unique<MVAUtils::BDT>(tBDT);
89  if (curBDT == nullptr) {
90  ATH_MSG_ERROR( "Failed to create MVAUtils::BDT for " << resolvedWeightFileName );
91  return StatusCode::FAILURE;
92  }
93 
94  m_MVABDT_List.push_back(std::move(curBDT));
95 
96  }//end loop over pt bins to get weight files, reference hists and MVAUtils::BDT objects
97 
98  return StatusCode::SUCCESS;
99 }
100 
101 
102 void PanTau::Tool_ModeDiscriminator::updateReaderVariables(PanTau::PanTauSeed* inSeed, std::vector<float>& list_BDTVariableValues) const {
103 
104  //update features used in MVA with values from current seed
105  // use default value for feature if it is not present in current seed
106  //NOTE! This has to be done (even if the seed pt is bad) otherwise problems with details storage
107  // [If this for loop is skipped, it is not guaranteed that all details are set to their proper default value]
108  PanTau::TauFeature* seedFeatures = inSeed->getFeatures();
109 
110  for (unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
111  std::string curVar = m_Name_InputAlg + "_" + m_List_BDTVariableNames[iVar];
112 
113  bool isValid;
114  double newValue = seedFeatures->value(curVar, isValid);
115  if (!isValid) {
116  ATH_MSG_DEBUG("\tUse default value as the feature (the one below this line) was not calculated");
117  newValue = m_List_BDTVariableDefaultValues[iVar];
118  //add this feature with its default value for the details later
119  seedFeatures->addFeature(curVar, newValue);
120  }
121 
122  list_BDTVariableValues[iVar] = static_cast<float>(newValue);
123  }//end loop over BDT vars
124 
125  }
126 
127 
129 
130  std::vector<float> list_BDTVariableValues(m_List_BDTVariableNames.size());
131 
132  updateReaderVariables(inSeed, list_BDTVariableValues);
133 
135  ATH_MSG_DEBUG("WARNING Seed has bad pt value! " << inSeed->getTauJet()->pt() << " MeV");
136  isOK = false;
137  return -2;
138  }
139 
140  //get the pt bin of input Seed
141  //NOTE: could be moved to decay mode determinator tool...
142  int ptBin = -1;
143  for (unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
144  if (inSeed->p4().Pt() > m_BinEdges_Pt[iPtBin] && inSeed->p4().Pt() < m_BinEdges_Pt[iPtBin+1]) {
145  ptBin = iPtBin;
146  break;
147  }
148  }
149  if (ptBin == -1) {
150  ATH_MSG_WARNING("Could not find ptBin for tau seed with pt " << inSeed->p4().Pt());
151  isOK = false;
152  return -2.;
153  }
154 
155  isOK = true;
156 
157  // return the mva response
158  return m_MVABDT_List[ptBin]->GetGradBoostMVA(list_BDTVariableValues);
159 }
PanTauSeed.h
PanTau::Tool_ModeDiscriminator::m_calib_path
std::string m_calib_path
Definition: Tool_ModeDiscriminator.h:53
PanTau::Tool_ModeDiscriminator::m_Name_InputAlg
std::string m_Name_InputAlg
Definition: Tool_ModeDiscriminator.h:54
PanTau::PanTauSeed::getFeatures
const PanTau::TauFeature * getFeatures() const
Definition: PanTauSeed.h:230
PanTau::PanTauSeed::getTauJet
const xAOD::TauJet * getTauJet() const
Definition: PanTauSeed.h:228
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
HelperFunctions.h
asg
Definition: DataHandleTestTool.h:28
PanTau::TauFeature::addFeature
bool addFeature(const std::string &name, const double value)
adds a new feature
Definition: TauFeature.cxx:44
PanTau::Tool_ModeDiscriminator::Tool_ModeDiscriminator
Tool_ModeDiscriminator(const std::string &name)
Definition: Tool_ModeDiscriminator.cxx:16
Tool_ModeDiscriminator.h
PanTau::TauFeature
Definition: TauFeature.h:19
isValid
bool isValid(const T &p)
Definition: AtlasPID.h:225
PanTau::PanTauSeed::p4
virtual FourMom_t p4() const
The full 4-momentum of the particle as a TLoretzVector.
Definition: PanTauSeed.cxx:141
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
PanTau::Tool_ModeDiscriminator::~Tool_ModeDiscriminator
virtual ~Tool_ModeDiscriminator()
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
PanTau::TauFeature::value
double value(const std::string &name, bool &isValid) const
returns the value of the feature given by its name
Definition: TauFeature.cxx:23
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
PanTau::Tool_ModeDiscriminator::getResponse
virtual double getResponse(PanTau::PanTauSeed *inSeed, bool &isOK) const
Definition: Tool_ModeDiscriminator.cxx:128
PanTau::PanTauSeed::t_BadPtValue
@ t_BadPtValue
Definition: PanTauSeed.h:37
PanTau::HelperFunctions::bindToolHandle
static StatusCode bindToolHandle(ToolHandle< T > &, std::string)
Definition: Reconstruction/PanTau/PanTauAlgs/PanTauAlgs/HelperFunctions.h:60
PanTau::Tool_ModeDiscriminator::updateReaderVariables
void updateReaderVariables(PanTau::PanTauSeed *inSeed, std::vector< float > &list_BDTVariableValues) const
Definition: Tool_ModeDiscriminator.cxx:102
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
Tool_InformationStore.h
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
PanTau::PanTauSeed::isOfTechnicalQuality
bool isOfTechnicalQuality(int pantauSeed_TechnicalQuality) const
Definition: PanTauSeed.cxx:349
PanTau::Tool_ModeDiscriminator::m_Tool_InformationStore
ToolHandle< PanTau::ITool_InformationStore > m_Tool_InformationStore
Definition: Tool_ModeDiscriminator.h:56
TauFeature.h
PanTau::Tool_ModeDiscriminator::m_Name_ModeCase
std::string m_Name_ModeCase
Definition: Tool_ModeDiscriminator.h:55
PanTau::Tool_ModeDiscriminator::initialize
virtual StatusCode initialize()
Dummy implementation of the initialisation function.
Definition: Tool_ModeDiscriminator.cxx:34
PanTau::Tool_ModeDiscriminator::m_Tool_InformationStoreName
std::string m_Tool_InformationStoreName
Definition: Tool_ModeDiscriminator.h:57
PanTau::PanTauSeed
Definition: PanTauSeed.h:24