8 #include "Identifier/Identifier.h"
27 ATH_MSG_DEBUG(
"Successfully initialized SPIdDumperAlg with ONNX model filtering.");
28 return StatusCode::SUCCESS;
32 const EventContext& ctx{Gaudi::Hive::currentContext()};
35 if (!readHandle.isValid()) {
36 ATH_MSG_ERROR(
"Failed to retrieve SpacePointContainer from StoreGate.");
37 return StatusCode::FAILURE;
43 if (graphData.
graph->dataTensor.size() < 3) {
45 return StatusCode::SUCCESS;
48 const float* predictions = graphData.
graph->dataTensor[2].GetTensorMutableData<
float>();
49 size_t predictionIndex = 0;
51 size_t totalNodes = graphData.
graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount();
52 size_t totalFeatureElements = graphData.
graph->dataTensor[0].GetTensorTypeAndShapeInfo().GetElementCount();
54 if (totalNodes == 0) {
55 ATH_MSG_ERROR(
"Total number of nodes is zero! Cannot divide.");
56 return StatusCode::FAILURE;
59 size_t numFeaturesPerNode = totalFeatureElements / totalNodes;
60 ATH_MSG_DEBUG(
"Inferred " << numFeaturesPerNode <<
" features per node from tensor of size "
61 << totalFeatureElements <<
" and " << totalNodes <<
" nodes.");
63 std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap;
70 for (
const auto& bucket : *readHandle) {
72 std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment;
73 auto match_itr = segmentMap.find(bucket);
74 if (match_itr != segmentMap.end()) {
75 unsigned int segIdx{0};
77 for (
const auto& meas :
segment->measurements()) {
78 spacePointToSegment[meas->spacePoint()].push_back(segIdx);
85 unsigned int layer{0};
87 for (
const auto& hitsInLay : splitter.mdtHits()) {
88 for (
const auto sp : hitsInLay){
101 if (spacePointToSegment.count(sp) > 0) {
107 if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
121 for (
const auto& hitsInLay : splitter.stripHits()) {
123 for (
const auto sp : hitsInLay){
132 if (spacePointToSegment.count(sp) > 0) {
138 if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
153 if (!
m_tree.
fill(ctx))
return StatusCode::FAILURE;
154 return StatusCode::SUCCESS;
159 return StatusCode::SUCCESS;