3#ifndef FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
4#define FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
23#include <onnxruntime_cxx_api.h>
39 virtual StatusCode
scoreEdges(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
47 ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool>
m_GNNInferenceTool {
this,
"GNNInferenceTool",
"AthOnnx::OnnxRuntimeInferenceTool"};
48 Gaudi::Property<int>
m_regionNum{
this,
"regionNum", -1,
"Region number for this GNNEdgeClassifierTool"};
49 Gaudi::Property<bool>
m_doGNNPixelSeeding {
this,
"doGNNPixelSeeding",
false,
"Flag to configure for GNN Pixel Seeding" };
54 std::vector<float>
getNodeFeatures(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits);
55 std::vector<int64_t>
getEdgeList(
const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
56 std::vector<float>
getEdgeFeatures(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges,
const std::vector<float> & gNodeFeatures);
57 void computeEdgeFeatures(std::shared_ptr<FPGATrackSimGNNEdge>& edge,
const int& hit1_index,
const int& hit2_index,
const std::vector<float> & gNodeFeatures);
60 this,
"GNNFeatureNames",
61 {
"r",
"phi",
"z",
"eta",
"cluster_r_1",
"cluster_phi_1",
"cluster_z_1",
"cluster_eta_1",
"cluster_r_2",
"cluster_phi_2",
"cluster_z_2",
"cluster_eta_2"},
62 "Feature names for the GNN model"};
64 this,
"GNNFeatureScales",
65 {1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0},
66 "Feature scales for the GNN model"};
69 this,
"GNNFeatureNames",
70 {
"r",
"phi",
"z",
"eta"},
71 "Feature names for the GNN model"};
73 this,
"GNNFeatureScales",
74 {1000.0, 3.14159265359, 1000.0, 1.0},
75 "Feature scales for the GNN model"};
FPGATrackSim-specific class to represent an edge as a connection between two hits in the detector use...
FPGATrackSim-specific class to represent an hit in the detector used for GNN pattern recognition.