3#ifndef FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
4#define FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
23#include <onnxruntime_cxx_api.h>
39 virtual StatusCode
scoreEdges(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
47 ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool>
m_GNNInferenceTool {
this,
"GNNInferenceTool",
"AthOnnx::OnnxRuntimeInferenceTool"};
48 Gaudi::Property<int>
m_regionNum{
this,
"regionNum", -1,
"Region number for this GNNEdgeClassifierTool"};
53 std::vector<float>
getNodeFeatures(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits);
54 std::vector<int64_t>
getEdgeList(
const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
55 std::vector<float>
getEdgeFeatures(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges,
const std::vector<float> & gNodeFeatures);
56 void computeEdgeFeatures(std::shared_ptr<FPGATrackSimGNNEdge>& edge,
const int& hit1_index,
const int& hit2_index,
const std::vector<float> & gNodeFeatures);
59 this,
"GNNFeatureNames",
60 {
"r",
"phi",
"z",
"eta",
"cluster_r_1",
"cluster_phi_1",
"cluster_z_1",
"cluster_eta_1",
"cluster_r_2",
"cluster_phi_2",
"cluster_z_2",
"cluster_eta_2"},
61 "Feature names for the GNN model"};
63 this,
"GNNFeatureScales",
64 {1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0},
65 "Feature scales for the GNN model"};
FPGATrackSim-specific class to represent an edge as a connection between two hits in the detector use...
FPGATrackSim-specific class to represent an hit in the detector used for GNN pattern recognition.