3#ifndef FPGATRACKSIMGNNGRAPHCONSTRUCTIONTOOL_H
4#define FPGATRACKSIMGNNGRAPHCONSTRUCTIONTOOL_H
26#include <onnxruntime_cxx_api.h>
32#include <unordered_map>
48 virtual StatusCode
getEdges(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits,
49 std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
56 ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool>
m_MLInferenceTool {
this,
"MLInferenceTool",
"AthOnnx::OnnxRuntimeInferenceTool"};
62 Gaudi::Property<std::string>
m_graphTool {
this,
"graphTool",
"",
"Tool for graph construction" };
63 Gaudi::Property<std::string>
m_moduleMapType {
this,
"moduleMapType",
"",
"Type for Module Map for graph construction" };
64 Gaudi::Property<std::string>
m_moduleMapFunc {
this,
"moduleMapFunc",
"",
"Function for Module Map for graph construction" };
65 Gaudi::Property<float>
m_moduleMapTol {
this,
"moduleMapTol", 0.0,
"Tolerance value for Module Map cut calculations" };
66 Gaudi::Property<float>
m_moduleMapRMSThresholdFactor {
this,
"moduleMapRMSThresholdFactor", 0.0,
"RMS Threshold value for Module Map cut calculations" };
67 Gaudi::Property<float>
m_metricLearningR {
this,
"metricLearningR", 0.0,
"Clustering radius for Metric Learning"};
68 Gaudi::Property<int>
m_metricLearningMaxN {
this,
"metricLearningMaxN", 1,
"Max number of neighbours for Metric Learning"};
113 return mid1 == other.mid1 &&
mid2 == other.mid2 &&
mid3 == other.mid3;
119 return std::hash<unsigned>()(k.mid1) ^
120 (std::hash<unsigned>()(k.mid2) << 1) ^
121 (std::hash<unsigned>()(k.mid3) << 2);
126 std::unordered_map<TripletKey, const ModuleMapConfig*, TripletKeyHash>
m_tripletMap;
133 void doModuleMap(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits,
134 std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
135 void getTripletEdges(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits,
136 std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
137 void getDoubletEdges(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits,
138 std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges,
141 const std::shared_ptr<FPGATrackSimGNNHit> & hit2,
144 const std::shared_ptr<FPGATrackSimGNNHit> & hit2,
145 const std::shared_ptr<FPGATrackSimGNNHit> & hit3,
151 void doMetricLearning(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
152 std::vector<float>
getNodeFeatures(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits);
153 std::vector<float>
embed(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits);
154 void doClustering(
const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges, std::vector<float> & gEmbedded);
157 this,
"MLFeatureNames",
159 "Feature names for the Metric Learning model"};
161 this,
"MLFeatureScales",
162 {1000.0, 3.14159265359, 1000.0},
163 "Feature scales for the Metric Learning model"};
FPGATrackSim-specific class to represent an edge as a connection between two hits in the detector use...
FPGATrackSim-specific class to represent an hit in the detector used for GNN pattern recognition.