17#include <unordered_map>
19#include "CLHEP/Random/RandFlat.h"
22 struct LocalSegSorter{
27 if (
a->chamberIndex() !=
b->chamberIndex()) {
28 return a->chamberIndex() <
b->chamberIndex();
30 if (
a->sector() !=
b->sector()) {
31 return a->sector() <
b->sector();
33 if (
a->etaIndex() !=
b->etaIndex()) {
34 return a->etaIndex() <
b->etaIndex();
36 using namespace MuonR4::SegmentFit;
39 return locParsA < locParsB;
55 m_tree.addBranch(std::make_shared<MuonVal::EventHashBranch>(
m_tree.tree()));
77 return StatusCode::SUCCESS;
82 return StatusCode::SUCCESS;
86 const EventContext& ctx{Gaudi::Hive::currentContext()};
94 return StatusCode::SUCCESS;
102 std::set<const xAOD::MuonSegment*, LocalSegSorter>>;
103 SegmentsPerBucket_t segmentMap{};
121 std::unordered_map<const SpacePointBucket*, std::vector<float>> bucketScores;
122 const std::string& containerName = spacePointKey.
key();
124 ATH_MSG_DEBUG(
"Computing ML bucket scores for container: " << containerName);
126 ATH_MSG_DEBUG(
"ML scoring completed, got scores for " << bucketScores.size() <<
" buckets");
128 ATH_MSG_DEBUG(
"Skipping ML scoring for container: " << containerName
129 <<
" (only MuonSpacePoints is supported)");
139 CLHEP::RandFlat::shoot(rndEngine,0.,1.) >
m_fracToKeep) {
145 auto scoreIt = bucketScores.find(bucket);
146 if (scoreIt != bucketScores.end() && scoreIt->second.size() >= 3) {
147 const auto& logits = scoreIt->second;
148 int predictedClass = 0;
149 float maxLogit = logits[0];
150 if (logits[1] > maxLogit) { maxLogit = logits[1]; predictedClass = 1; }
151 if (logits[2] > maxLogit) { maxLogit = logits[2]; predictedClass = 2; }
152 if (predictedClass == 0)
continue;
165 const Amg::Vector3D bucketPos = bucket->msSector()->localToGlobalTransform(*gctx) *
166 (0.5*(bucket->coveredMin() + bucket->coveredMax()) * Amg::Vector3D::UnitY());
172 m_bucket_truthHit = std::ranges::any_of(*bucket,[
this](
const SpacePointBucket::value_type &
sp){
181 std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment{};
182 std::set<const xAOD::MuonSegment*, LocalSegSorter> truthSegments{};
184 auto match_itr = segmentMap.find(bucket);
185 if (match_itr != segmentMap.end()) {
188 if (meas->fitState() == CalibratedSpacePoint::State::Valid) {
189 spacePointToSegment[meas->spacePoint()].push_back(segment->
index());
200 unsigned truthLink = -1;
202 truthLink = truthPart->index();
206 truthSegments.insert(truthSeg);
216 std::vector<unsigned int> layNumbers{};
217 std::unordered_map<const SpacePoint*, std::vector<const xAOD::MuonSegment*>> spToTrueSeg{};
219 using SegLinkVec_t = std::vector<ElementLink<xAOD::MuonSegmentContainer>>;
221 for (
const auto&
sp : *bucket){
222 for (
const auto& link : segAcc(*
sp->primaryMeasurement())) {
223 spToTrueSeg[
sp.get()].push_back(*link);
224 truthSegments.insert(*link);
229 for(
const SpacePointBucket::value_type&
sp : *bucket) {
232 if (std::find(layNumbers.begin(), layNumbers.end(), layNum) == layNumbers.end()) {
233 layNumbers.push_back(layNum);
235 const unsigned layer = layNumbers.size()-1;
241 if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){
259 const std::vector<int16_t>& segIdxs = spacePointToSegment[
sp.get()];
263 trueSegLinks.push_back(std::distance(truthSegments.begin(), truthSegments.find(matchedSeg)));
270 m_spoint_covX.push_back(
sp->covariance()[Acts::toUnderlying(CovIdx::phiCov)]);
271 m_spoint_covY.push_back(
sp->covariance()[Acts::toUnderlying(CovIdx::etaCov)]);
283 Amg::Vector3D globalPos =
sp->msSector()->localToGlobalTransform(*gctx) *
sp->localPosition();
301 auto scoreIt = bucketScores.find(bucket);
302 if (scoreIt != bucketScores.end() && scoreIt->second.size() >= 3) {
309 if (bucket && !bucket->empty()) {
310 ATH_MSG_WARNING(
"Non-empty bucket from scored container not found in ML scores map");
324 return StatusCode::FAILURE;
328 return StatusCode::SUCCESS;
335 rngWrapper->
setSeed(rngName, ctx);
341 std::unordered_map<
const SpacePointBucket*, std::vector<float>>& bucketScoreMap)
const {
342 bucketScoreMap.clear();
346 <<
", spContainer=" << (spContainer ?
"valid" :
"null"));
349 return StatusCode::SUCCESS;
353 size_t nonEmptyBuckets = 0;
355 if (bucket && !bucket->empty()) {
361 if (nonEmptyBuckets == 0) {
362 ATH_MSG_DEBUG(
"Container has no non-empty buckets, skipping ML scoring");
363 return StatusCode::SUCCESS;
373 if (!graphData.
graph || graphData.
graph->dataTensor.size() <= 2) {
375 return StatusCode::FAILURE;
378 const Ort::Value& outTensor = graphData.
graph->dataTensor[2];
379 const auto& info = outTensor.GetTensorTypeAndShapeInfo();
380 std::vector<int64_t> outShape = info.GetShape();
382 if (outShape.size() != 2 || outShape[1] != 3) {
384 return StatusCode::FAILURE;
387 const float* logitsPtr = outTensor.GetTensorData<
float>();
390 return StatusCode::FAILURE;
394 size_t totalNumPredictions =
static_cast<size_t>(outShape[0]);
398 if (!bucket || bucket->empty()) {
402 if (predIdx >= totalNumPredictions) {
403 ATH_MSG_ERROR(
"More non-empty buckets than predictions from model");
404 return StatusCode::FAILURE;
407 std::vector<float> scores(3);
408 scores[0] = logitsPtr[3 * predIdx + 0];
409 scores[1] = logitsPtr[3 * predIdx + 1];
410 scores[2] = logitsPtr[3 * predIdx + 2];
412 bucketScoreMap[bucket] = scores;
416 if (predIdx != totalNumPredictions) {
417 ATH_MSG_ERROR(
"Number of non-empty buckets (" << predIdx <<
") does not match predictions ("
418 << totalNumPredictions <<
")");
419 return StatusCode::FAILURE;
422 ATH_MSG_INFO(
"Successfully computed ML bucket scores for " << bucketScoreMap.size() <<
" buckets");
423 return StatusCode::SUCCESS;
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
Handle class for reading from StoreGate.
static const Attributes_t empty
A wrapper class for event-slot-local random engines.
void setSeed(const std::string &algName, const EventContext &ctx)
Set the random seed using a string (e.g.
CLHEP::HepRandomEngine * getEngine(const EventContext &ctx) const
Retrieve the random engine corresponding to the provided EventContext.
ToolHandle< MuonValR4::IPatternVisualizationTool > m_visionTool
Pattern visualization tool.
MuonVal::ScalarBranch< Char_t > & m_bucket_side
MuonVal::ThreeVectorBranch m_spoint_localPosition
MuonVal::VectorBranch< unsigned short > & m_spoint_measuresEta
MuonVal::ScalarBranch< uint8_t > & m_bucket_sector
MuonVal::VectorBranch< float > & m_truthSegLocTheta
MuonVal::VectorBranch< float > & m_segment_numberDoF
SG::ReadHandleKeyArray< SpacePointContainer > m_spacePointKeys
Gaudi::Property< bool > m_doMLBucketFilter
StatusCode dumpContainer(const EventContext &ctx, const SG::ReadHandleKey< SpacePointContainer > &spacePointKey, const SG::ReadHandleKey< xAOD::MuonSegmentContainer > &segmentKey)
Dumps the space point container with the associated muon segment container.
ToolHandle< MuonML::IGraphInferenceTool > m_inferenceTool
Inference tool for ML bucket scoring (optional)
MuonVal::ScalarBranch< float > & m_bucket_min
MuonVal::VectorBranch< uint16_t > & m_spoint_layer
MuonVal::ThreeVectorBranch m_segmentPos
MuonVal::VectorBranch< float > & m_spoint_covX
MuonVal::VectorBranch< float > & m_segmentLocTheta
MuonVal::VectorBranch< float > & m_truthSegLocY
MuonVal::VectorBranch< float > & m_segment_chiSquared
CLHEP::HepRandomEngine * getRandomEngine(const EventContext &ctx) const
MuonVal::MatrixBranch< int16_t > & m_spoint_mat
MuonVal::ScalarBranch< float > & m_bucket_posZ
MuonVal::ScalarBranch< uint16_t > & m_bucket_layers
MuonVal::VectorBranch< float > & m_truthSegLocPhi
MuonVal::VectorBranch< float > & m_segmentLocY
MuonVal::VectorBranch< uint16_t > & m_segmentTruthIdx
virtual StatusCode initialize() override final
MuonVal::VectorBranch< float > & m_bucket_ml_score_class1
MuonVal::MuonTesterTree m_tree
MuonVal::VectorBranch< float > & m_segmentLocX
Gaudi::Property< double > m_fracToKeep
MuonVal::ScalarBranch< float > & m_bucket_max
virtual StatusCode execute() override final
MuonVal::VectorBranch< uint16_t > & m_spoint_nSegments
Gaudi::Property< bool > m_doMLBucketScore
SG::ReadDecorHandleKeyArray< xAOD::MuonSegmentContainer > m_truthDecorKeys
Gaudi::Property< bool > m_isMC
MuonVal::VectorBranch< float > & m_bucket_ml_score_class0
ML bucket filter scores (3 classes for the filter model)
MuonVal::VectorBranch< unsigned int > & m_spoint_nPhiInstances
StatusCode computeAllBucketScores(const EventContext &ctx, const SpacePointContainer *spContainer, std::unordered_map< const SpacePointBucket *, std::vector< float > > &bucketScoreMap) const
Computes ML bucket scores using the inference tool for a specific container.
MuonVal::ScalarBranch< uint8_t > & m_bucket_truthHit
Gaudi::Property< std::string > m_streamName
MuonVal::ThreeVectorBranch m_segmentDir
MuonVal::VectorBranch< uint16_t > & m_spoint_tdc
MuonVal::VectorBranch< float > & m_spoint_driftR
MuonVal::ScalarBranch< float > & m_bucket_posY
MuonVal::VectorBranch< float > & m_segmentLocPhi
MuonVal::ScalarBranch< float > & m_bucket_posX
MuonVal::VectorBranch< uint16_t > & m_spoint_adc
MuonVal::MuonIdentifierBranch m_spoint_id
SG::ReadHandleKey< ActsTrk::GeometryContext > m_geoCtxKey
MuonVal::ScalarBranch< uint16_t > & m_bucket_segments
MuonVal::VectorBranch< unsigned int > & m_spoint_dimension
SG::ReadHandleKeyArray< xAOD::MuonSegmentContainer > m_inSegmentKeys
MuonVal::ScalarBranch< uint8_t > & m_bucket_chamberIdx
MuonVal::VectorBranch< float > & m_spoint_covY
MuonVal::MatrixBranch< int16_t > & m_spoint_trueSeg
MuonVal::VectorBranch< unsigned short > & m_spoint_isStrip
MuonVal::VectorBranch< unsigned int > & m_spoint_nEtaInstances
MuonVal::VectorBranch< float > & m_truthSegLocX
MuonVal::VectorBranch< unsigned short > & m_spoint_trueLabel
ServiceHandle< IAthRNGSvc > m_rndmSvc
MuonVal::ScalarBranch< uint16_t > & m_bucket_spacePoints
MuonVal::VectorBranch< float > & m_bucket_ml_score_class2
MuonVal::VectorBranch< unsigned short > & m_spoint_isMdt
virtual StatusCode finalize() override final
MuonVal::ThreeVectorBranch m_spoint_globalPosition
MuonVal::VectorBranch< unsigned short > & m_spoint_measuresPhi
const SpacePointBucket * parentBucket() const
Returns the bucket out of which the seed was formed.
const SegmentSeed * parent() const
Returns the seed out of which the segment was built.
const MeasVec & measurements() const
Returns the associated measurements.
: The muon space point bucket represents a collection of points that will bre processed together in t...
The SpacePointPerLayerSorter sort two given space points by their layer Identifier.
unsigned int sectorLayerNum(const SpacePoint &sp) const
method returning the logic layer number
size_t index() const
Return the index of this element within its container.
Helper class to provide constant type-safe access to aux data.
Property holding a SG store/key/clid from which a ReadHandle is made.
const std::string & key() const
Return the StoreGate ID for the referenced object.
StatusCode initialize(bool used=true)
If this object is used as a property, then this should be called during the initialize phase.
float numberDoF() const
Returns the numberDoF.
Amg::Vector3D direction() const
Returns the direction as Amg::Vector.
Amg::Vector3D position() const
Returns the position as Amg::Vector.
Eigen::Matrix< double, 3, 1 > Vector3D
Parameters localSegmentPars(const xAOD::MuonSegment &seg)
Returns the localSegPars decoration from a xAODMuon::Segment.
This header ties the generic definitions in this package.
const xAOD::TruthParticle * getTruthMatchedParticle(const xAOD::MuonSegment &segment)
Returns the particle truth-matched to the segment.
const xAOD::MuonSegment * getMatchedTruthSegment(const xAOD::MuonSegment &segment)
Returns the truth-matched segment.
std::vector< SegLink_t > SegLinkVec_t
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const Segment * detailedSegment(const xAOD::MuonSegment &seg)
Helper function to navigate from the xAOD::MuonSegment to the MuonR4::Segment.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
MdtDriftCircle_v1 MdtDriftCircle
MuonSegmentContainer_v1 MuonSegmentContainer
Definition of the current "MuonSegment container version".
TruthParticle_v1 TruthParticle
Typedef to implementation.
MuonSegment_v1 MuonSegment
Reference the current persistent version:
Helper struct to ship the Graph from the space point buckets to ONNX.
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.