ATLAS Offline Software
Loading...
Searching...
No Matches
FPGATrackSimGNNEdgeClassifierTool.h
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3#ifndef FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
4#define FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
5
16
18
21
23#include <onnxruntime_cxx_api.h>
24
26{
27 public:
28
30 // AthAlgTool
31
32 FPGATrackSimGNNEdgeClassifierTool(const std::string&, const std::string&, const IInterface*);
33
34 virtual StatusCode initialize() override;
35
37 // Functions
38
39 virtual StatusCode scoreEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
40 int regionNum() const { return m_regionNum; }
41
42 private:
43
45 // Handles
46
47 ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool> m_GNNInferenceTool {this, "GNNInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool"};
48 Gaudi::Property<int> m_regionNum{this, "regionNum", -1, "Region number for this GNNEdgeClassifierTool"};
49
51 // Helpers
52
53 std::vector<float> getNodeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits);
54 std::vector<int64_t> getEdgeList(const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
55 std::vector<float> getEdgeFeatures(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges, const std::vector<float> & gNodeFeatures);
56 void computeEdgeFeatures(std::shared_ptr<FPGATrackSimGNNEdge>& edge, const int& hit1_index, const int& hit2_index, const std::vector<float> & gNodeFeatures);
57
58 StringArrayProperty m_gnnFeatureNamesVec{
59 this, "GNNFeatureNames",
60 {"r", "phi", "z", "eta", "cluster_r_1", "cluster_phi_1", "cluster_z_1", "cluster_eta_1", "cluster_r_2", "cluster_phi_2", "cluster_z_2", "cluster_eta_2"},
61 "Feature names for the GNN model"};
62 FloatArrayProperty m_gnnFeatureScalesVec{
63 this, "GNNFeatureScales",
64 {1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0},
65 "Feature scales for the GNN model"};
66};
67
68
69#endif // FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
FPGATrackSim-specific class to represent an edge as a connection between two hits in the detector use...
FPGATrackSim-specific class to represent an hit in the detector used for GNN pattern recognition.
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
void computeEdgeFeatures(std::shared_ptr< FPGATrackSimGNNEdge > &edge, const int &hit1_index, const int &hit2_index, const std::vector< float > &gNodeFeatures)
std::vector< int64_t > getEdgeList(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_GNNInferenceTool
std::vector< float > getEdgeFeatures(std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges, const std::vector< float > &gNodeFeatures)
virtual StatusCode scoreEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits)
FPGATrackSimGNNEdgeClassifierTool(const std::string &, const std::string &, const IInterface *)