8#include "Identifier/Identifier.h"
27 ATH_MSG_DEBUG(
"Successfully initialized SPIdDumperAlg with ONNX model filtering.");
28 return StatusCode::SUCCESS;
32 const EventContext& ctx{Gaudi::Hive::currentContext()};
36 ATH_MSG_ERROR(
"Failed to retrieve SpacePointContainer from StoreGate.");
37 return StatusCode::FAILURE;
43 if (graphData.
graph->dataTensor.size() < 3) {
45 return StatusCode::SUCCESS;
48 const float* predictions = graphData.
graph->dataTensor[2].GetTensorMutableData<
float>();
49 size_t predictionIndex = 0;
51 size_t totalNodes = graphData.
graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount();
52 size_t totalFeatureElements = graphData.
graph->dataTensor[0].GetTensorTypeAndShapeInfo().GetElementCount();
54 if (totalNodes == 0) {
55 ATH_MSG_ERROR(
"Total number of nodes is zero! Cannot divide.");
56 return StatusCode::FAILURE;
59 size_t numFeaturesPerNode = totalFeatureElements / totalNodes;
60 ATH_MSG_DEBUG(
"Inferred " << numFeaturesPerNode <<
" features per node from tensor of size "
61 << totalFeatureElements <<
" and " << totalNodes <<
" nodes.");
63 std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap;
67 segmentMap[segment->parent()->parentBucket()].push_back(segment);
72 std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment;
73 auto match_itr = segmentMap.find(bucket);
74 if (match_itr != segmentMap.end()) {
75 unsigned int segIdx{0};
78 spacePointToSegment[meas->spacePoint()].push_back(segIdx);
85 unsigned int layer{0};
87 for (
const auto& hitsInLay : splitter.
mdtHits()) {
88 for (
const auto sp : hitsInLay){
91 if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){
101 if (spacePointToSegment.count(
sp) > 0) {
107 if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
121 for (
const auto& hitsInLay : splitter.
stripHits()) {
123 for (
const auto sp : hitsInLay){
132 if (spacePointToSegment.count(
sp) > 0) {
138 if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
153 if (!
m_tree.fill(ctx))
return StatusCode::FAILURE;
154 return StatusCode::SUCCESS;
159 return StatusCode::SUCCESS;
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_WARNING(x)
Handle class for reading from StoreGate.
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
MuonVal::VectorBranch< float > & m_spoint_driftR
MuonVal::VectorBranch< float > & m_spoint_x
ServiceHandle< Muon::IMuonIdHelperSvc > m_idHelperSvc
MuonVal::MuonTesterTree m_tree
MuonVal::VectorBranch< uint16_t > & m_spoint_layer
MuonVal::VectorBranch< uint8_t > & m_spoint_station
MuonVal::VectorBranch< float > & m_spoint_z
virtual StatusCode finalize() override
virtual StatusCode initialize() override
SG::ReadHandleKey< MuonR4::SegmentContainer > m_inSegmentKey
ToolHandle< MuonML::IGraphInferenceTool > m_graphFilterTool
MuonVal::VectorBranch< uint8_t > & m_spoint_label
virtual StatusCode execute() override
MuonVal::VectorBranch< float > & m_spoint_y
MuonVal::VectorBranch< float > & m_spoint_predictions
Placeholder for what will later be the muon segment EDM representation.
const MeasVec & measurements() const
Returns the associated measurements.
: The muon space point bucket represents a collection of points that will bre processed together in t...
The SpacePointPerLayerSplitter takes a set of spacepoints already sorted by layer Identifier (see Muo...
const HitLayVec & mdtHits() const
Returns the sorted Mdt hits.
const HitLayVec & stripHits() const
Returns the sorted strip hits.
virtual bool isValid() override final
Can the handle be successfully dereferenced?
bool isPresent() const
Is the referenced object present in SG?
This header ties the generic definitions in this package.
MdtDriftCircle_v1 MdtDriftCircle
Helper struct to ship the Graph from the space point buckets to ONNX.
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.