26 auto notSpace = [](
unsigned char c) {
return !std::isspace(c); };
27 s.erase(s.begin(), std::find_if(s.begin(), s.end(), notSpace));
28 s.erase(std::find_if(s.rbegin(), s.rend(), notSpace).base(), s.end());
33 std::vector<std::string> out;
35 if (s.empty())
return out;
38 if (!s.empty() && s.front() ==
'[') {
44 if (!token.empty()) out.push_back(token);
50 if (inQuote) token.push_back(c);
52 if (!out.empty())
return out;
56 std::istringstream
ss(s);
58 while (std::getline(
ss, tok,
',')) {
60 if (!tok.empty()) out.push_back(tok);
80 <<
". I/O binding will be used.");
86 return StatusCode::SUCCESS;
91 graphData.
graph = std::make_unique<InferenceGraph>();
103 std::vector<BucketGraphUtils::NodeAux> nodes;
108 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
110 <<
" -> nodes (size>0): " << numNodes
114 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping inference.");
115 return StatusCode::SUCCESS;
119 if (numNodes * nFeatPerNode !=
static_cast<int64_t
>(graphData.
featureLeaves.size())) {
120 ATH_MSG_ERROR(
"Feature size mismatch: expected " << (numNodes * nFeatPerNode)
122 return StatusCode::FAILURE;
125 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
126 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
127 graphData.
graph->dataTensor.emplace_back(
128 Ort::Value::CreateTensor<float>(memInfo,
133 return StatusCode::SUCCESS;
146 ATH_MSG_WARNING(
"No valid features for transformer input. Skipping inference.");
147 return StatusCode::SUCCESS;
150 if (msgLvl(MSG::DEBUG)) {
152 ATH_MSG_DEBUG(
"=== DEBUGGING: Transformer input features for first 10 nodes ===");
153 const int64_t debugNodes = std::min(S,
static_cast<int64_t
>(10));
154 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
157 <<
"x=" << featuresFlat[baseIdx + 0] <<
", "
158 <<
"y=" << featuresFlat[baseIdx + 1] <<
", "
159 <<
"z=" << featuresFlat[baseIdx + 2] <<
", "
160 <<
"layers=" << featuresFlat[baseIdx + 3] <<
", "
161 <<
"nSp=" << featuresFlat[baseIdx + 4] <<
", "
162 <<
"bucketSize=" << featuresFlat[baseIdx + 5]);
168 graphData.
graph = std::make_unique<InferenceGraph>();
170 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
175 graphData.
graph->dataTensor.emplace_back(
176 Ort::Value::CreateTensor<float>(memInfo,
183 Ort::AllocatorWithDefaultOptions allocator;
184 std::vector<int64_t> mShape{1, S};
185 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
188 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
189 bool* maskPtr = padVal.GetTensorMutableData<
bool>();
190 for (int64_t i = 0; i < S; ++i) maskPtr[i] =
false;
191 graphData.
graph->dataTensor.emplace_back(std::move(padVal));
193 return StatusCode::SUCCESS;
206 std::vector<BucketGraphUtils::NodeAux> nodes;
207 std::vector<float> throwawayFeatures;
208 std::vector<int64_t> throwawaySp;
211 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
213 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping graph building.");
214 return StatusCode::SUCCESS;
217 std::vector<int64_t> srcEdges, dstEdges;
227 std::vector<int64_t> newSrc;
228 std::vector<int64_t> newDst;
229 newSrc.reserve(srcEdges.size());
230 newDst.reserve(dstEdges.size());
231 for (
size_t k = 0; k < srcEdges.size(); ++k) {
232 const int64_t u = srcEdges[k];
233 const int64_t v = dstEdges[k];
234 const bool okU = (u >= 0 && u < numNodes);
235 const bool okV = (v >= 0 && v < numNodes);
241 ATH_MSG_DEBUG(
"Drop invalid edge " << k <<
": (" << u <<
"->" << v
242 <<
"), valid node range [0," << (numNodes-1) <<
"]");
248 srcEdges.swap(newSrc);
249 dstEdges.swap(newDst);
253 const size_t E = srcEdges.size();
255 if (msgLvl(MSG::DEBUG)) {
259 for (
unsigned int k = 0; k < dumpE; ++k) {
260 ATH_MSG_DEBUG(
"EDGE[" << k <<
"]: " << srcEdges[k] <<
" -> " << dstEdges[k]);
262 std::vector<int> nodeConnections(numNodes, 0);
263 for (
size_t k = 0; k < srcEdges.size(); ++k) {
264 const int64_t u = srcEdges[k];
265 const int64_t v = dstEdges[k];
266 if (u >= 0 && u < numNodes) nodeConnections[u]++;
267 if (v >= 0 && v < numNodes) nodeConnections[v]++;
270 ATH_MSG_INFO(
"=== DEBUGGING: Node Connections (first 10 nodes) ===");
271 const int64_t debugNodeCount = std::min(numNodes,
static_cast<int64_t
>(10));
272 for (int64_t i = 0; i < debugNodeCount; ++i) {
273 ATH_MSG_DEBUG(
"Node[" << i <<
"] connections: " << nodeConnections[i]);
278 ATH_MSG_DEBUG(
"=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
279 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
280 std::stringstream connections;
281 connections <<
"Node[" << nodeIdx <<
"] connected to: ";
282 bool foundAny =
false;
284 for (
size_t k = 0; k < srcEdges.size(); ++k) {
285 const int64_t u = srcEdges[k];
286 const int64_t v = dstEdges[k];
289 if (foundAny) connections <<
", ";
292 }
else if (v == nodeIdx) {
293 if (foundAny) connections <<
", ";
299 if (!foundAny) connections <<
"none";
308 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
309 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(Efinal)};
310 graphData.
graph->dataTensor.emplace_back(
311 Ort::Value::CreateTensor<int64_t>(memInfo,
317 ATH_MSG_DEBUG(
"Built sparse bucket graph: N=" << numNodes <<
", E=" << Efinal);
318 return StatusCode::SUCCESS;
323 const std::vector<const char*>& inputNames,
324 const std::vector<const char*>& outputNames)
const
326 if (!graphData.
graph) {
328 return StatusCode::FAILURE;
330 if (graphData.
graph->dataTensor.empty()) {
332 return StatusCode::FAILURE;
335 if (msgLvl(MSG::DEBUG)) {
339 if (!graphData.
graph->dataTensor.empty()) {
340 const auto& featureTensor = graphData.
graph->dataTensor[0];
341 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
343 << (featShape.size()>1 ? (
"," + std::to_string(featShape[1])) :
"")
344 << (featShape.size()>2 ? (
"," + std::to_string(featShape[2])) :
"") <<
"]");
346 float* featData =
const_cast<Ort::Value&
>(featureTensor).GetTensorMutableData<float>();
347 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
348 ATH_MSG_DEBUG(
"Features tensor total elements: " << totalElements);
351 const size_t nFeat = (featShape.size() > 1 && featShape[1] > 0) ?
static_cast<size_t>(featShape[1]) : 1;
352 const size_t nNodes = totalElements / nFeat;
353 const size_t debugNodes = std::min(nNodes,
static_cast<size_t>(10));
357 std::vector<std::string> featNames;
359 Ort::AllocatorWithDefaultOptions allocator;
360 Ort::ModelMetadata
meta =
model().GetModelMetadata();
361 auto keys =
meta.GetCustomMetadataMapKeysAllocated(allocator);
362 std::vector<std::string> keyNames;
363 keyNames.reserve(keys.size());
364 for (
const auto& k : keys) keyNames.emplace_back(k.get());
365 const std::array<std::string, 4> candidates{
366 "x_feature_names",
"node_feature_names",
"feature_names",
"input_feature_names"};
367 for (
const std::string& key : candidates) {
368 if (std::find(keyNames.begin(), keyNames.end(), key) != keyNames.end()) {
369 std::string val =
meta.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get();
374 if (featNames.empty()) {
375 ATH_MSG_DEBUG(
"No usable feature-name metadata key found in model; using generic fN labels.");
378 auto featLabel = [&](
size_t f) -> std::string {
379 if (f < featNames.size())
return featNames[f];
380 return "f" + std::to_string(f);
385 std::ostringstream legend;
386 legend <<
"Node feature legend (" << nFeat <<
" features):";
387 for (
size_t f = 0; f < nFeat; ++f) {
388 legend <<
" f" << f <<
"=" << featLabel(f);
389 if (f + 1 < nFeat) legend <<
",";
394 for (
size_t n = 0; n < debugNodes; ++n) {
395 std::ostringstream row;
396 row <<
"ONNXNode[" << n <<
"]:";
397 for (
size_t f = 0; f < nFeat; ++f) {
398 row <<
" f" << f <<
"=" << featData[n * nFeat + f];
399 if (f + 1 < nFeat) row <<
",";
407 Ort::RunOptions run_options;
408 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR);
412 Ort::IoBinding binding(
model());
413 for (std::size_t i = 0; i < inputNames.size(); ++i) {
414 binding.BindInput(inputNames[i], graphData.
graph->dataTensor[i]);
417 Ort::MemoryInfo cpuOut = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
418 for (
const char* outName : outputNames) {
419 binding.BindOutput(outName, cpuOut);
422 model().Run(run_options, binding);
423 binding.SynchronizeOutputs();
425 std::vector<Ort::Value> outputs = binding.GetOutputValues();
426 if (outputs.empty()) {
428 return StatusCode::FAILURE;
431 float* outData = outputs[0].GetTensorMutableData<
float>();
432 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
433 ATH_MSG_DEBUG(
"ONNX (IoBinding) raw output elementCount = " << outSize);
436 std::span<float> preds(outData, outData + outSize);
437 for (
size_t i = 0; i < outSize; ++i) {
438 if (!std::isfinite(preds[i])) {
439 ATH_MSG_WARNING(
"Non-finite prediction detected at " << i <<
" -> set to -100.");
445 for (
auto& v : outputs) {
446 graphData.
graph->dataTensor.emplace_back(std::move(v));
448 return StatusCode::SUCCESS;
452 std::vector<Ort::Value> outputs =
453 model().Run(run_options,
455 graphData.
graph->dataTensor.data(),
456 graphData.
graph->dataTensor.size(),
460 if (outputs.empty()) {
462 return StatusCode::FAILURE;
465 float* outData = outputs[0].GetTensorMutableData<
float>();
466 const size_t outSize = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
470 std::span<float> preds(outData, outData + outSize);
471 for (
size_t i = 0; i < outSize; ++i) {
472 if (!std::isfinite(preds[i])) {
473 ATH_MSG_WARNING(
"Non-finite prediction detected at " << i <<
" -> set to -100.");
479 for (
auto& v : outputs) {
480 graphData.
graph->dataTensor.emplace_back(std::move(v));
482 return StatusCode::SUCCESS;
486 std::vector<const char*> inputNames = {
"features",
"edge_index"};
487 std::vector<const char*> outputNames = {
m_outputName.value().c_str()};
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_WARNING(x)
Handle class for reading from StoreGate.
size_type size() const noexcept
Returns the number of elements in the collection.
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
DataVector< SpacePointBucket > SpacePointContainer
Abrivation of the space point container type.
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
Helper struct to ship the Graph from the space point buckets to ONNX.
FeatureVec_t featureLeaves
Vector containing all features.
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
EdgeCounterVec_t desEdges
Vect.
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.