ATLAS Offline Software
Loading...
Searching...
No Matches
FPGATrackSimGNNEdgeClassifierTool.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
4
6
8// AthAlgTool
9
10FPGATrackSimGNNEdgeClassifierTool::FPGATrackSimGNNEdgeClassifierTool(const std::string& algname, const std::string &name, const IInterface *ifc)
11 : AthAlgTool(algname, name, ifc) {}
12
14{
15 ATH_CHECK( m_GNNInferenceTool.retrieve() );
16 m_GNNInferenceTool->printModelInfo();
17 assert(m_gnnFeatureNamesVec.size() == m_gnnFeatureScalesVec.size());
18
19 return StatusCode::SUCCESS;
20}
21
23// Functions
24
25StatusCode FPGATrackSimGNNEdgeClassifierTool::scoreEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
26{
27 std::vector<float> edge_scores;
28
29 std::vector<float> gNodeFeatures = getNodeFeatures(hits);
30 std::vector<int64_t> edgeList = getEdgeList(edges);
31 std::vector<float> gEdgeFeatures = getEdgeFeatures(edges, gNodeFeatures);
32
33 std::vector<Ort::Value> gInputTensor;
34 ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, gNodeFeatures, 0, hits.size()) );
35 ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, edgeList, 1, edges.size()) );
36 ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, gEdgeFeatures, 2, edges.size()) );
37
38 std::vector<float> gOutputData;
39 std::vector<Ort::Value> gOutputTensor;
40 ATH_CHECK( m_GNNInferenceTool->addOutput(gOutputTensor, edge_scores, 0, edges.size()) );
41
42 ATH_CHECK( m_GNNInferenceTool->inference(gInputTensor, gOutputTensor) );
43 // apply sigmoid to the gnn output data
44 for(auto& v : edge_scores) {
45 v = 1.f / (1.f + std::exp(-v));
46 };
47
48 for (size_t i = 0; i < edges.size(); i++) {
49 edges[i]->setEdgeScore(edge_scores[i]);
50 }
51
52 return StatusCode::SUCCESS;
53}
54
55std::vector<float> FPGATrackSimGNNEdgeClassifierTool::getNodeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits)
56{
57 std::vector<float> gNodeFeatures;
58
59 for(const auto& hit : hits) {
60 std::map<std::string, float> features;
61 features["r"] = hit->getR();
62 features["phi"] = hit->getPhi();
63 features["z"] = hit->getZ();
64 features["eta"] = hit->getEta();
65 features["cluster_r_1"] = hit->getCluster1R();
66 features["cluster_phi_1"] = hit->getCluster1Phi();
67 features["cluster_z_1"] = hit->getCluster1Z();
68 features["cluster_eta_1"] = hit->getCluster1Eta();
69 features["cluster_r_2"] = hit->getCluster2R();
70 features["cluster_phi_2"] = hit->getCluster2Phi();
71 features["cluster_z_2"] = hit->getCluster2Z();
72 features["cluster_eta_2"] = hit->getCluster2Eta();
73
74 for(size_t i = 0; i < m_gnnFeatureNamesVec.size(); i++){
75 gNodeFeatures.push_back(
77 }
78 }
79
80 return gNodeFeatures;
81}
82
83std::vector<int64_t> FPGATrackSimGNNEdgeClassifierTool::getEdgeList(const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
84{
85 std::vector<int64_t> rowIndices;
86 std::vector<int64_t> colIndices;
87 std::vector<int64_t> edgesList(edges.size() * 2);
88
89 for(const auto& edge : edges) {
90 rowIndices.push_back(edge->getEdgeIndex1());
91 colIndices.push_back(edge->getEdgeIndex2());
92 }
93
94 std::copy(rowIndices.begin(), rowIndices.end(), edgesList.begin());
95 std::copy(colIndices.begin(), colIndices.end(), edgesList.begin() + edges.size());
96
97 return edgesList;
98}
99
100std::vector<float> FPGATrackSimGNNEdgeClassifierTool::getEdgeFeatures(std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges, const std::vector<float> & gNodeFeatures)
101{
102 std::vector<float> gEdgeFeatures;
103
104 for (auto& edge : edges) {
105 computeEdgeFeatures(edge, edge->getEdgeIndex1(), edge->getEdgeIndex2(), gNodeFeatures);
106
107 gEdgeFeatures.push_back(edge->getEdgeDR());
108 gEdgeFeatures.push_back(edge->getEdgeDPhi());
109 gEdgeFeatures.push_back(edge->getEdgeDZ());
110 gEdgeFeatures.push_back(edge->getEdgeDEta());
111 gEdgeFeatures.push_back(edge->getEdgePhiSlope());
112 gEdgeFeatures.push_back(edge->getEdgeRPhiSlope());
113 }
114
115 return gEdgeFeatures;
116}
117
118void FPGATrackSimGNNEdgeClassifierTool::computeEdgeFeatures(std::shared_ptr<FPGATrackSimGNNEdge>& edge, const int& hit1_index, const int& hit2_index, const std::vector<float>& gNodeFeatures)
119{
120 size_t num_features = m_gnnFeatureNamesVec.size();
121
122 std::map<std::string, float> hit1_features;
123 std::map<std::string, float> hit2_features;
124
125 for(size_t i = 0; i < num_features; i++){
126 hit1_features[m_gnnFeatureNamesVec[i]] = gNodeFeatures[hit1_index * num_features + i];
127 hit2_features[m_gnnFeatureNamesVec[i]] = gNodeFeatures[hit2_index * num_features + i];
128 }
129
130 float deta = hit2_features["eta"] - hit1_features["eta"];
131 float dz = hit2_features["z"] - hit1_features["z"];
132 float dr = hit2_features["r"] - hit1_features["r"];
133 float dphi = P4Helpers::deltaPhi(hit2_features["phi"],hit1_features["phi"]);
134 float phislope = dr==0. ? 0. : dphi / dr;
135 float rphislope = 0.5 * (hit2_features["r"] + hit1_features["r"]) * phislope;
136
137 edge->setEdgeDR(dr);
138 edge->setEdgeDPhi(dphi);
139 edge->setEdgeDZ(dz);
140 edge->setEdgeDEta(deta);
141 edge->setEdgePhiSlope(phislope);
142 edge->setEdgeRPhiSlope(rphislope);
143}
#define ATH_CHECK
Evaluate an expression and check for errors.
Implements edge classification by inferencing on an Interaction Network GNN.
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
void computeEdgeFeatures(std::shared_ptr< FPGATrackSimGNNEdge > &edge, const int &hit1_index, const int &hit2_index, const std::vector< float > &gNodeFeatures)
std::vector< int64_t > getEdgeList(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_GNNInferenceTool
std::vector< float > getEdgeFeatures(std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges, const std::vector< float > &gNodeFeatures)
virtual StatusCode scoreEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge > > &edges)
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit > > &hits)
FPGATrackSimGNNEdgeClassifierTool(const std::string &, const std::string &, const IInterface *)
double deltaPhi(double phiA, double phiB)
delta Phi in range [-pi,pi[
Definition P4Helpers.h:34