ATLAS Offline Software
Trigger
EFTracking
FPGATrackSim
FPGATrackSimGNN
src
FPGATrackSimGNNEdgeClassifierTool.h
Go to the documentation of this file.
1
// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3
#ifndef FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
4
#define FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
5
17
#include "
AthenaBaseComps/AthAlgTool.h
"
18
19
#include "
FPGATrackSimObjects/FPGATrackSimGNNHit.h
"
20
#include "
FPGATrackSimObjects/FPGATrackSimGNNEdge.h
"
21
22
#include "
AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h
"
23
#include <onnxruntime_cxx_api.h>
24
25
class
FPGATrackSimGNNEdgeClassifierTool
:
public
AthAlgTool
26
{
27
public
:
28
30
// AthAlgTool
31
32
FPGATrackSimGNNEdgeClassifierTool
(
const
std::string&,
const
std::string&,
const
IInterface*);
33
34
virtual
StatusCode
initialize
()
override
;
35
37
// Functions
38
39
virtual
StatusCode
scoreEdges
(
const
std::vector<std::shared_ptr<FPGATrackSimGNNHit>> &
hits
, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
40
int
regionNum
()
const
{
return
m_regionNum
; }
41
42
private
:
43
45
// Handles
46
47
ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool>
m_GNNInferenceTool
{
this
,
"GNNInferenceTool"
,
"AthOnnx::OnnxRuntimeInferenceTool"
};
48
Gaudi::Property<int>
m_regionNum
{
this
,
"regionNum"
, -1,
"Region number for this GNNEdgeClassifierTool"
};
49
51
// Helpers
52
53
std::vector<float>
getNodeFeatures
(
const
std::vector<std::shared_ptr<FPGATrackSimGNNHit>> &
hits
);
54
std::vector<int64_t>
getEdgeList
(
const
std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges);
55
std::vector<float>
getEdgeFeatures
(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges,
const
std::vector<float> & gNodeFeatures);
56
void
computeEdgeFeatures
(std::shared_ptr<FPGATrackSimGNNEdge>& edge,
const
int
& hit1_index,
const
int
& hit2_index,
const
std::vector<float> & gNodeFeatures);
57
58
StringArrayProperty
m_gnnFeatureNamesVec
{
59
this
,
"GNNFeatureNames"
,
60
{
"r"
,
"phi"
,
"z"
,
"eta"
,
"cluster_r_1"
,
"cluster_phi_1"
,
"cluster_z_1"
,
"cluster_eta_1"
,
"cluster_r_2"
,
"cluster_phi_2"
,
"cluster_z_2"
,
"cluster_eta_2"
},
61
"Feature names for the GNN model"
};
62
FloatArrayProperty
m_gnnFeatureScalesVec
{
63
this
,
"GNNFeatureScales"
,
64
{1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0},
65
"Feature scales for the GNN model"
};
66
};
67
68
69
#endif // FPGATRACKSIMGNNEDGECLASSIFIERTOOL_H
FPGATrackSimGNNEdgeClassifierTool::FPGATrackSimGNNEdgeClassifierTool
FPGATrackSimGNNEdgeClassifierTool(const std::string &, const std::string &, const IInterface *)
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:10
FPGATrackSimGNNEdgeClassifierTool::scoreEdges
virtual StatusCode scoreEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit >> &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:25
TRTCalib_Extractor.hits
hits
Definition:
TRTCalib_Extractor.py:35
IOnnxRuntimeInferenceTool.h
FPGATrackSimGNNEdgeClassifierTool::computeEdgeFeatures
void computeEdgeFeatures(std::shared_ptr< FPGATrackSimGNNEdge > &edge, const int &hit1_index, const int &hit2_index, const std::vector< float > &gNodeFeatures)
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:118
FPGATrackSimGNNEdgeClassifierTool::regionNum
int regionNum() const
Definition:
FPGATrackSimGNNEdgeClassifierTool.h:40
FPGATrackSimGNNEdgeClassifierTool::m_gnnFeatureNamesVec
StringArrayProperty m_gnnFeatureNamesVec
Definition:
FPGATrackSimGNNEdgeClassifierTool.h:58
FPGATrackSimGNNEdge.h
FPGATrackSim-specific class to represent an edge as a connection between two hits in the detector use...
FPGATrackSimGNNEdgeClassifierTool::m_gnnFeatureScalesVec
FloatArrayProperty m_gnnFeatureScalesVec
Definition:
FPGATrackSimGNNEdgeClassifierTool.h:62
FPGATrackSimGNNEdgeClassifierTool::getNodeFeatures
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit >> &hits)
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:55
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition:
PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
FPGATrackSimGNNEdgeClassifierTool
Definition:
FPGATrackSimGNNEdgeClassifierTool.h:26
FPGATrackSimGNNHit.h
FPGATrackSim-specific class to represent an hit in the detector used for GNN pattern recognition.
FPGATrackSimGNNEdgeClassifierTool::getEdgeList
std::vector< int64_t > getEdgeList(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:83
FPGATrackSimGNNEdgeClassifierTool::initialize
virtual StatusCode initialize() override
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:13
FPGATrackSimGNNEdgeClassifierTool::m_regionNum
Gaudi::Property< int > m_regionNum
Definition:
FPGATrackSimGNNEdgeClassifierTool.h:48
FPGATrackSimGNNEdgeClassifierTool::m_GNNInferenceTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_GNNInferenceTool
Definition:
FPGATrackSimGNNEdgeClassifierTool.h:47
FPGATrackSimGNNEdgeClassifierTool::getEdgeFeatures
std::vector< float > getEdgeFeatures(std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges, const std::vector< float > &gNodeFeatures)
Definition:
FPGATrackSimGNNEdgeClassifierTool.cxx:100
AthAlgTool
Definition:
AthAlgTool.h:26
Generated on Tue Sep 30 2025 21:10:05 for ATLAS Offline Software by
1.8.18