ATLAS Offline Software
FPGATrackSimGNNEdgeClassifierTool.h
Go to the documentation of this file.
1 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 #ifndef FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
4 #define FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
5 
18 
21 
23 #include <onnxruntime_cxx_api.h>
24 
26 {
27  public:
28 
30  // AthAlgTool
31 
32  FPGATrackSimGNNEdgeClassifierTool(const std::string&, const std::string&, const IInterface*);
33 
34  virtual StatusCode initialize() override;
35 
37  // Functions
38 
39  virtual StatusCode scoreEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
40  int regionNum() const { return m_regionNum; }
41 
42  private:
43 
45  // Handles
46 
47  ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool> m_GNNInferenceTool {this, "GNNInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool"};
48  Gaudi::Property<int> m_regionNum{this, "regionNum", -1, "Region number for this GNNEdgeClassifierTool"};
49 
51  // Helpers
52 
53  std::vector<float> getNodeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits);
54  std::vector<int64_t> getEdgeList(const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
55  std::vector<float> getEdgeFeatures(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges, const std::vector<float> & gNodeFeatures);
56  void computeEdgeFeatures(std::shared_ptr<FPGATrackSimGNNEdge>& edge, const int& hit1_index, const int& hit2_index, const std::vector<float> & gNodeFeatures);
57 
58  StringArrayProperty m_gnnFeatureNamesVec{
59  this, "GNNFeatureNames",
60  {"r", "phi", "z", "eta", "cluster_r_1", "cluster_phi_1", "cluster_z_1", "cluster_eta_1", "cluster_r_2", "cluster_phi_2", "cluster_z_2", "cluster_eta_2"},
61  "Feature names for the GNN model"};
62  FloatArrayProperty m_gnnFeatureScalesVec{
63  this, "GNNFeatureScales",
64  {1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0},
65  "Feature scales for the GNN model"};
66 };
67 
68 
69 #endif // FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
FPGATrackSimGNNEdgeClassifierTool::FPGATrackSimGNNEdgeClassifierTool
FPGATrackSimGNNEdgeClassifierTool(const std::string &, const std::string &, const IInterface *)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:10
FPGATrackSimGNNEdgeClassifierTool::scoreEdges
virtual StatusCode scoreEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit >> &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:25
TRTCalib_Extractor.hits
hits
Definition: TRTCalib_Extractor.py:35
IOnnxRuntimeInferenceTool.h
FPGATrackSimGNNEdgeClassifierTool::computeEdgeFeatures
void computeEdgeFeatures(std::shared_ptr< FPGATrackSimGNNEdge > &edge, const int &hit1_index, const int &hit2_index, const std::vector< float > &gNodeFeatures)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:118
FPGATrackSimGNNEdgeClassifierTool::regionNum
int regionNum() const
Definition: FPGATrackSimGNNEdgeClassifierTool.h:40
FPGATrackSimGNNEdgeClassifierTool::m_gnnFeatureNamesVec
StringArrayProperty m_gnnFeatureNamesVec
Definition: FPGATrackSimGNNEdgeClassifierTool.h:58
FPGATrackSimGNNEdge.h
FPGATrackSim-specific class to represent an edge as a connection between two hits in the detector use...
FPGATrackSimGNNEdgeClassifierTool::m_gnnFeatureScalesVec
FloatArrayProperty m_gnnFeatureScalesVec
Definition: FPGATrackSimGNNEdgeClassifierTool.h:62
FPGATrackSimGNNEdgeClassifierTool::getNodeFeatures
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit >> &hits)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:55
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
FPGATrackSimGNNEdgeClassifierTool
Definition: FPGATrackSimGNNEdgeClassifierTool.h:26
FPGATrackSimGNNHit.h
FPGATrackSim-specific class to represent an hit in the detector used for GNN pattern recognition.
FPGATrackSimGNNEdgeClassifierTool::getEdgeList
std::vector< int64_t > getEdgeList(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:83
FPGATrackSimGNNEdgeClassifierTool::initialize
virtual StatusCode initialize() override
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:13
FPGATrackSimGNNEdgeClassifierTool::m_regionNum
Gaudi::Property< int > m_regionNum
Definition: FPGATrackSimGNNEdgeClassifierTool.h:48
FPGATrackSimGNNEdgeClassifierTool::m_GNNInferenceTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_GNNInferenceTool
Definition: FPGATrackSimGNNEdgeClassifierTool.h:47
FPGATrackSimGNNEdgeClassifierTool::getEdgeFeatures
std::vector< float > getEdgeFeatures(std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges, const std::vector< float > &gNodeFeatures)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:100
AthAlgTool
Definition: AthAlgTool.h:26