 |
ATLAS Offline Software
|
#include <BucketInferenceToolBase.h>
|
| StatusCode | setupModel () |
| |
| Ort::Session & | model () const |
| |
| StatusCode | buildFeaturesOnly (const EventContext &ctx, GraphRawData &graphData) const |
| | Build only features (N,6); attaches one tensor in graph.dataTensor[0]. More...
|
| |
| StatusCode | buildTransformerInputs (const EventContext &ctx, GraphRawData &graphData) const |
| | Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1. More...
|
| |
| StatusCode | runNamedInference (GraphRawData &graphData, const std::vector< const char * > &inputNames, const std::vector< const char * > &outputNames) const |
| | Generic named inference, for tools with different I/O conventions. More...
|
| |
|
| SG::ReadHandleKey< MuonR4::SpacePointContainer > | m_readKey {this, "ReadSpacePoints", "MuonSpacePoints"} |
| |
| SG::ReadHandleKey< ActsTrk::GeometryContext > | m_geoCtxKey {this, "AlignmentKey", "ActsAlignment", "cond handle key"} |
| |
| Gaudi::Property< int > | m_minLayers {this, "MinLayersValid", 3} |
| |
| Gaudi::Property< int > | m_maxChamberDelta {this, "MaxChamberDelta", 13} |
| |
| Gaudi::Property< int > | m_maxSectorDelta {this, "MaxSectorDelta", 1} |
| |
| Gaudi::Property< double > | m_maxDistXY {this, "MaxDistXY", 6800.0} |
| |
| Gaudi::Property< double > | m_maxAbsDz {this, "MaxAbsDz", 15000.0} |
| |
| Gaudi::Property< unsigned int > | m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5} |
| |
| Gaudi::Property< unsigned int > | m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12} |
| |
| Gaudi::Property< bool > | m_validateEdges {this, "ValidateEdges", true} |
| |
BucketInferenceToolBase
Common infra to:
- read buckets & (optionally) geometry
- build node features
- (optionally) build GNN sparse edges (via BucketGraphUtils)
- wrap tensors and run ONNX sessions
GNN-specific operations are in BucketGraphUtils.* Transformer tools reuse feature building without edges and add a pad mask.
Definition at line 36 of file BucketInferenceToolBase.h.
◆ ~BucketInferenceToolBase()
| MuonML::BucketInferenceToolBase::~BucketInferenceToolBase |
( |
| ) |
|
|
overridedefault |
◆ buildFeaturesOnly()
| StatusCode BucketInferenceToolBase::buildFeaturesOnly |
( |
const EventContext & |
ctx, |
|
|
GraphRawData & |
graphData |
|
) |
| const |
|
protected |
Build only features (N,6); attaches one tensor in graph.dataTensor[0].
Definition at line 32 of file BucketInferenceToolBase.cxx.
34 graphData.
graph = std::make_unique<InferenceGraph>();
46 std::vector<BucketGraphUtils::NodeAux> nodes;
51 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
53 <<
" -> nodes (size>0): " << numNodes
57 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping inference.");
58 return StatusCode::SUCCESS;
61 const int64_t nFeatPerNode = 6;
62 if (numNodes * nFeatPerNode !=
static_cast<int64_t
>(graphData.
featureLeaves.size())) {
63 ATH_MSG_ERROR(
"Feature size mismatch: expected " << (numNodes * nFeatPerNode)
65 return StatusCode::FAILURE;
68 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
69 std::vector<int64_t> featShape{numNodes, nFeatPerNode};
70 graphData.
graph->dataTensor.emplace_back(
71 Ort::Value::CreateTensor<float>(memInfo,
76 return StatusCode::SUCCESS;
◆ buildGraph()
| StatusCode BucketInferenceToolBase::buildGraph |
( |
const EventContext & |
ctx, |
|
|
GraphRawData & |
graphData |
|
) |
| const |
GNN-style graph builder (features + edges). Kept for tools that want it.
Definition at line 139 of file BucketInferenceToolBase.cxx.
149 std::vector<BucketGraphUtils::NodeAux> nodes;
150 std::vector<float> throwawayFeatures;
151 std::vector<int64_t> throwawaySp;
154 const int64_t numNodes =
static_cast<int64_t
>(nodes.size());
156 ATH_MSG_WARNING(
"No valid buckets found (all have size 0.0). Skipping graph building.");
157 return StatusCode::SUCCESS;
160 std::vector<int64_t> srcEdges, dstEdges;
170 std::vector<int64_t> newSrc;
171 std::vector<int64_t> newDst;
172 newSrc.reserve(srcEdges.size());
173 newDst.reserve(dstEdges.size());
174 for (
size_t k = 0;
k < srcEdges.size(); ++
k) {
175 const int64_t
u = srcEdges[
k];
176 const int64_t
v = dstEdges[
k];
177 const bool okU = (
u >= 0 &&
u < numNodes);
178 const bool okV = (
v >= 0 &&
v < numNodes);
185 <<
"), valid node range [0," << (numNodes-1) <<
"]");
191 srcEdges.swap(newSrc);
192 dstEdges.swap(newDst);
196 const size_t E = srcEdges.size();
202 for (
unsigned int k = 0;
k < dumpE; ++
k) {
203 ATH_MSG_DEBUG(
"EDGE[" <<
k <<
"]: " << srcEdges[
k] <<
" -> " << dstEdges[
k]);
205 std::vector<int> nodeConnections(numNodes, 0);
206 for (
size_t k = 0;
k < srcEdges.size(); ++
k) {
207 const int64_t
u = srcEdges[
k];
208 const int64_t
v = dstEdges[
k];
209 if (
u >= 0 &&
u < numNodes) nodeConnections[
u]++;
210 if (
v >= 0 &&
v < numNodes) nodeConnections[
v]++;
213 ATH_MSG_INFO(
"=== DEBUGGING: Node Connections (first 10 nodes) ===");
214 const int64_t debugNodeCount =
std::min(numNodes,
static_cast<int64_t
>(10));
215 for (int64_t
i = 0;
i < debugNodeCount; ++
i) {
216 ATH_MSG_DEBUG(
"Node[" <<
i <<
"] connections: " << nodeConnections[
i]);
221 ATH_MSG_DEBUG(
"=== DEBUGGING: Detailed Edge Connections (first 10 nodes) ===");
222 for (int64_t nodeIdx = 0; nodeIdx < debugNodeCount; ++nodeIdx) {
223 std::stringstream connections;
224 connections <<
"Node[" << nodeIdx <<
"] connected to: ";
225 bool foundAny =
false;
227 for (
size_t k = 0;
k < srcEdges.size(); ++
k) {
228 const int64_t
u = srcEdges[
k];
229 const int64_t
v = dstEdges[
k];
232 if (foundAny) connections <<
", ";
235 }
else if (
v == nodeIdx) {
236 if (foundAny) connections <<
", ";
242 if (!foundAny) connections <<
"none";
251 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
252 std::vector<int64_t> edgeShape{2,
static_cast<int64_t
>(Efinal)};
253 graphData.
graph->dataTensor.emplace_back(
254 Ort::Value::CreateTensor<int64_t>(memInfo,
260 ATH_MSG_DEBUG(
"Built sparse bucket graph: N=" << numNodes <<
", E=" << Efinal);
261 return StatusCode::SUCCESS;
◆ buildTransformerInputs()
| StatusCode BucketInferenceToolBase::buildTransformerInputs |
( |
const EventContext & |
ctx, |
|
|
GraphRawData & |
graphData |
|
) |
| const |
|
protected |
Build Transformer inputs: features [1,S,6] and pad_mask [1,S] (False = valid), as tensors 0 and 1.
Definition at line 79 of file BucketInferenceToolBase.cxx.
86 const int64_t
S =
static_cast<int64_t
>(featuresFlat.size() / 6);
89 ATH_MSG_WARNING(
"No valid features for transformer input. Skipping inference.");
90 return StatusCode::SUCCESS;
95 ATH_MSG_DEBUG(
"=== DEBUGGING: Transformer input features for first 10 nodes ===");
96 const int64_t debugNodes =
std::min(
S,
static_cast<int64_t
>(10));
97 for (int64_t nodeIdx = 0; nodeIdx < debugNodes; ++nodeIdx) {
98 const int64_t baseIdx = nodeIdx * 6;
100 <<
"x=" << featuresFlat[baseIdx + 0] <<
", "
101 <<
"y=" << featuresFlat[baseIdx + 1] <<
", "
102 <<
"z=" << featuresFlat[baseIdx + 2] <<
", "
103 <<
"layers=" << featuresFlat[baseIdx + 3] <<
", "
104 <<
"nSp=" << featuresFlat[baseIdx + 4] <<
", "
105 <<
"bucketSize=" << featuresFlat[baseIdx + 5]);
111 graphData.
graph = std::make_unique<InferenceGraph>();
113 Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
116 std::vector<int64_t> fShape{1,
S, 6};
118 graphData.
graph->dataTensor.emplace_back(
119 Ort::Value::CreateTensor<float>(memInfo,
126 Ort::AllocatorWithDefaultOptions allocator;
127 std::vector<int64_t> mShape{1,
S};
128 Ort::Value padVal = Ort::Value::CreateTensor(allocator,
131 ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL);
132 bool* maskPtr = padVal.GetTensorMutableData<
bool>();
133 for (int64_t
i = 0;
i <
S; ++
i) maskPtr[
i] =
false;
134 graphData.
graph->dataTensor.emplace_back(std::move(padVal));
136 return StatusCode::SUCCESS;
◆ model()
| Ort::Session & BucketInferenceToolBase::model |
( |
| ) |
const |
|
protected |
◆ runInference()
| StatusCode BucketInferenceToolBase::runInference |
( |
GraphRawData & |
graphData | ) |
const |
Default ONNX run for GNN case: inputs {"features","edge_index"} -> outputs {"output"}.
Definition at line 344 of file BucketInferenceToolBase.cxx.
345 std::vector<const char*> inputNames = {
"features",
"edge_index"};
◆ runNamedInference()
| StatusCode BucketInferenceToolBase::runNamedInference |
( |
GraphRawData & |
graphData, |
|
|
const std::vector< const char * > & |
inputNames, |
|
|
const std::vector< const char * > & |
outputNames |
|
) |
| const |
|
protected |
Generic named inference, for tools with different I/O conventions.
Definition at line 264 of file BucketInferenceToolBase.cxx.
269 if (!graphData.
graph) {
271 return StatusCode::FAILURE;
273 if (graphData.
graph->dataTensor.empty()) {
275 return StatusCode::FAILURE;
282 if (!graphData.
graph->dataTensor.empty()) {
283 const auto& featureTensor = graphData.
graph->dataTensor[0];
284 auto featShape = featureTensor.GetTensorTypeAndShapeInfo().GetShape();
286 << (featShape.size()>1 ? (
"," +
std::to_string(featShape[1])) :
"")
287 << (featShape.size()>2 ? (
"," +
std::to_string(featShape[2])) :
"") <<
"]");
289 float* featData =
const_cast<Ort::Value&
>(featureTensor).GetTensorMutableData<float>();
290 const size_t totalElements = featureTensor.GetTensorTypeAndShapeInfo().GetElementCount();
291 ATH_MSG_DEBUG(
"Features tensor total elements: " << totalElements);
294 const size_t debugElements =
std::min(totalElements,
static_cast<size_t>(60));
295 for (
size_t i = 0;
i < debugElements;
i += 6) {
296 if (
i + 5 < totalElements) {
298 <<
"x=" << featData[
i+0] <<
", "
299 <<
"y=" << featData[
i+1] <<
", "
300 <<
"z=" << featData[
i+2] <<
", "
301 <<
"layers=" << featData[
i+3] <<
", "
302 <<
"nSp=" << featData[
i+4] <<
", "
303 <<
"bucketSize=" << featData[
i+5]);
310 Ort::RunOptions run_options;
311 run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
313 std::vector<Ort::Value>
outputs =
314 model().Run(run_options,
316 graphData.
graph->dataTensor.data(),
317 graphData.
graph->dataTensor.size(),
323 return StatusCode::FAILURE;
326 float* outData =
outputs[0].GetTensorMutableData<
float>();
327 const size_t outSize =
outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
330 std::span<float> preds(outData, outData + outSize);
331 for (
size_t i = 0;
i < outSize; ++
i) {
332 if (!std::isfinite(preds[
i])) {
333 ATH_MSG_WARNING(
"Non-finite prediction detected at " <<
i <<
" -> set to -100.");
339 graphData.
graph->dataTensor.emplace_back(std::move(
v));
341 return StatusCode::SUCCESS;
◆ setupModel()
| StatusCode BucketInferenceToolBase::setupModel |
( |
| ) |
|
|
protected |
◆ m_debugDumpFirstNEdges
| Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNEdges {this, "DebugDumpFirstNEdges", 12} |
|
protected |
◆ m_debugDumpFirstNNodes
| Gaudi::Property<unsigned int> MuonML::BucketInferenceToolBase::m_debugDumpFirstNNodes {this, "DebugDumpFirstNNodes", 5} |
|
protected |
◆ m_geoCtxKey
◆ m_maxAbsDz
| Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxAbsDz {this, "MaxAbsDz", 15000.0} |
|
protected |
◆ m_maxChamberDelta
| Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxChamberDelta {this, "MaxChamberDelta", 13} |
|
protected |
◆ m_maxDistXY
| Gaudi::Property<double> MuonML::BucketInferenceToolBase::m_maxDistXY {this, "MaxDistXY", 6800.0} |
|
protected |
◆ m_maxSectorDelta
| Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_maxSectorDelta {this, "MaxSectorDelta", 1} |
|
protected |
◆ m_minLayers
| Gaudi::Property<int> MuonML::BucketInferenceToolBase::m_minLayers {this, "MinLayersValid", 3} |
|
protected |
◆ m_onnxSessionTool
◆ m_readKey
◆ m_validateEdges
| Gaudi::Property<bool> MuonML::BucketInferenceToolBase::m_validateEdges {this, "ValidateEdges", true} |
|
protected |
The documentation for this class was generated from the following files:
NodeConnectVec_t spacePointsInBucket
Vector keeping track of how many space points are in each parsed bucket.
std::unique_ptr< InferenceGraph > graph
Pointer to the graph to be parsed to ONNX.
std::string to_string(const SectorProjector proj)
@ u
Enums for curvilinear frames.
EdgeCounterVec_t desEdges
Vect
const T * get(const ReadCondHandleKey< T > &key, const EventContext &ctx)
Convenience function to retrieve an object given a ReadCondHandleKey.
EdgeCounterVec_t srcEdges
Vector encoding the source index of the.
FeatureVec_t featureLeaves
Vector containing all features.
StatusCode initialize(bool used=true)
If this object is used as a property, then this should be called during the initialize phase.
EdgeCounterVec_t edgeIndexPacked
Packed edge index buffer (kept alive for ONNX tensors that reference it) This stores [srcEdges,...
void buildSparseEdges(const std::vector< NodeAux > &nodes, int minLayers, int maxChamberDelta, int maxSectorDelta, double maxDistXY, double maxAbsDz, std::vector< int64_t > &srcEdges, std::vector< int64_t > &dstEdges)
size_t packEdgeIndex(const std::vector< int64_t > &srcEdges, const std::vector< int64_t > &dstEdges, std::vector< int64_t > &edgeIndexPacked)
#define ATH_MSG_WARNING(x)
void buildNodesAndFeatures(const MuonR4::SpacePointContainer &buckets, const ActsTrk::GeometryContext &gctx, std::vector< NodeAux > &nodes, std::vector< float > &featuresLeaves, std::vector< int64_t > &spInBucket)
Build nodes + flat features (N,6) and number of SPs per kept bucket.