ATLAS Offline Software
Tool_ModeDiscriminator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
11 #include "TString.h"
12 #include "TFile.h"
13 #include "TTree.h"
14 #include <memory>
15 
17  asg::AsgTool(name),
18  m_MVABDT_List()
19 {
20 }
21 
22 
24 
25 
27 
28  ATH_MSG_DEBUG( name() << " initialize()" );
29  m_init=true;
30 
31  ATH_CHECK( HelperFunctions::bindToolHandle( m_Tool_InformationStore, m_Tool_InformationStoreName ) );
32 
33  ATH_CHECK(m_Tool_InformationStore.retrieve());
34 
35  // get the required information from the informationstore tool
36  ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble("ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
37  ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_TMVAMethod", m_MethodName) );
38 
39  // build the name of the variable that contains the variable list for this discri tool
40  std::string varNameList_Full = "ModeDiscriminator_BDTVariableNames_CellBased_" + m_Name_ModeCase;
41  ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
42 
43  std::string varDefaultValueList_Full = "ModeDiscriminator_BDTVariableDefaults_CellBased_" + m_Name_ModeCase;
44  ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
45 
46 
47  // consistency check:
48  // Number of feature names and feature default values has to match
49  if ( m_List_BDTVariableDefaultValues.size() != m_List_BDTVariableNames.size() ) {
50  ATH_MSG_ERROR("Number of variable names does not match number of default values! Check jobOptions!");
51  return StatusCode::FAILURE;
52  }
53 
54  // Create reader for each pT Bin; nBins = Edges-1
55  for (unsigned int iPtBin=0; iPtBin<(m_BinEdges_Pt.size() - 1); iPtBin++) {
56 
57  std::string bin_lowerStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin]/1000.);
58  std::string bin_upperStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin+1]/1000.);
59  std::string curPtBin = "ET_" + bin_lowerStr + "_" + bin_upperStr;
60 
61  // weight files
62  std::string curWeightFile = m_calib_path + (!m_calib_path.empty() ? "/" : "");
63  curWeightFile += "TrainModes_";
64  curWeightFile += "CellBased_";
65  curWeightFile += curPtBin + "_";
66  curWeightFile += m_Name_ModeCase + "_";
67  curWeightFile += m_MethodName + ".weights.root";
68 
69  std::string resolvedWeightFileName = PathResolverFindCalibFile(curWeightFile);
70 
71  if (resolvedWeightFileName.empty()) {
72  ATH_MSG_ERROR("Weight file " << curWeightFile << " not found!");
73  return StatusCode::FAILURE;
74  }
75 
76  // MVAUtils BDT
77  std::unique_ptr<TFile> fBDT = std::make_unique<TFile>( resolvedWeightFileName.c_str() );
78  TTree* tBDT = dynamic_cast<TTree*> (fBDT->Get("BDT"));
79  std::unique_ptr<MVAUtils::BDT> curBDT = std::make_unique<MVAUtils::BDT>(tBDT);
80  if (curBDT == nullptr) {
81  ATH_MSG_ERROR( "Failed to create MVAUtils::BDT for " << resolvedWeightFileName );
82  return StatusCode::FAILURE;
83  }
84 
85  m_MVABDT_List.push_back(std::move(curBDT));
86 
87  }//end loop over pt bins to get weight files, reference hists and MVAUtils::BDT objects
88 
89  return StatusCode::SUCCESS;
90 }
91 
92 
93 void PanTau::Tool_ModeDiscriminator::updateReaderVariables(PanTau::PanTauSeed* inSeed, std::vector<float>& list_BDTVariableValues) const {
94 
95  //update features used in MVA with values from current seed
96  // use default value for feature if it is not present in current seed
97  //NOTE! This has to be done (even if the seed pt is bad) otherwise problems with details storage
98  // [If this for loop is skipped, it is not guaranteed that all details are set to their proper default value]
99  PanTau::TauFeature* seedFeatures = inSeed->getFeatures();
100 
101  for (unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
102  std::string curVar = "CellBased_" + m_List_BDTVariableNames[iVar];
103 
104  bool isValid;
105  double newValue = seedFeatures->value(curVar, isValid);
106  if (!isValid) {
107  ATH_MSG_DEBUG("\tUse default value as the feature (the one below this line) was not calculated");
108  newValue = m_List_BDTVariableDefaultValues[iVar];
109  //add this feature with its default value for the details later
110  seedFeatures->addFeature(curVar, newValue);
111  }
112 
113  list_BDTVariableValues[iVar] = static_cast<float>(newValue);
114  }//end loop over BDT vars
115 
116  }
117 
118 
120 
121  std::vector<float> list_BDTVariableValues(m_List_BDTVariableNames.size());
122 
123  updateReaderVariables(inSeed, list_BDTVariableValues);
124 
126  ATH_MSG_DEBUG("WARNING Seed has bad pt value! " << inSeed->getTauJet()->pt() << " MeV");
127  isOK = false;
128  return -2;
129  }
130 
131  //get the pt bin of input Seed
132  //NOTE: could be moved to decay mode determinator tool...
133  int ptBin = -1;
134  for (unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
135  if (inSeed->p4().Pt() > m_BinEdges_Pt[iPtBin] && inSeed->p4().Pt() < m_BinEdges_Pt[iPtBin+1]) {
136  ptBin = iPtBin;
137  break;
138  }
139  }
140  if (ptBin == -1) {
141  ATH_MSG_WARNING("Could not find ptBin for tau seed with pt " << inSeed->p4().Pt());
142  isOK = false;
143  return -2.;
144  }
145 
146  isOK = true;
147 
148  // return the mva response
149  return m_MVABDT_List[ptBin]->GetGradBoostMVA(list_BDTVariableValues);
150 }
PanTauSeed.h
PanTau::PanTauSeed::getFeatures
const PanTau::TauFeature * getFeatures() const
Definition: PanTauSeed.h:217
PanTau::PanTauSeed::getTauJet
const xAOD::TauJet * getTauJet() const
Definition: PanTauSeed.h:215
HelperFunctions.h
asg
Definition: DataHandleTestTool.h:28
PanTau::TauFeature::addFeature
bool addFeature(const std::string &name, const double value)
adds a new feature
Definition: TauFeature.cxx:44
PanTau::Tool_ModeDiscriminator::Tool_ModeDiscriminator
Tool_ModeDiscriminator(const std::string &name)
Definition: Tool_ModeDiscriminator.cxx:16
Tool_ModeDiscriminator.h
PanTau::TauFeature
Definition: TauFeature.h:19
isValid
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
Definition: AtlasPID.h:867
PanTau::PanTauSeed::p4
virtual FourMom_t p4() const
The full 4-momentum of the particle as a TLoretzVector.
Definition: PanTauSeed.cxx:134
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
PanTau::Tool_ModeDiscriminator::~Tool_ModeDiscriminator
virtual ~Tool_ModeDiscriminator()
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
PanTau::TauFeature::value
double value(const std::string &name, bool &isValid) const
returns the value of the feature given by its name
Definition: TauFeature.cxx:23
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
PanTau::Tool_ModeDiscriminator::getResponse
virtual double getResponse(PanTau::PanTauSeed *inSeed, bool &isOK) const
Definition: Tool_ModeDiscriminator.cxx:119
PanTau::PanTauSeed::t_BadPtValue
@ t_BadPtValue
Definition: PanTauSeed.h:37
PanTau::HelperFunctions::bindToolHandle
static StatusCode bindToolHandle(ToolHandle< T > &, std::string)
Definition: Reconstruction/PanTau/PanTauAlgs/PanTauAlgs/HelperFunctions.h:56
PanTau::Tool_ModeDiscriminator::updateReaderVariables
void updateReaderVariables(PanTau::PanTauSeed *inSeed, std::vector< float > &list_BDTVariableValues) const
Definition: Tool_ModeDiscriminator.cxx:93
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
Tool_InformationStore.h
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
PanTau::PanTauSeed::isOfTechnicalQuality
bool isOfTechnicalQuality(int pantauSeed_TechnicalQuality) const
Definition: PanTauSeed.cxx:329
TauFeature.h
PanTau::Tool_ModeDiscriminator::initialize
virtual StatusCode initialize()
Dummy implementation of the initialisation function.
Definition: Tool_ModeDiscriminator.cxx:26
PanTau::PanTauSeed
Definition: PanTauSeed.h:24