Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
FPGATrackSimGNNEdgeClassifierTool.cxx
Go to the documentation of this file.
1 // Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
4 
6 // AthAlgTool
7 
8 FPGATrackSimGNNEdgeClassifierTool::FPGATrackSimGNNEdgeClassifierTool(const std::string& algname, const std::string &name, const IInterface *ifc)
9  : AthAlgTool(algname, name, ifc) {}
10 
12 {
13  ATH_CHECK( m_GNNInferenceTool.retrieve() );
14  m_GNNInferenceTool->printModelInfo();
15  assert(m_gnnFeatureNamesVec.size() == m_gnnFeatureScalesVec.size());
16 
17  return StatusCode::SUCCESS;
18 }
19 
21 // Functions
22 
23 StatusCode FPGATrackSimGNNEdgeClassifierTool::scoreEdges(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits, std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
24 {
25  std::vector<float> edge_scores;
26 
27  std::vector<float> gNodeFeatures = getNodeFeatures(hits);
28  std::vector<int64_t> edgeList = getEdgeList(edges);
29  std::vector<float> gEdgeFeatures = getEdgeFeatures(edges);
30 
31  std::vector<Ort::Value> gInputTensor;
32  ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, gNodeFeatures, 0, hits.size()) );
33  ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, edgeList, 1, edges.size()) );
34  ATH_CHECK( m_GNNInferenceTool->addInput(gInputTensor, gEdgeFeatures, 2, edges.size()) );
35 
36  std::vector<float> gOutputData;
37  std::vector<Ort::Value> gOutputTensor;
38  ATH_CHECK( m_GNNInferenceTool->addOutput(gOutputTensor, edge_scores, 0, edges.size()) );
39 
40  ATH_CHECK( m_GNNInferenceTool->inference(gInputTensor, gOutputTensor) );
41  // apply sigmoid to the gnn output data
42  for(auto& v : edge_scores) {
43  v = 1.f / (1.f + std::exp(-v));
44  };
45 
46  for (size_t i = 0; i < edges.size(); i++) {
47  edges[i]->setEdgeScore(edge_scores[i]);
48  }
49 
50  return StatusCode::SUCCESS;
51 }
52 
53 std::vector<float> FPGATrackSimGNNEdgeClassifierTool::getNodeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNHit>> & hits)
54 {
55  std::vector<float> gNodeFeatures;
56 
57  for(const auto& hit : hits) {
58  std::map<std::string, float> features;
59  features["r"] = hit->getR();
60  features["phi"] = hit->getPhi();
61  features["z"] = hit->getZ();
62  features["eta"] = hit->getEta();
63  features["cluster_r_1"] = hit->getCluster1R();
64  features["cluster_phi_1"] = hit->getCluster1Phi();
65  features["cluster_z_1"] = hit->getCluster1Z();
66  features["cluster_eta_1"] = hit->getCluster1Eta();
67  features["cluster_r_2"] = hit->getCluster2R();
68  features["cluster_phi_2"] = hit->getCluster2Phi();
69  features["cluster_z_2"] = hit->getCluster2Z();
70  features["cluster_eta_2"] = hit->getCluster2Eta();
71 
72  for(size_t i = 0; i < m_gnnFeatureNamesVec.size(); i++){
73  gNodeFeatures.push_back(
75  }
76  }
77 
78  return gNodeFeatures;
79 }
80 
81 std::vector<int64_t> FPGATrackSimGNNEdgeClassifierTool::getEdgeList(const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
82 {
83  std::vector<int64_t> rowIndices;
84  std::vector<int64_t> colIndices;
85  std::vector<int64_t> edgesList(edges.size() * 2);
86 
87  for(const auto& edge : edges) {
88  rowIndices.push_back(edge->getEdgeIndex1());
89  colIndices.push_back(edge->getEdgeIndex2());
90  }
91 
92  std::copy(rowIndices.begin(), rowIndices.end(), edgesList.begin());
93  std::copy(colIndices.begin(), colIndices.end(), edgesList.begin() + edges.size());
94 
95  return edgesList;
96 }
97 
98 std::vector<float> FPGATrackSimGNNEdgeClassifierTool::getEdgeFeatures(const std::vector<std::shared_ptr<FPGATrackSimGNNEdge>> & edges)
99 {
100  std::vector<float> gEdgeFeatures;
101 
102  for(const auto& edge : edges) {
103  gEdgeFeatures.push_back(edge->getEdgeDR());
104  gEdgeFeatures.push_back(edge->getEdgeDPhi());
105  gEdgeFeatures.push_back(edge->getEdgeDZ());
106  gEdgeFeatures.push_back(edge->getEdgeDEta());
107  gEdgeFeatures.push_back(edge->getEdgePhiSlope());
108  gEdgeFeatures.push_back(edge->getEdgeRPhiSlope());
109  }
110 
111  return gEdgeFeatures;
112 }
FPGATrackSimGNNEdgeClassifierTool::FPGATrackSimGNNEdgeClassifierTool
FPGATrackSimGNNEdgeClassifierTool(const std::string &, const std::string &, const IInterface *)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:8
FPGATrackSimGNNEdgeClassifierTool::scoreEdges
virtual StatusCode scoreEdges(const std::vector< std::shared_ptr< FPGATrackSimGNNHit >> &hits, std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:23
getMenu.algname
algname
Definition: getMenu.py:54
TRTCalib_Extractor.hits
hits
Definition: TRTCalib_Extractor.py:35
FPGATrackSimGNNEdgeClassifierTool::m_gnnFeatureNamesVec
StringArrayProperty m_gnnFeatureNamesVec
Definition: FPGATrackSimGNNEdgeClassifierTool.h:55
drawFromPickle.exp
exp
Definition: drawFromPickle.py:36
FPGATrackSimGNNEdgeClassifierTool.h
Implements edge classification by inferencing on an Interaction Network GNN.
FPGATrackSimGNNEdgeClassifierTool::m_gnnFeatureScalesVec
FloatArrayProperty m_gnnFeatureScalesVec
Definition: FPGATrackSimGNNEdgeClassifierTool.h:59
lumiFormat.i
int i
Definition: lumiFormat.py:85
FPGATrackSimGNNEdgeClassifierTool::getNodeFeatures
std::vector< float > getNodeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNHit >> &hits)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:53
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
FPGATrackSimGNNEdgeClassifierTool::getEdgeFeatures
std::vector< float > getEdgeFeatures(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:98
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
FPGATrackSimGNNEdgeClassifierTool::getEdgeList
std::vector< int64_t > getEdgeList(const std::vector< std::shared_ptr< FPGATrackSimGNNEdge >> &edges)
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:81
FPGATrackSimGNNEdgeClassifierTool::initialize
virtual StatusCode initialize() override
Definition: FPGATrackSimGNNEdgeClassifierTool.cxx:11
FPGATrackSimGNNEdgeClassifierTool::m_GNNInferenceTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_GNNInferenceTool
Definition: FPGATrackSimGNNEdgeClassifierTool.h:46
python.PyAthena.v
v
Definition: PyAthena.py:154
calibdata.copy
bool copy
Definition: calibdata.py:27
AthAlgTool
Definition: AthAlgTool.h:26