ATLAS Offline Software
Tool_ModeDiscriminator.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
9 #include "TFile.h"
10 #include "TTree.h"
11 
13  asg::AsgTool(name),
14  m_MVABDT_List()
15 {}
16 
17 
19 
20 
22 
23  ATH_MSG_DEBUG( name() << " initialize()" );
24  m_init=true;
25 
26  ATH_CHECK( HelperFunctions::bindToolHandle( m_Tool_InformationStore, m_Tool_InformationStoreName ) );
27 
28  ATH_CHECK(m_Tool_InformationStore.retrieve());
29 
30  // get the required information from the informationstore tool
31  ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble("ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
32  ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_TMVAMethod", m_MethodName) );
33 
34  // build the name of the variable that contains the variable list for this discri tool
35  std::string varNameList_Full = "ModeDiscriminator_BDTVariableNames_CellBased_" + m_Name_ModeCase;
36  ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
37 
38  std::string varDefaultValueList_Full = "ModeDiscriminator_BDTVariableDefaults_CellBased_" + m_Name_ModeCase;
39  ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
40 
41 
42  // consistency check:
43  // Number of feature names and feature default values has to match
44  if ( m_List_BDTVariableDefaultValues.size() != m_List_BDTVariableNames.size() ) {
45  ATH_MSG_ERROR("Number of variable names does not match number of default values! Check jobOptions!");
46  return StatusCode::FAILURE;
47  }
48 
49  // Create reader for each pT Bin; nBins = Edges-1
50  for (unsigned int iPtBin=0; iPtBin<(m_BinEdges_Pt.size() - 1); iPtBin++) {
51 
52  std::string bin_lowerStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin]/1000.);
53  std::string bin_upperStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin+1]/1000.);
54  std::string curPtBin = "ET_" + bin_lowerStr + "_" + bin_upperStr;
55 
56  // weight files
57  std::string curWeightFile = m_calib_path + (!m_calib_path.empty() ? "/" : "");
58  curWeightFile += "TrainModes_";
59  curWeightFile += "CellBased_";
60  curWeightFile += curPtBin + "_";
61  curWeightFile += m_Name_ModeCase + "_";
62  curWeightFile += m_MethodName + ".weights.root";
63 
64  std::string resolvedWeightFileName = PathResolverFindCalibFile(curWeightFile);
65 
66  if (resolvedWeightFileName.empty()) {
67  ATH_MSG_ERROR("Weight file " << curWeightFile << " not found!");
68  return StatusCode::FAILURE;
69  }
70 
71  // MVAUtils BDT
72  std::unique_ptr<TFile> fBDT = std::make_unique<TFile>( resolvedWeightFileName.c_str() );
73  TTree* tBDT = dynamic_cast<TTree*> (fBDT->Get("BDT"));
74  std::unique_ptr<MVAUtils::BDT> curBDT = std::make_unique<MVAUtils::BDT>(tBDT);
75  if (curBDT == nullptr) {
76  ATH_MSG_ERROR( "Failed to create MVAUtils::BDT for " << resolvedWeightFileName );
77  return StatusCode::FAILURE;
78  }
79 
80  m_MVABDT_List.push_back(std::move(curBDT));
81 
82  }//end loop over pt bins to get weight files, reference hists and MVAUtils::BDT objects
83 
84  return StatusCode::SUCCESS;
85 }
86 
87 
88 void PanTau::Tool_ModeDiscriminator::updateReaderVariables(PanTau::PanTauSeed* inSeed, std::vector<float>& list_BDTVariableValues) const {
89 
90  //update features used in MVA with values from current seed
91  // use default value for feature if it is not present in current seed
92  //NOTE! This has to be done (even if the seed pt is bad) otherwise problems with details storage
93  // [If this for loop is skipped, it is not guaranteed that all details are set to their proper default value]
94  PanTau::TauFeature* seedFeatures = inSeed->getFeatures();
95 
96  for (unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
97  std::string curVar = "CellBased_" + m_List_BDTVariableNames[iVar];
98 
99  bool isValid;
100  double newValue = seedFeatures->value(curVar, isValid);
101  if (!isValid) {
102  ATH_MSG_DEBUG("\tUse default value as the feature (the one below this line) was not calculated");
103  newValue = m_List_BDTVariableDefaultValues[iVar];
104  //add this feature with its default value for the details later
105  seedFeatures->addFeature(curVar, newValue);
106  }
107 
108  list_BDTVariableValues[iVar] = static_cast<float>(newValue);
109  }//end loop over BDT vars
110 
111  }
112 
113 
115 
116  std::vector<float> list_BDTVariableValues(m_List_BDTVariableNames.size());
117 
118  updateReaderVariables(inSeed, list_BDTVariableValues);
119 
121  ATH_MSG_DEBUG("WARNING Seed has bad pt value! " << inSeed->getTauJet()->pt() << " MeV");
122  isOK = false;
123  return -2;
124  }
125 
126  //get the pt bin of input Seed
127  //NOTE: could be moved to decay mode determinator tool...
128  int ptBin = -1;
129  for (unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
130  if (inSeed->p4().Pt() > m_BinEdges_Pt[iPtBin] && inSeed->p4().Pt() < m_BinEdges_Pt[iPtBin+1]) {
131  ptBin = iPtBin;
132  break;
133  }
134  }
135  if (ptBin == -1) {
136  ATH_MSG_WARNING("Could not find ptBin for tau seed with pt " << inSeed->p4().Pt());
137  isOK = false;
138  return -2.;
139  }
140 
141  isOK = true;
142 
143  // return the mva response
144  return m_MVABDT_List[ptBin]->GetGradBoostMVA(list_BDTVariableValues);
145 }
PanTauSeed.h
PanTau::PanTauSeed::getFeatures
const PanTau::TauFeature * getFeatures() const
Definition: PanTauSeed.h:217
PanTau::PanTauSeed::getTauJet
const xAOD::TauJet * getTauJet() const
Definition: PanTauSeed.h:215
HelperFunctions.h
asg
Definition: DataHandleTestTool.h:28
PanTau::TauFeature::addFeature
bool addFeature(const std::string &name, const double value)
adds a new feature
Definition: TauFeature.cxx:44
PanTau::Tool_ModeDiscriminator::Tool_ModeDiscriminator
Tool_ModeDiscriminator(const std::string &name)
Definition: Tool_ModeDiscriminator.cxx:12
Tool_ModeDiscriminator.h
PanTau::TauFeature
Definition: TauFeature.h:19
isValid
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
Definition: AtlasPID.h:878
PanTau::PanTauSeed::p4
virtual FourMom_t p4() const
The full 4-momentum of the particle as a TLoretzVector.
Definition: PanTauSeed.cxx:133
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
PanTau::Tool_ModeDiscriminator::~Tool_ModeDiscriminator
virtual ~Tool_ModeDiscriminator()
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
PanTau::TauFeature::value
double value(const std::string &name, bool &isValid) const
returns the value of the feature given by its name
Definition: TauFeature.cxx:23
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
PanTau::Tool_ModeDiscriminator::getResponse
virtual double getResponse(PanTau::PanTauSeed *inSeed, bool &isOK) const
Definition: Tool_ModeDiscriminator.cxx:114
PanTau::PanTauSeed::t_BadPtValue
@ t_BadPtValue
Definition: PanTauSeed.h:37
PanTau::HelperFunctions::bindToolHandle
static StatusCode bindToolHandle(ToolHandle< T > &, std::string)
Definition: Reconstruction/PanTau/PanTauAlgs/PanTauAlgs/HelperFunctions.h:56
PanTau::Tool_ModeDiscriminator::updateReaderVariables
void updateReaderVariables(PanTau::PanTauSeed *inSeed, std::vector< float > &list_BDTVariableValues) const
Definition: Tool_ModeDiscriminator.cxx:88
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:321
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
PanTau::PanTauSeed::isOfTechnicalQuality
bool isOfTechnicalQuality(int pantauSeed_TechnicalQuality) const
Definition: PanTauSeed.cxx:328
PanTau::Tool_ModeDiscriminator::initialize
virtual StatusCode initialize()
Dummy implementation of the initialisation function.
Definition: Tool_ModeDiscriminator.cxx:21
PanTau::PanTauSeed
Definition: PanTauSeed.h:24