Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
SPIdDumperAlg.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "SPIdDumperAlg.h"
6 
8 #include "Identifier/Identifier.h"
10 #include "StoreGate/ReadHandle.h"
14 
16 
17 #include <fstream>
18 
19 namespace MuonR4 {
20 
22  ATH_CHECK(m_readKey.initialize());
23  ATH_CHECK(m_inSegmentKey.initialize());
24  ATH_CHECK(m_tree.init(this));
25  ATH_CHECK(m_graphFilterTool.retrieve());
26 
27  ATH_MSG_DEBUG("Successfully initialized SPIdDumperAlg with ONNX model filtering.");
28  return StatusCode::SUCCESS;
29  }
30 
32  const EventContext& ctx{Gaudi::Hive::currentContext()};
33 
35  if (!readHandle.isValid()) {
36  ATH_MSG_ERROR("Failed to retrieve SpacePointContainer from StoreGate.");
37  return StatusCode::FAILURE;
38  }
39 
40  MuonML::GraphRawData graphData;
41  ATH_CHECK(m_graphFilterTool->runGraphInference(ctx, graphData));
42 
43  if (graphData.graph->dataTensor.size() < 3) {
44  ATH_MSG_DEBUG("ONNX inference output tensor is missing.");
45  return StatusCode::SUCCESS;
46  }
47 
48  const float* predictions = graphData.graph->dataTensor[2].GetTensorMutableData<float>();
49  size_t predictionIndex = 0;
50 
51  size_t totalNodes = graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount();
52  size_t totalFeatureElements = graphData.graph->dataTensor[0].GetTensorTypeAndShapeInfo().GetElementCount();
53 
54  if (totalNodes == 0) {
55  ATH_MSG_ERROR("Total number of nodes is zero! Cannot divide.");
56  return StatusCode::FAILURE;
57  }
58 
59  size_t numFeaturesPerNode = totalFeatureElements / totalNodes;
60  ATH_MSG_DEBUG("Inferred " << numFeaturesPerNode << " features per node from tensor of size "
61  << totalFeatureElements << " and " << totalNodes << " nodes.");
62 
63  std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap;
64  SG::ReadHandle readSegment(m_inSegmentKey, ctx);
65  ATH_CHECK(readSegment.isPresent());
66  for (const MuonR4::Segment* segment : *readSegment) {
67  segmentMap[segment->parent()->parentBucket()].push_back(segment);
68  }
69 
70  for (const auto& bucket : *readHandle) {
71 
72  std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment;
73  auto match_itr = segmentMap.find(bucket);
74  if (match_itr != segmentMap.end()) {
75  unsigned int segIdx{0};
76  for (const MuonR4::Segment* segment : match_itr->second) {
77  for (const auto& meas : segment->measurements()) {
78  spacePointToSegment[meas->spacePoint()].push_back(segIdx);
79  }
80  ++segIdx;
81  }
82  }
83 
84  SpacePointPerLayerSplitter splitter{*bucket};
85  unsigned int layer{0};
86 
87  for (const auto& hitsInLay : splitter.mdtHits()) {
88  for (const auto sp : hitsInLay){
89 
90  const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement());
92  continue;
93  }
94  m_spoint_x.push_back(sp->positionInChamber().x());
95  m_spoint_y.push_back(sp->positionInChamber().y());
96  m_spoint_z.push_back(sp->positionInChamber().z());
97  m_spoint_driftR.push_back(sp->driftRadius());
98  m_spoint_station.push_back(m_idHelperSvc->stationName(sp->identify()));
100 
101  if (spacePointToSegment.count(sp) > 0) {
103  } else {
105  }
106 
107  if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
108  m_spoint_predictions.push_back(predictions[predictionIndex]);
109  predictionIndex++;
110  } else {
111  ATH_MSG_WARNING("Prediction index exceeded ONNX output size.");
112  m_spoint_predictions.push_back(-999); // Default invalid value
113  }
114 
115 
116  }
117  ++layer;
118 
119  }
120 
121  for (const auto& hitsInLay : splitter.stripHits()) {
122 
123  for (const auto sp : hitsInLay){
124 
125  m_spoint_x.push_back(sp->positionInChamber().x());
126  m_spoint_y.push_back(sp->positionInChamber().y());
127  m_spoint_z.push_back(sp->positionInChamber().z());
128  m_spoint_driftR.push_back(sp->driftRadius());
129  m_spoint_station.push_back(m_idHelperSvc->stationName(sp->identify()));
131 
132  if (spacePointToSegment.count(sp) > 0) {
134  } else {
136  }
137 
138  if (predictionIndex < graphData.graph->dataTensor[2].GetTensorTypeAndShapeInfo().GetElementCount()) {
139  m_spoint_predictions.push_back(predictions[predictionIndex]);
140  predictionIndex++;
141  } else {
142  ATH_MSG_WARNING("Prediction index exceeded ONNX output size.");
143  m_spoint_predictions.push_back(-999); // Default invalid value
144  }
145 
146  }
147  ++layer;
148 
149  }
150 
151  }
152 
153  if (!m_tree.fill(ctx)) return StatusCode::FAILURE;
154  return StatusCode::SUCCESS;
155  }
156 
159  return StatusCode::SUCCESS;
160  }
161 
162 }
MuonR4::SPIdDumperAlg::m_graphFilterTool
ToolHandle< MuonML::IGraphInferenceTool > m_graphFilterTool
Definition: SPIdDumperAlg.h:46
MuonR4::SPIdDumperAlg::m_spoint_label
MuonVal::VectorBranch< uint8_t > & m_spoint_label
Definition: SPIdDumperAlg.h:61
MuonR4::SpacePointPerLayerSplitter
The SpacePointPerLayerSplitter takes a set of spacepoints already sorted by layer Identifier (see Muo...
Definition: SpacePointPerLayerSplitter.h:16
MuonVal::MuonTesterTree::init
StatusCode init(OWNER *instance)
Initialize method.
AthMsgStreamMacros.h
SG::ReadHandle< SpacePointContainer >
MuonR4::SPIdDumperAlg::m_readKey
SG::ReadHandleKey< MuonR4::SpacePointContainer > m_readKey
Definition: SPIdDumperAlg.h:41
MuonR4::SPIdDumperAlg::m_idHelperSvc
ServiceHandle< Muon::IMuonIdHelperSvc > m_idHelperSvc
Definition: SPIdDumperAlg.h:42
MuonR4::Segment
Placeholder for what will later be the muon segment EDM representation.
Definition: MuonSpectrometer/MuonPhaseII/Event/MuonPatternEvent/MuonPatternEvent/Segment.h:19
MuonR4::SPIdDumperAlg::m_spoint_y
MuonVal::VectorBranch< float > & m_spoint_y
Definition: SPIdDumperAlg.h:53
MuonML::GraphRawData::graph
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
Definition: GraphData.h:40
MuonR4::SPIdDumperAlg::m_spoint_driftR
MuonVal::VectorBranch< float > & m_spoint_driftR
Definition: SPIdDumperAlg.h:57
MuonR4::SPIdDumperAlg::execute
virtual StatusCode execute() override
Definition: SPIdDumperAlg.cxx:31
MuonR4::SPIdDumperAlg::m_spoint_layer
MuonVal::VectorBranch< uint16_t > & m_spoint_layer
Definition: SPIdDumperAlg.h:59
EventHashBranch.h
MuonR4::SPIdDumperAlg::m_spoint_x
MuonVal::VectorBranch< float > & m_spoint_x
Definition: SPIdDumperAlg.h:52
MuonR4::SPIdDumperAlg::initialize
virtual StatusCode initialize() override
Definition: SPIdDumperAlg.cxx:21
Muon::MdtStatusDriftTime
@ MdtStatusDriftTime
The tube produced a vaild measurement.
Definition: MdtDriftCircleStatus.h:34
SpacePointPerLayerSplitter.h
MuonR4::SPIdDumperAlg::m_spoint_predictions
MuonVal::VectorBranch< float > & m_spoint_predictions
Definition: SPIdDumperAlg.h:62
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
SPIdDumperAlg.h
TRT::Hit::layer
@ layer
Definition: HitInfo.h:79
MuonR4::SPIdDumperAlg::m_inSegmentKey
SG::ReadHandleKey< MuonR4::SegmentContainer > m_inSegmentKey
Definition: SPIdDumperAlg.h:44
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
MuonML::GraphRawData
Helper struct to ship the Graph from the space point buckets to ONNX.
Definition: GraphData.h:25
MuonR4::SPIdDumperAlg::finalize
virtual StatusCode finalize() override
Definition: SPIdDumperAlg.cxx:157
MdtDriftCircle.h
MuonVal::VectorBranch::push_back
void push_back(const T &value)
Adds a new element at the end of the vector.
GraphData.h
MuonR4
This header ties the generic definitions in this package.
Definition: HoughEventData.h:16
MuonR4::SPIdDumperAlg::m_spoint_station
MuonVal::VectorBranch< uint8_t > & m_spoint_station
Definition: SPIdDumperAlg.h:56
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
MuonVal::MuonTesterTree::fill
bool fill(const EventContext &ctx)
Fills the tree per call.
Definition: MuonTesterTree.cxx:89
MuonVal::MuonTesterTree::write
StatusCode write()
Finally write the TTree objects.
Definition: MuonTesterTree.cxx:178
SG::VarHandleBase::isPresent
bool isPresent() const
Is the referenced object present in SG?
Definition: StoreGate/src/VarHandleBase.cxx:400
ReadHandle.h
Handle class for reading from StoreGate.
xAOD::MdtDriftCircle_v1
https://gitlab.cern.ch/atlas/athena/-/blob/master/MuonSpectrometer/MuonReconstruction/MuonRecEvent/Mu...
Definition: MdtDriftCircle_v1.h:21
MuonR4::SPIdDumperAlg::m_tree
MuonVal::MuonTesterTree m_tree
Definition: SPIdDumperAlg.h:51
IMuonIdHelperSvc.h
MuonR4::SPIdDumperAlg::m_spoint_z
MuonVal::VectorBranch< float > & m_spoint_z
Definition: SPIdDumperAlg.h:54
NSWL1::PadTriggerAdapter::segment
Muon::NSW_PadTriggerSegment segment(const NSWL1::PadTrigger &data)
Definition: PadTriggerAdapter.cxx:5