2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
5 #include "MeasurementToTruthAssociationAlg.h"
7 #include "ActsGeometry/ATLASSourceLink.h"
11 #include <type_traits>
17 template <class T_MeasurementCollection, class T_SimDataCollection, class T_TruthEventCollection, bool IsDebug >
18 MeasurementToTruthAssociationAlg<T_MeasurementCollection,
20 T_TruthEventCollection,
21 IsDebug>::MeasurementToTruthAssociationAlg(const std::string &name,
22 ISvcLocator *pSvcLocator)
23 : AthReentrantAlgorithm(name, pSvcLocator)
27 template <class T_MeasurementCollection, class T_SimDataCollection, class T_TruthEventCollection, bool IsDebug>
28 StatusCode MeasurementToTruthAssociationAlg<T_MeasurementCollection,
30 T_TruthEventCollection,
31 IsDebug>::initialize()
33 ATH_CHECK( m_measurementKey.initialize() );
34 ATH_CHECK( m_simDataKey.initialize() );
35 for (unsigned int i=0; i<m_statRDO.size(); ++i) {
38 ATH_CHECK( m_associationOutKey.initialize() );
39 if constexpr(std::is_same<T_TruthEventCollection, void>::value) {
40 ATH_CHECK( m_truthEventCollectionKey.initialize(false) );
43 ATH_MSG_INFO( " Truth event " << typeid(m_truthEventCollectionKey).name()
44 << " " << m_truthEventCollectionKey.key());
45 ATH_CHECK( m_truthEventCollectionKey.initialize() );
47 return StatusCode::SUCCESS;
50 template <class T_MeasurementCollection, class T_SimDataCollection, class T_TruthEventCollection, bool IsDebug>
51 StatusCode MeasurementToTruthAssociationAlg<T_MeasurementCollection,
53 T_TruthEventCollection,
56 std::array<std::string, kNCategories> names {
57 "Measurements without SimData", // kNoTruth
58 "Measurements with SimData", // kHasSimHit
59 "Deposits without xAOD TruthParticle", // kInvalidTruthLink
60 "SimData without deposits above threshold", // kHasSimHitNoParticle
61 "Associated truth exceeds small vector size" // kBeyondSmallVectorSize
63 auto max_name_iter = std::max_element(names.begin(),names.end(),[](std::string &a,std::string &b) { return a.size() < b.size(); } );
64 for (unsigned int i=0; i<names.size(); ++i) {
65 ATH_MSG_INFO( "RDO truth stat " << std::left << std::setw(max_name_iter->size()) << names[i] << std::right << " " << m_statRDO[i]);
67 if constexpr(IsDebug) {
68 ATH_MSG_INFO("Truth particles per RDO " << m_stat.m_particlesPerMeasurement << std::endl
69 << m_stat.m_particlesPerMeasurement.histogramToString() );
70 ATH_MSG_INFO("Measurements per particle " << m_stat.m_measurementsPerParticle << std::endl
71 << m_stat.m_measurementsPerParticle.histogramToString() );
72 ATH_MSG_INFO("Log10 of deposited energy per RDO " << m_stat.m_depositedEnergy << std::endl
73 << m_stat.m_depositedEnergy.histogramToString() );
75 ATH_MSG_INFO("Deposits HS: " << m_depositCounts[1]
76 << " PU: " << m_depositCounts[0]
77 << " ; without truth particle HS: "
78 << " " << m_depositCounts[1+2]
79 << " PU: " << m_depositCounts[0+2] );
81 return StatusCode::SUCCESS;
84 template <class T_MeasurementCollection, class T_SimDataCollection, class T_TruthEventCollection, bool IsDebug>
85 StatusCode MeasurementToTruthAssociationAlg<T_MeasurementCollection,
87 T_TruthEventCollection,
88 IsDebug>::execute(const EventContext &ctx) const
90 std::unique_ptr<MeasurementToTruthParticleAssociation>
91 association( std::make_unique<MeasurementToTruthParticleAssociation>() );
93 SG::ReadHandle<T_SimDataCollection> simDataHandle = SG::makeHandle(m_simDataKey, ctx);
94 if (!simDataHandle.isValid()) {
95 ATH_MSG_ERROR("No sim data for key " << m_simDataKey.key() );
96 return StatusCode::FAILURE;
99 ATH_MSG_DEBUG("Retrieving measurement collection with key: " << m_measurementKey.key());
100 SG::ReadHandle<T_MeasurementCollection> measurementHandle = SG::makeHandle(m_measurementKey, ctx);
101 if (!measurementHandle.isValid()) {
102 ATH_MSG_ERROR("No measurements for key " << m_measurementKey.key() );
103 return StatusCode::FAILURE;
105 ATH_MSG_DEBUG("Retrieved " << measurementHandle->size() << " input measurements" );
106 association->resize( measurementHandle->size() );
107 association->setSourceContainer(*measurementHandle, ctx);
109 const T_TruthEventCollection *truth_event_collection=nullptr;
110 if constexpr(!std::is_same<T_TruthEventCollection, void>::value) {
111 SG::ReadHandle<T_TruthEventCollection>
112 truthEventCollectionHandle = SG::makeHandle(m_truthEventCollectionKey, ctx);
113 if (!truthEventCollectionHandle.isValid()) {
114 ATH_MSG_ERROR("No truth event collection for key " << m_truthEventCollectionKey.key() );
115 return StatusCode::FAILURE;
117 truth_event_collection = truthEventCollectionHandle.cptr();
120 auto deposit_to_truth_map = makeDepositToTruthParticleMap(truth_event_collection);
122 typename std::conditional<IsDebug,Dbg::HistTemp,Dbg::Empty>::type stat(m_stat);
123 typename std::conditional<IsDebug,
124 std::unordered_map<const xAOD::TruthParticle *,unsigned int>,
126 hits_per_truthparticle;
127 std::array<std::size_t,4> depositCounts {0u, 0u, 0u, 0u};
128 std::array<unsigned int, kNCategories> truth_stat{};
130 const T_SimDataCollection *simData = simDataHandle.cptr();
132 std::vector<float> summed_contribution;
133 const T_MeasurementCollection *measurements = measurementHandle.cptr();
134 for ( const auto *measurement : *measurements ) {
135 summed_contribution.clear();
136 for (const auto& a_rdo : getRDOList(*measurement) ) {
137 typename T_SimDataCollection::const_iterator sim_data_iter(simData->find(a_rdo));
139 if(sim_data_iter != simData->end() ) {
140 auto deposits_for_measurement = getSimDataDeposits(*simData, sim_data_iter);
141 unsigned int n_particles=0;
142 for (const auto &deposit : deposits_for_measurement ) {
143 auto *truth_particle = deposit_to_truth_map.getTruthParticle(deposit);
144 ++(depositCounts.at(deposit_to_truth_map.isHardScatter(deposit)+(truth_particle ? 0u : 2u)));
145 if (!truth_particle) {
146 ++truth_stat[kInvalidTruthLink];
149 if constexpr(IsDebug) { stat.m_depositedEnergy.add( std::log10(getDepositedEnergy(deposit)) ); }
151 if (getDepositedEnergy(deposit) >= m_depositedEnergyMin.value()) {
152 if (measurement->index() >= association->size()) {
153 throw std::range_error("Measurement index out of range");
155 ParticleVector::iterator
156 particle_iter = std::find( (*association)[ measurement->index() ].begin(),
157 (*association)[ measurement->index() ].end(),
159 if ( particle_iter == (*association)[ measurement->index() ].end()
160 || *particle_iter != truth_particle) {
161 particle_iter = (*association)[ measurement->index() ].insert( particle_iter, truth_particle);
164 unsigned int idx = (/*dest_iter*/ particle_iter - (*association)[ measurement->index() ].begin());
165 if (idx >= summed_contribution.size() ) {
166 summed_contribution.resize(idx+1, 0.);
168 summed_contribution[idx] += getDepositedEnergy(deposit);
169 if constexpr(IsDebug) {
170 std::unordered_map<const xAOD::TruthParticle *,unsigned int>::iterator
171 part_hits_iter = hits_per_truthparticle.find(truth_particle);
172 if (part_hits_iter == hits_per_truthparticle.end()) {
173 hits_per_truthparticle.insert(std::make_pair(truth_particle,1));
176 ++(part_hits_iter->second);
181 if constexpr(IsDebug) { stat.m_particlesPerMeasurement.add(n_particles); }
182 ++truth_stat[ n_particles == 0 ? kHasSimHitNoParticle : kHasSimHit];
183 if (n_particles>ActsTrk::NTruthParticlesPerMeasurement) {
184 ++truth_stat[kBeyondSmallVectorSize];
188 ++truth_stat[kNoTruth];
191 if ( (*association)[ measurement->index() ].size()>1) {
192 // move truth particle with largest contribution to the measurement to the beginning of the list
193 std::vector<float>::const_iterator iter_max=std::max_element(summed_contribution.begin(),summed_contribution.end());
194 unsigned int idx = iter_max - summed_contribution.begin();
196 std::swap((*association)[ measurement->index() ][0],
197 (*association)[ measurement->index() ][idx]);
201 if constexpr(IsDebug) {
202 std::lock_guard<std::mutex> lock(m_stat.m_mutex);
203 m_stat.m_depositedEnergy += stat.m_depositedEnergy;
204 m_stat.m_particlesPerMeasurement += stat.m_particlesPerMeasurement;
205 for (const auto &particle_hit_counts : hits_per_truthparticle) {
206 m_stat.m_measurementsPerParticle.add( particle_hit_counts.second);
209 for (unsigned int i=0; i<m_depositCounts.size(); ++i) {
210 m_depositCounts[i] += depositCounts[i];
212 for (unsigned int i=0; i<m_statRDO.size(); ++i) {
213 m_statRDO[i]+=truth_stat[i];
216 SG::WriteHandle<MeasurementToTruthParticleAssociation> associationOutHandle(m_associationOutKey, ctx);
217 if (associationOutHandle.record( std::move(association)).isFailure()) {
218 ATH_MSG_ERROR("Failed to record measurement to truth assocition with key " << m_associationOutKey.key() );
219 return StatusCode::FAILURE;
222 return StatusCode::SUCCESS;