ATLAS Offline Software
ClassifiedTrackTaggerTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 #include "TLorentzVector.h"
8 
9 #include "MVAUtils/BDT.h"
10 #include "TFile.h"
11 #include "TTree.h"
12 #include "GaudiKernel/IChronoStatSvc.h"
13 //
14 //-------------------------------------------------
15 namespace Analysis {
16 //
17 //Constructor--------------------------------------------------------------
19  const std::string& name,
20  const IInterface* parent):
21  base_class(type,name,parent),
22  m_trackClassificator("InDet::InDetTrkInJetType/TrackClassificationTool",this),
23  m_deltaRConeSize(0.4),
24  m_useFivePtJetBinTCT(false),
25  m_calibFileName("CTT_calib_v00.root"),
26  m_jetCollection("AntiKt4EMPFlowJets")
27  {
28  declareProperty("TrackClassificationTool", m_trackClassificator);
29  declareProperty("deltaRConeSize", m_deltaRConeSize);
30  declareProperty("useFivePtJetBinTCT",m_useFivePtJetBinTCT);
31  declareProperty("JetCollection",m_jetCollection);
32  m_timingProfile=nullptr;
33  }
34 
35 //Initialize---------------------------------------------------------------
37  //retrieve calibration file and initialize the m_trkClassBDT
38  std::string fullPathToFile = PathResolverFindCalibFile("BTagging/20221012track/"+m_calibFileName);
39 
40  std::string strBDTName = m_useFivePtJetBinTCT ?
41  "CTTtrainedWithRetrainedTCT" : "CTTtrainedWithDefaultTCT";
42 
43  std::unique_ptr<TFile> rootFile(TFile::Open(fullPathToFile.c_str(), "READ"));
44  if (!rootFile) {
45  ATH_MSG_ERROR("Can not retrieve ClassifiedTrackTagger calibration root file: " << fullPathToFile);
46  return StatusCode::FAILURE;
47  }
48  std::unique_ptr<TTree> training( (TTree*)rootFile->Get(strBDTName.c_str()) );
49  m_CTTBDT = std::make_unique<MVAUtils::BDT>(training.get());
50 
51  //-------
52  //check that the TrackClassificationTool can be accessed-> InDetTrkInJetType to get TCT weights per TrackParticle
53  if (m_trackClassificator.retrieve().isFailure()) {
54  ATH_MSG_DEBUG("Could not find InDet::InDetTrkInJetType - TrackClassificationTool");
55  return StatusCode::FAILURE;
56  } else {
57  ATH_MSG_DEBUG("InDet::InDetTrkInJetType - TrackClassificationTool found");
58  }
59 
60  //check that the fivePtJetBinTCT is actually used, if CTT is configured to do so
62  ATH_MSG_DEBUG("FivePtJetBin version of TCT is used");
63  if(!m_trackClassificator->usesFivePtJetBinVersion())
64  ATH_MSG_ERROR("FivePtJetBin TCT tool is not used, but required by CTT!");
65  }
66  else{
67  ATH_MSG_DEBUG("Default version of TCT is used");
68  if(m_trackClassificator->usesFivePtJetBinVersion())
69  ATH_MSG_ERROR("FivePtJetBin TCT tool is used, but default version required by CTT!");
70  }
71 
72  //SG::WriteDecorHandleKey for CTT jet decoration
73  if(m_jetCollection.empty()) {ATH_MSG_FATAL("No JetCollection specified! ");}
74  else {
75  m_jetWriteDecorKey = m_jetCollection +".CTTScore";
76  ATH_CHECK( m_jetWriteDecorKey.initialize());
77  }
78 
79  if(msgLvl(MSG::DEBUG)) ATH_CHECK(service("ChronoStatSvc", m_timingProfile));
80 //-----
81  return StatusCode::SUCCESS;
82  }
83 
85  {
86  if(m_timingProfile)m_timingProfile->chronoPrint("ClassifiedTrackTaggerTool");
87  ATH_MSG_DEBUG("ClassifiedTrackTaggerTool finalize()");
88  return StatusCode::SUCCESS;
89  }
90 
91  float ClassifiedTrackTaggerTool::bJetWgts(const std::vector<const xAOD::TrackParticle*> & InpTrk, const xAOD::Vertex & PV, const TLorentzVector & Jet) const
92  {
93  std::vector<std::vector<float>> TCTweights;
94  //for each track inside a cone of deltaR around the jet direction save the TCT output (wgtB, wgtL,wgtG)
95  //if it was not rejected by the TCT track quality cuts
96  for (const auto &itrk : InpTrk) {
97  if((itrk->p4()).DeltaR(Jet)<=m_deltaRConeSize) {
98  std::vector<float> v_tctScore = m_trackClassificator->trkTypeWgts(itrk, PV, Jet);
99  bool b_zeroTCTScore = std::all_of(v_tctScore.begin(), v_tctScore.end(), [](float i) { return i==0; });
100  if(!b_zeroTCTScore) { TCTweights.push_back(v_tctScore); }
101  }
102  }
103 
104  ATH_MSG_DEBUG("[ClassifiedTrackTagger]: retrieved TCT score");
105  int ntrk = TCTweights.size();
106  if(ntrk< 3) {return -5; } //if less than three tracks passing quality cuts of TCT -> return default value of -5
107 
108  //get sorted indices w.r.t. wgtB (highest wgtB track -> first index in sorted_indices)
109  std::vector<int> sorted_indices = GetSortedIndices(TCTweights);
110  ATH_MSG_DEBUG("[ClassifiedTrackTagger]: ntrk = " << ntrk << " in jet");
111 
112  float ptjet = Jet.Pt();
113  float trackMultiplicity = ( ((float)ntrk) / ptjet) * 1.e3;
114 
115  if(m_timingProfile)m_timingProfile->chronoStart("ClassifiedTrackTaggerTool");
116 
117  //-----Use MVAUtils to save CPU
118  //order in which variables are given to the BDT wgtB_0,wgtG_0,wgtL_0,wgtB_1,wgtG_1,wgtL_1,wgtB_2,wgtG_2,wgtL_2, (ntrk/ptjet * 1.e3)
119  // (0: track with highest wgtB), (1: track with 2nd highest wgtB), (2: track with 3rd highest wgtB)
120  int iwgtB=0, iwgtL=1, iwgtG=2;
121  ATH_MSG_DEBUG("[ClassifiedTrackTagger]: ordered signal TCT weights = " << TCTweights[sorted_indices[0]][iwgtB] << "," << TCTweights[sorted_indices[1]][iwgtB] << "," << TCTweights[sorted_indices[2]][iwgtB]);
122 
123  //change input variable ordering when final BDT model is chosen!
124  std::vector<float> bdt_vars = {
125  TCTweights[sorted_indices[0]][iwgtB], TCTweights[sorted_indices[0]][iwgtG], TCTweights[sorted_indices[0]][iwgtL],
126  TCTweights[sorted_indices[1]][iwgtB], TCTweights[sorted_indices[1]][iwgtG], TCTweights[sorted_indices[1]][iwgtL],
127  TCTweights[sorted_indices[2]][iwgtB], TCTweights[sorted_indices[2]][iwgtG], TCTweights[sorted_indices[2]][iwgtL],
128  trackMultiplicity};
129  float score=m_CTTBDT->GetGradBoostMVA(bdt_vars);
130 
131  ATH_MSG_DEBUG("[ClassifiedTrackTagger]: CTT classification score = " << score);
132 
133  if(m_timingProfile)m_timingProfile->chronoStop("ClassifiedTrackTaggerTool");
134  return score;
135  }
136 
137 
138  void ClassifiedTrackTaggerTool::decorateJets(const std::vector<const xAOD::TrackParticle*> & InpTrk, const xAOD::Vertex & primVertex, const xAOD::JetContainer & jets) const
139  {
141  for(const auto curjet : jets){
142  ATH_MSG_DEBUG( " Jet pt: " << curjet->pt()<<" eta: "<<curjet->eta()<<" phi: "<< curjet->phi() );
143  float CTTScore = bJetWgts(InpTrk, primVertex, curjet->p4());
144  jetWriteDecorHandle(*curjet) = CTTScore;
145  }
146  }
147 
148  std::vector<int> ClassifiedTrackTaggerTool::GetSortedIndices(std::vector<std::vector<float>> unordered_vec) const
149  {
150  //from https://stackoverflow.com/questions/1577475/c-sorting-and-keeping-track-of-indexes
151  int ntrk = unordered_vec.size();
152  std::vector<int> indices;
153  indices.clear();
154  for(int i=0; i < ntrk; i++) indices.push_back(i);
155 
156  //sort the vector of indices, such that the index corresponding to the highest wgtB stands first, with the lowest last (unordered_vec[itrk][iwgt], wgtB-> iwgt=0)
157  std::sort(std::begin(indices), std::end(indices),[&unordered_vec](size_t itrk1, size_t itrk2) {return unordered_vec[itrk1][0] > unordered_vec[itrk2][0];});
158 
159  return indices;
160  }
161 
162 }// close namespace
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
Analysis::ClassifiedTrackTaggerTool::m_trackClassificator
ToolHandle< InDet::IInDetTrkInJetType > m_trackClassificator
Definition: ClassifiedTrackTaggerTool.h:82
Analysis::ClassifiedTrackTaggerTool::m_deltaRConeSize
float m_deltaRConeSize
Definition: ClassifiedTrackTaggerTool.h:85
Jet
Basic data class defines behavior for all Jet objects The Jet class is the principal data class for...
Definition: Reconstruction/Jet/JetEvent/JetEvent/Jet.h:47
Analysis::ClassifiedTrackTaggerTool::m_calibFileName
std::string m_calibFileName
Definition: ClassifiedTrackTaggerTool.h:91
Trk::indices
std::pair< long int, long int > indices
Definition: AlSymMatBase.h:24
PlotCalibFromCool.begin
begin
Definition: PlotCalibFromCool.py:94
ClassifiedTrackTaggerTool.h
Analysis::ClassifiedTrackTaggerTool::finalize
virtual StatusCode finalize() override
Definition: ClassifiedTrackTaggerTool.cxx:84
Analysis::ClassifiedTrackTaggerTool::m_timingProfile
IChronoStatSvc * m_timingProfile
Definition: ClassifiedTrackTaggerTool.h:76
mergePhysValFiles.end
end
Definition: DataQuality/DataQualityUtils/scripts/mergePhysValFiles.py:93
Analysis::ClassifiedTrackTaggerTool::m_jetCollection
std::string m_jetCollection
Definition: ClassifiedTrackTaggerTool.h:94
Analysis::ClassifiedTrackTaggerTool::ClassifiedTrackTaggerTool
ClassifiedTrackTaggerTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition: ClassifiedTrackTaggerTool.cxx:18
Analysis::ClassifiedTrackTaggerTool::m_useFivePtJetBinTCT
bool m_useFivePtJetBinTCT
Definition: ClassifiedTrackTaggerTool.h:88
Analysis::ClassifiedTrackTaggerTool::bJetWgts
virtual float bJetWgts(const std::vector< const xAOD::TrackParticle * > &, const xAOD::Vertex &, const TLorentzVector &) const override
Method to retrieve the classifier score of the ClassifiedTrackTagger (CTT)
Definition: ClassifiedTrackTaggerTool.cxx:91
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
CheckAppliedSFs.e3
e3
Definition: CheckAppliedSFs.py:264
lumiFormat.i
int i
Definition: lumiFormat.py:92
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
SG::WriteDecorHandle
Handle class for adding a decoration to an object.
Definition: StoreGate/StoreGate/WriteDecorHandle.h:99
test_pyathena.parent
parent
Definition: test_pyathena.py:15
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
makeComparison.rootFile
rootFile
Definition: makeComparison.py:27
BDT.h
DataVector
Derived DataVector<T>.
Definition: DataVector.h:581
Analysis::ClassifiedTrackTaggerTool::decorateJets
virtual void decorateJets(const std::vector< const xAOD::TrackParticle * > &, const xAOD::Vertex &, const xAOD::JetContainer &) const override
Method to decorate the xAOD::Jet object with the CTT score.
Definition: ClassifiedTrackTaggerTool.cxx:138
Analysis
The namespace of all packages in PhysicsAnalysis/JetTagging.
Definition: BTaggingCnvAlg.h:20
PathResolver.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
xAOD::score
@ score
Definition: TrackingPrimitives.h:513
Analysis::ClassifiedTrackTaggerTool::GetSortedIndices
std::vector< int > GetSortedIndices(std::vector< std::vector< float >> unordered_vec) const
Private method for sorting tracks according to the highest wgtB.
Definition: ClassifiedTrackTaggerTool.cxx:148
xAOD::Vertex_v1
Class describing a Vertex.
Definition: Vertex_v1.h:42
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
DEBUG
#define DEBUG
Definition: page_access.h:11
declareProperty
#define declareProperty(n, p, h)
Definition: BaseFakeBkgTool.cxx:15
defineDB.jets
list jets
Definition: JetTagCalibration/share/defineDB.py:24
Analysis::ClassifiedTrackTaggerTool::m_jetWriteDecorKey
SG::WriteDecorHandleKey< xAOD::JetContainer > m_jetWriteDecorKey
The write key for adding CTT score to the jets.
Definition: ClassifiedTrackTaggerTool.h:100
Analysis::ClassifiedTrackTaggerTool::m_CTTBDT
std::unique_ptr< MVAUtils::BDT > m_CTTBDT
Definition: ClassifiedTrackTaggerTool.h:79
readCCLHist.float
float
Definition: readCCLHist.py:83
Analysis::ClassifiedTrackTaggerTool::initialize
virtual StatusCode initialize() override
Definition: ClassifiedTrackTaggerTool.cxx:36