ATLAS Offline Software
Loading...
Searching...
No Matches
Tool_ModeDiscriminator.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
9#include "TFile.h"
10#include "TTree.h"
11
16
17
19
20
22
23 ATH_MSG_DEBUG( name() << " initialize()" );
24 m_init=true;
25
27
29
30 // get the required information from the informationstore tool
31 ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble("ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
32 ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_TMVAMethod", m_MethodName) );
33
34 // build the name of the variable that contains the variable list for this discri tool
35 std::string varNameList_Full = "ModeDiscriminator_BDTVariableNames_CellBased_" + m_Name_ModeCase;
36 ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
37
38 std::string varDefaultValueList_Full = "ModeDiscriminator_BDTVariableDefaults_CellBased_" + m_Name_ModeCase;
39 ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
40
41
42 // consistency check:
43 // Number of feature names and feature default values has to match
45 ATH_MSG_ERROR("Number of variable names does not match number of default values! Check jobOptions!");
46 return StatusCode::FAILURE;
47 }
48
49 // Create reader for each pT Bin; nBins = Edges-1
50 for (unsigned int iPtBin=0; iPtBin<(m_BinEdges_Pt.size() - 1); iPtBin++) {
51
52 std::string bin_lowerStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin]/1000.);
53 std::string bin_upperStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin+1]/1000.);
54 std::string curPtBin = "ET_" + bin_lowerStr + "_" + bin_upperStr;
55
56 // weight files
57 std::string curWeightFile = m_calib_path + (!m_calib_path.empty() ? "/" : "");
58 curWeightFile += "TrainModes_";
59 curWeightFile += "CellBased_";
60 curWeightFile += curPtBin + "_";
61 curWeightFile += m_Name_ModeCase + "_";
62 curWeightFile += m_MethodName + ".weights.root";
63
64 std::string resolvedWeightFileName = PathResolverFindCalibFile(curWeightFile);
65
66 if (resolvedWeightFileName.empty()) {
67 ATH_MSG_ERROR("Weight file " << curWeightFile << " not found!");
68 return StatusCode::FAILURE;
69 }
70
71 // MVAUtils BDT
72 std::unique_ptr<TFile> fBDT = std::make_unique<TFile>( resolvedWeightFileName.c_str() );
73 TTree* tBDT = dynamic_cast<TTree*> (fBDT->Get("BDT"));
74 std::unique_ptr<MVAUtils::BDT> curBDT = std::make_unique<MVAUtils::BDT>(tBDT);
75 if (curBDT == nullptr) {
76 ATH_MSG_ERROR( "Failed to create MVAUtils::BDT for " << resolvedWeightFileName );
77 return StatusCode::FAILURE;
78 }
79
80 m_MVABDT_List.push_back(std::move(curBDT));
81
82 }//end loop over pt bins to get weight files, reference hists and MVAUtils::BDT objects
83
84 return StatusCode::SUCCESS;
85}
86
87
88void PanTau::Tool_ModeDiscriminator::updateReaderVariables(PanTau::PanTauSeed* inSeed, std::vector<float>& list_BDTVariableValues) const {
89
90 //update features used in MVA with values from current seed
91 // use default value for feature if it is not present in current seed
92 //NOTE! This has to be done (even if the seed pt is bad) otherwise problems with details storage
93 // [If this for loop is skipped, it is not guaranteed that all details are set to their proper default value]
94 PanTau::TauFeature* seedFeatures = inSeed->getFeatures();
95
96 for (unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
97 std::string curVar = "CellBased_" + m_List_BDTVariableNames[iVar];
98
99 bool isValid;
100 double newValue = seedFeatures->value(curVar, isValid);
101 if (!isValid) {
102 ATH_MSG_DEBUG("\tUse default value as the feature (the one below this line) was not calculated");
103 newValue = m_List_BDTVariableDefaultValues[iVar];
104 //add this feature with its default value for the details later
105 seedFeatures->addFeature(curVar, newValue);
106 }
107
108 list_BDTVariableValues[iVar] = static_cast<float>(newValue);
109 }//end loop over BDT vars
110
111 }
112
113
115
116 std::vector<float> list_BDTVariableValues(m_List_BDTVariableNames.size());
117
118 updateReaderVariables(inSeed, list_BDTVariableValues);
119
121 ATH_MSG_DEBUG("WARNING Seed has bad pt value! " << inSeed->getTauJet()->pt() << " MeV");
122 isOK = false;
123 return -2;
124 }
125
126 //get the pt bin of input Seed
127 //NOTE: could be moved to decay mode determinator tool...
128 int ptBin = -1;
129 for (unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
130 if (inSeed->p4().Pt() > m_BinEdges_Pt[iPtBin] && inSeed->p4().Pt() < m_BinEdges_Pt[iPtBin+1]) {
131 ptBin = iPtBin;
132 break;
133 }
134 }
135 if (ptBin == -1) {
136 ATH_MSG_WARNING("Could not find ptBin for tau seed with pt " << inSeed->p4().Pt());
137 isOK = false;
138 return -2.;
139 }
140
141 isOK = true;
142
143 // return the mva response
144 return m_MVABDT_List[ptBin]->GetGradBoostMVA(list_BDTVariableValues);
145}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
bool isValid(const T &p)
Av: we implement here an ATLAS-sepcific convention: all particles which are 99xxxxx are fine.
Definition AtlasPID.h:878
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
static StatusCode bindToolHandle(ToolHandle< T > &, std::string)
virtual FourMom_t p4() const
The full 4-momentum of the particle as a TLoretzVector.
bool isOfTechnicalQuality(int pantauSeed_TechnicalQuality) const
const xAOD::TauJet * getTauJet() const
Definition PanTauSeed.h:215
const PanTau::TauFeature * getFeatures() const
Definition PanTauSeed.h:217
Class containing features of a tau seed.
Definition TauFeature.h:19
bool addFeature(const std::string &name, const double value)
adds a new feature
double value(const std::string &name, bool &isValid) const
returns the value of the feature given by its name
PanTau::HelperFunctions m_HelperFunctions
std::vector< std::unique_ptr< MVAUtils::BDT > > m_MVABDT_List
virtual StatusCode initialize()
Dummy implementation of the initialisation function.
ToolHandle< PanTau::ITool_InformationStore > m_Tool_InformationStore
std::vector< double > m_List_BDTVariableDefaultValues
Gaudi::Property< std::string > m_Name_ModeCase
Gaudi::Property< std::string > m_calib_path
Tool_ModeDiscriminator(const std::string &name)
void updateReaderVariables(PanTau::PanTauSeed *inSeed, std::vector< float > &list_BDTVariableValues) const
Gaudi::Property< std::string > m_Tool_InformationStoreName
std::vector< std::string > m_List_BDTVariableNames
virtual double getResponse(PanTau::PanTauSeed *inSeed, bool &isOK) const
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
virtual double pt() const
The transverse momentum ( ) of the particle.