33 ATH_CHECK(m_Tool_InformationStore.retrieve());
36 ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(
"ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
37 ATH_CHECK( m_Tool_InformationStore->getInfo_String(
"ModeDiscriminator_TMVAMethod", m_MethodName) );
40 std::string varNameList_Full =
"ModeDiscriminator_BDTVariableNames_CellBased_" + m_Name_ModeCase;
41 ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
43 std::string varDefaultValueList_Full =
"ModeDiscriminator_BDTVariableDefaults_CellBased_" + m_Name_ModeCase;
44 ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
49 if ( m_List_BDTVariableDefaultValues.size() != m_List_BDTVariableNames.size() ) {
50 ATH_MSG_ERROR(
"Number of variable names does not match number of default values! Check jobOptions!");
51 return StatusCode::FAILURE;
55 for (
unsigned int iPtBin=0; iPtBin<(m_BinEdges_Pt.size() - 1); iPtBin++) {
57 std::string bin_lowerStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin]/1000.);
58 std::string bin_upperStr = m_HelperFunctions.convertNumberToString(m_BinEdges_Pt[iPtBin+1]/1000.);
59 std::string curPtBin =
"ET_" + bin_lowerStr +
"_" + bin_upperStr;
62 std::string curWeightFile = m_calib_path + (!m_calib_path.empty() ?
"/" :
"");
63 curWeightFile +=
"TrainModes_";
64 curWeightFile +=
"CellBased_";
65 curWeightFile += curPtBin +
"_";
66 curWeightFile += m_Name_ModeCase +
"_";
67 curWeightFile += m_MethodName +
".weights.root";
71 if (resolvedWeightFileName.empty()) {
72 ATH_MSG_ERROR(
"Weight file " << curWeightFile <<
" not found!");
73 return StatusCode::FAILURE;
77 std::unique_ptr<TFile> fBDT = std::make_unique<TFile>( resolvedWeightFileName.c_str() );
78 TTree* tBDT =
dynamic_cast<TTree*
> (fBDT->Get(
"BDT"));
79 std::unique_ptr<MVAUtils::BDT> curBDT = std::make_unique<MVAUtils::BDT>(tBDT);
80 if (curBDT ==
nullptr) {
81 ATH_MSG_ERROR(
"Failed to create MVAUtils::BDT for " << resolvedWeightFileName );
82 return StatusCode::FAILURE;
85 m_MVABDT_List.push_back(std::move(curBDT));
89 return StatusCode::SUCCESS;
101 for (
unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
102 std::string curVar =
"CellBased_" + m_List_BDTVariableNames[iVar];
107 ATH_MSG_DEBUG(
"\tUse default value as the feature (the one below this line) was not calculated");
108 newValue = m_List_BDTVariableDefaultValues[iVar];
113 list_BDTVariableValues[iVar] =
static_cast<float>(newValue);
121 std::vector<float> list_BDTVariableValues(m_List_BDTVariableNames.size());
123 updateReaderVariables(inSeed, list_BDTVariableValues);
134 for (
unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
135 if (inSeed->
p4().Pt() > m_BinEdges_Pt[iPtBin] && inSeed->
p4().Pt() < m_BinEdges_Pt[iPtBin+1]) {
141 ATH_MSG_WARNING(
"Could not find ptBin for tau seed with pt " << inSeed->
p4().Pt());
149 return m_MVABDT_List[ptBin]->GetGradBoostMVA(list_BDTVariableValues);