ATLAS Offline Software
Loading...
Searching...
No Matches
FPGATrackSimGNNEdgeClassifierTool.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
4
6#include <cmath>
7
9// AthAlgTool
10
11FPGATrackSimGNNEdgeClassifierTool::FPGATrackSimGNNEdgeClassifierTool(const std::string& algname, const std::string &name, const IInterface *ifc)
12 : AthAlgTool(algname, name, ifc) {}
13
15{
16 ATH_CHECK( m_GNNInferenceTool.retrieve() );
17 m_GNNInferenceTool->printModelInfo();
18 assert(m_gnnFeatureNamesVec.size() == m_gnnFeatureScalesVec.size());
20
21 return StatusCode::SUCCESS;
22}
23
25// Functions
26
27StatusCode FPGATrackSimGNNEdgeClassifierTool::scoreEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
28{
29 std::vector<float> edge_scores;
30
31 std::vector<float> gNodeFeatures = getNodeFeatures(hits);
32 std::vector<int64_t> edgeList = getEdgeList(edges);
33 std::vector<float> gEdgeFeatures = getEdgeFeatures(edges, gNodeFeatures);
34
35 std::vector<Ort::Value> gInputTensor;
36 ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, gNodeFeatures, 0, hits.size()) );
37 ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, edgeList, 1, edges.size()) );
38 ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, gEdgeFeatures, 2, edges.size()) );
39
40 std::vector<float> gOutputData;
41 std::vector<Ort::Value> gOutputTensor;
42 ATH_CHECK( m_GNNInferenceTool->addOutput(gOutputTensor, edge_scores, 0, edges.size()) );
43
44 ATH_CHECK( m_GNNInferenceTool->inference(gInputTensor, gOutputTensor) );
45 // apply sigmoid to the gnn output data
46 for(auto& v : edge_scores) {
47 v = 1.f / (1.f + std::exp(-v));
48 };
49
50 for (size_t i = 0; i < edges.size(); i++) {
51 edges[i]->setEdgeScore(edge_scores[i]);
52 }
53
54 return StatusCode::SUCCESS;
55}
56
57std::vector<float> FPGATrackSimGNNEdgeClassifierTool::getNodeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits)
58{
59 std::vector<float> gNodeFeatures;
60
61 // For GNN model, we use phi-folding to [2pi/16,3pi/16], so all hit.phi needs to be rotated to the right region
62 const float phi_binSize = M_PI / 16.0;
63 int phiBin = m_regionNum & 0x1f;
64 float regionMin = phi_binSize * phiBin;
65 float referenceMin = 2.0 * M_PI / 16.0;
66 float deltaPhi = referenceMin - regionMin;
67
68 // For GNN model, I am going to assume eta symmetry for testing purposes, so negative eta should be flipped to positive eta
69 int etaSide = (m_regionNum >> 5) & 0x1; // 1 is positive side, 0 negative side
70 bool flipEta = (etaSide == 0);
71
72 for(const auto& hit : hits) {
73 std::map<std::string, float> features;
74 features["r"] = hit->getR();
75 features["phi"] = std::remainder(hit->getPhi() + deltaPhi, 2*M_PI); // Add the deltaPhi to shift the phi to the trained-region
76 features["z"] = flipEta ? -hit->getZ() : hit->getZ();
77 features["eta"] = flipEta ? -hit->getEta() : hit->getEta();
78 if (m_doGNNPixelSeeding) { // Do not use cluster features for pixelOnly
79 for(size_t i = 0; i < m_gnnFeatureNamesVec_pixelOnly.size(); i++){
80 gNodeFeatures.push_back(
82 }
83 }
84 else { // Use cluster features and the standard feature vectors and scales
85 features["cluster_r_1"] = hit->getCluster1R();
86 features["cluster_phi_1"] = std::remainder(hit->getCluster1Phi() + deltaPhi, 2*M_PI);
87 features["cluster_z_1"] = flipEta ? -hit->getCluster1Z() : hit->getCluster1Z();
88 features["cluster_eta_1"] = flipEta ? -hit->getCluster1Eta() : hit->getCluster1Eta();
89 features["cluster_r_2"] = hit->getCluster2R();
90 features["cluster_phi_2"] = std::remainder(hit->getCluster2Phi() + deltaPhi, 2*M_PI);
91 features["cluster_z_2"] = flipEta ? -hit->getCluster2Z() : hit->getCluster2Z();
92 features["cluster_eta_2"] = flipEta ? -hit->getCluster2Eta() : hit->getCluster2Eta();
93
94 for(size_t i = 0; i < m_gnnFeatureNamesVec.size(); i++){
95 gNodeFeatures.push_back(
97 }
98 }
99 }
100
101 return gNodeFeatures;
102}
103
104std::vector<int64_t> FPGATrackSimGNNEdgeClassifierTool::getEdgeList(const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
105{
106 std::vector<int64_t> rowIndices;
107 std::vector<int64_t> colIndices;
108 std::vector<int64_t> edgesList(edges.size() * 2);
109
110 for(const auto& edge : edges) {
111 rowIndices.push_back(edge->getEdgeIndex1());
112 colIndices.push_back(edge->getEdgeIndex2());
113 }
114
115 std::copy(rowIndices.begin(), rowIndices.end(), edgesList.begin());
116 std::copy(colIndices.begin(), colIndices.end(), edgesList.begin() + edges.size());
117
118 return edgesList;
119}
120
121std::vector<float> FPGATrackSimGNNEdgeClassifierTool::getEdgeFeatures(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges, const std::vector<float> & gNodeFeatures)
122{
123 std::vector<float> gEdgeFeatures;
124
125 for (auto& edge : edges) {
126 computeEdgeFeatures(edge, edge->getEdgeIndex1(), edge->getEdgeIndex2(), gNodeFeatures);
127
128 gEdgeFeatures.push_back(edge->getEdgeDR());
129 gEdgeFeatures.push_back(edge->getEdgeDPhi());
130 gEdgeFeatures.push_back(edge->getEdgeDZ());
131 gEdgeFeatures.push_back(edge->getEdgeDEta());
132 gEdgeFeatures.push_back(edge->getEdgePhiSlope());
133 gEdgeFeatures.push_back(edge->getEdgeRPhiSlope());
134 }
135
136 return gEdgeFeatures;
137}
138
139void FPGATrackSimGNNEdgeClassifierTool::computeEdgeFeatures(std::shared_ptr<FPGATrackSimGNNEdge>& edge,const int& hit1_index, const int& hit2_index,const std::vector<float>& gNodeFeatures)
140{
141 size_t num_features = m_doGNNPixelSeeding ? m_gnnFeatureNamesVec_pixelOnly.size() : m_gnnFeatureNamesVec.size();
142
143 std::map<std::string, float> hit1_features;
144 std::map<std::string, float> hit2_features;
145
147 // Fill only pixel-only features
148 for (size_t i = 0; i < m_gnnFeatureNamesVec_pixelOnly.size(); i++) {
149 hit1_features[m_gnnFeatureNamesVec_pixelOnly[i]] = gNodeFeatures[hit1_index * num_features + i];
150 hit2_features[m_gnnFeatureNamesVec_pixelOnly[i]] = gNodeFeatures[hit2_index * num_features + i];
151 }
152 } else {
153 // Fill full feature set (including clusters)
154 for (size_t i = 0; i < m_gnnFeatureNamesVec.size(); i++) {
155 hit1_features[m_gnnFeatureNamesVec[i]] = gNodeFeatures[hit1_index * num_features + i];
156 hit2_features[m_gnnFeatureNamesVec[i]] = gNodeFeatures[hit2_index * num_features + i];
157 }
158 }
159
160 // Now compute differences using only the common "physics" features
161 float deta = hit2_features["eta"] - hit1_features["eta"];
162 float dz = hit2_features["z"] - hit1_features["z"];
163 float dr = hit2_features["r"] - hit1_features["r"];
164 float dphi = P4Helpers::deltaPhi(hit2_features["phi"], hit1_features["phi"]);
165 float phislope = dr == 0. ? 0. : dphi / dr;
166 float rphislope = 0.5 * (hit2_features["r"] + hit1_features["r"]) * phislope;
167
168 edge->setEdgeDR(dr);
169 edge->setEdgeDPhi(dphi);
170 edge->setEdgeDZ(dz);
171 edge->setEdgeDEta(deta);
172 edge->setEdgePhiSlope(phislope);
173 edge->setEdgeRPhiSlope(rphislope);
174}
#define M_PI
Scalar deltaPhi(const MatrixBase< Derived > &vec) const
#define ATH_CHECK
Evaluate an expression and check for errors.
Implements edge classification by inferencing on an Interaction Network GNN.
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
void computeEdgeFeatures(std::shared_ptr< FPGATrackSimGNNEdge > &edge, const int &hit1_index, const int &hit2_index, const std::vector< float > &gNodeFeatures)
std::vector< int64_t > getEdgeList(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_GNNInferenceTool
std::vector< float > getEdgeFeatures(std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges, const std::vector< float > &gNodeFeatures)
virtual StatusCode scoreEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits)
FPGATrackSimGNNEdgeClassifierTool(const std::string &, const std::string &, const IInterface *)
double deltaPhi(double phiA, double phiB)
delta Phi in range [-pi,pi[
Definition P4Helpers.h:34