29 return StatusCode::SUCCESS;
34 graphData.
graph = std::make_unique<InferenceGraph>();
46 std::vector<BucketGraphUtils::NodeAux> nodes;
51 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
53 <<
" -> nodes (size>0): " << numNodes
57 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping inference.");
58 return StatusCode::SUCCESS;
61 const int64_t nFeatPerNode = 6;
62 if (numNodes * nFeatPerNode !=
static_cast<int64_t
>(graphData.
featureLeaves.size())) {
63 ATH_MSG_ERROR(
"Feature size mismatch: expected " << (numNodes * nFeatPerNode)
65 return StatusCode::FAILURE;
68 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
69 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
70 graphData.
graph->dataTensor.emplace_back(
71 Ort::Value::CreateTensor<float>(memInfo,
76 return StatusCode::SUCCESS;
86 const int64_t S =
static_cast<int64_t
>(featuresFlat.size() / 6);
89 ATH_MSG_WARNING(
"No valid features for transformer input. Skipping inference.");
90 return StatusCode::SUCCESS;
93 if (msgLvl(MSG::DEBUG)) {
95 ATH_MSG_DEBUG(
"=== DEBUGGING: Transformer input features for first 10 nodes ===");
96 const int64_t debugNodes = std::min(S,
static_cast<int64_t
>(10));
97 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
98 const int64_t baseIdx = nodeIdx * 6;
100 <<
"x=" << featuresFlat[baseIdx + 0] <<
", "
101 <<
"y=" << featuresFlat[baseIdx + 1] <<
", "
102 <<
"z=" << featuresFlat[baseIdx + 2] <<
", "
103 <<
"layers=" << featuresFlat[baseIdx + 3] <<
", "
104 <<
"nSp=" << featuresFlat[baseIdx + 4] <<
", "
105 <<
"bucketSize=" << featuresFlat[baseIdx + 5]);
111 graphData.
graph = std::make_unique<InferenceGraph>();
113 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
116 std::vector<int64_t> fShape{1, S, 6};
118 graphData.
graph->dataTensor.emplace_back(
119 Ort::Value::CreateTensor<float>(memInfo,
126 Ort::AllocatorWithDefaultOptions allocator;
127 std::vector<int64_t> mShape{1, S};
128 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
131 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
132 bool* maskPtr = padVal.GetTensorMutableData<
bool>();
133 for (int64_t i = 0; i < S; ++i) maskPtr[i] =
false;
134 graphData.
graph->dataTensor.emplace_back(std::move(padVal));
136 return StatusCode::SUCCESS;
149 std::vector<BucketGraphUtils::NodeAux> nodes;
150 std::vector<float> throwawayFeatures;
151 std::vector<int64_t> throwawaySp;
154 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
156 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping graph building.");
157 return StatusCode::SUCCESS;
160 std::vector<int64_t> srcEdges, dstEdges;
170 std::vector<int64_t> newSrc;
171 std::vector<int64_t> newDst;
172 newSrc.reserve(srcEdges.size());
173 newDst.reserve(dstEdges.size());
174 for (
size_t k = 0; k < srcEdges.size(); ++k) {
175 const int64_t u = srcEdges[k];
176 const int64_t v = dstEdges[k];
177 const bool okU = (u >= 0 && u < numNodes);
178 const bool okV = (v >= 0 && v < numNodes);
184 ATH_MSG_DEBUG(
"Drop invalid edge " << k <<
": (" << u <<
"->" << v
185 <<
"), valid node range [0," << (numNodes-1) <<
"]");
191 srcEdges.swap(newSrc);
192 dstEdges.swap(newDst);
196 const size_t E = srcEdges.size();
198 if (msgLvl(MSG::DEBUG)) {
202 for (
unsigned int k = 0; k < dumpE; ++k) {
203 ATH_MSG_DEBUG(
"EDGE[" << k <<
"]: " << srcEdges[k] <<
" -> " << dstEdges[k]);
205 std::vector<int> nodeConnections(numNodes, 0);
206 for (
size_t k = 0; k < srcEdges.size(); ++k) {
207 const int64_t u = srcEdges[k];
208 const int64_t v = dstEdges[k];
209 if (u >= 0 && u < numNodes) nodeConnections[u]++;
210 if (v >= 0 && v < numNodes) nodeConnections[v]++;
213 ATH_MSG_INFO(
"=== DEBUGGING: Node Connections (first 10 nodes) ===");
214 const int64_t debugNodeCount = std::min(numNodes,
static_cast<int64_t
>(10));
215 for (int64_t i = 0; i < debugNodeCount; ++i) {
216 ATH_MSG_DEBUG(
"Node[" << i <<
"] connections: " << nodeConnections[i]);
221 ATH_MSG_DEBUG(
"=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
222 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
223 std::stringstream connections;
224 connections <<
"Node[" << nodeIdx <<
"] connected to: ";
225 bool foundAny =
false;
227 for (
size_t k = 0; k < srcEdges.size(); ++k) {
228 const int64_t u = srcEdges[k];
229 const int64_t v = dstEdges[k];
232 if (foundAny) connections <<
", ";
235 }
else if (v == nodeIdx) {
236 if (foundAny) connections <<
", ";
242 if (!foundAny) connections <<
"none";
251 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
252 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(Efinal)};
253 graphData.
graph->dataTensor.emplace_back(
254 Ort::Value::CreateTensor<int64_t>(memInfo,
260 ATH_MSG_DEBUG(
"Built sparse bucket graph: N=" << numNodes <<
", E=" << Efinal);
261 return StatusCode::SUCCESS;
266 const std::vector<const char*>& inputNames,
267 const std::vector<const char*>& outputNames)
const
269 if (!graphData.
graph) {
271 return StatusCode::FAILURE;
273 if (graphData.
graph->dataTensor.empty()) {
275 return StatusCode::FAILURE;
278 if (msgLvl(MSG::DEBUG)) {
282 if (!graphData.
graph->dataTensor.empty()) {
283 const auto& featureTensor = graphData.
graph->dataTensor[0];
284 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
286 << (featShape.size()>1 ? (
"," + std::to_string(featShape[1])) :
"")
287 << (featShape.size()>2 ? (
"," + std::to_string(featShape[2])) :
"") <<
"]");
289 float* featData =
const_cast<Ort::Value&
>(featureTensor).GetTensorMutableData<float>();
290 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
291 ATH_MSG_DEBUG(
"Features tensor total elements: " << totalElements);
294 const size_t debugElements = std::min(totalElements,
static_cast<size_t>(60));
295 for (
size_t i = 0; i < debugElements; i += 6) {
296 if (i + 5 < totalElements) {
298 <<
"x=" << featData[i+0] <<
", "
299 <<
"y=" << featData[i+1] <<
", "
300 <<
"z=" << featData[i+2] <<
", "
301 <<
"layers=" << featData[i+3] <<
", "
302 <<
"nSp=" << featData[i+4] <<
", "
303 <<
"bucketSize=" << featData[i+5]);
310 Ort::RunOptions run_options;
311 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
313 std::vector<Ort::Value> outputs =
314 model().Run(run_options,
316 graphData.
graph->dataTensor.data(),
317 graphData.
graph->dataTensor.size(),
321 if (outputs.empty()) {
323 return StatusCode::FAILURE;
326 float* outData = outputs[0].GetTensorMutableData<
float>();
327 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
330 std::span<float> preds(outData, outData + outSize);
331 for (
size_t i = 0; i < outSize; ++i) {
332 if (!std::isfinite(preds[i])) {
333 ATH_MSG_WARNING(
"Non-finite prediction detected at " << i <<
" -> set to -100.");
338 for (
auto& v : outputs) {
339 graphData.
graph->dataTensor.emplace_back(std::move(v));
341 return StatusCode::SUCCESS;
345 std::vector<const char*> inputNames = {
"features",
"edge_index"};
346 std::vector<const char*> outputNames = {
"output"};
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_WARNING(x)
Handle class for reading from StoreGate.
size_type size() const noexcept
Returns the number of elements in the collection.
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
Helper struct to ship the Graph from the space point buckets to ONNX.
FeatureVec_t featureLeaves
Vector containing all features.
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
EdgeCounterVec_t desEdges
Vect.
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.