28 ATH_CHECK(m_Tool_InformationStore.retrieve());
31 ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(
"ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
32 ATH_CHECK( m_Tool_InformationStore->getInfo_String(
"ModeDiscriminator_TMVAMethod", m_MethodName) );
35 std::string varNameList_Full =
"ModeDiscriminator_BDTVariableNames_CellBased_" + m_Name_ModeCase;
36 ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
38 std::string varDefaultValueList_Full =
"ModeDiscriminator_BDTVariableDefaults_CellBased_" + m_Name_ModeCase;
39 ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
44 if ( m_List_BDTVariableDefaultValues.size() != m_List_BDTVariableNames.size() ) {
45 ATH_MSG_ERROR(
"Number of variable names does not match number of default values! Check jobOptions!");
46 return StatusCode::FAILURE;
50 for (
unsigned int iPtBin=0; iPtBin<(m_BinEdges_Pt.size() - 1); iPtBin++) {
52 std::string bin_lowerStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin]/1000.);
53 std::string bin_upperStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin+1]/1000.);
54 std::string curPtBin =
"ET_" + bin_lowerStr +
"_" + bin_upperStr;
57 std::string curWeightFile = m_calib_path + (!m_calib_path.empty() ?
"/" :
"");
58 curWeightFile +=
"TrainModes_";
59 curWeightFile +=
"CellBased_";
60 curWeightFile += curPtBin +
"_";
61 curWeightFile += m_Name_ModeCase +
"_";
62 curWeightFile += m_MethodName +
".weights.root";
66 if (resolvedWeightFileName.empty()) {
67 ATH_MSG_ERROR(
"Weight file " << curWeightFile <<
" not found!");
68 return StatusCode::FAILURE;
72 std::unique_ptr<TFile> fBDT = std::make_unique<TFile>( resolvedWeightFileName.c_str() );
73 TTree* tBDT =
dynamic_cast<TTree*
> (fBDT->Get(
"BDT"));
74 std::unique_ptr<MVAUtils::BDT> curBDT = std::make_unique<MVAUtils::BDT>(tBDT);
75 if (curBDT ==
nullptr) {
76 ATH_MSG_ERROR(
"Failed to create MVAUtils::BDT for " << resolvedWeightFileName );
77 return StatusCode::FAILURE;
80 m_MVABDT_List.push_back(std::move(curBDT));
84 return StatusCode::SUCCESS;
96 for (
unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
97 std::string curVar =
"CellBased_" + m_List_BDTVariableNames[iVar];
102 ATH_MSG_DEBUG(
"\tUse default value as the feature (the one below this line) was not calculated");
103 newValue = m_List_BDTVariableDefaultValues[iVar];
108 list_BDTVariableValues[iVar] =
static_cast<float>(newValue);
116 std::vector<float> list_BDTVariableValues(m_List_BDTVariableNames.size());
118 updateReaderVariables(inSeed, list_BDTVariableValues);
129 for (
unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
130 if (inSeed->
p4().Pt() > m_BinEdges_Pt[iPtBin] && inSeed->
p4().Pt() < m_BinEdges_Pt[iPtBin+1]) {
136 ATH_MSG_WARNING(
"Could not find ptBin for tau seed with pt " << inSeed->
p4().Pt());
144 return m_MVABDT_List[ptBin]->GetGradBoostMVA(list_BDTVariableValues);