ATLAS Offline Software
Loading...
Searching...
No Matches
SiGNNTrackFinderTool.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef SiGNNTrackFinder_H
6#define SiGNNTrackFinder_H
7
8// System include(s).
9#include <list>
10#include <iostream>
11#include <memory>
12
16
17// ONNX Runtime include(s).
19#include <onnxruntime_cxx_api.h>
20
21class MsgStream;
22
23namespace InDet{
30
31 class SiGNNTrackFinderTool: public extends<AthAlgTool, IGNNTrackFinder>
32 {
33 public:
34 SiGNNTrackFinderTool(const std::string& type, const std::string& name, const IInterface* parent);
35 virtual StatusCode initialize() override;
36
38 // Main methods for local track finding asked by the ISiMLTrackFinder
40
48 virtual StatusCode getTracks(
49 const std::vector<const Trk::SpacePoint*>& spacepoints,
50 std::vector<std::vector<uint32_t> >& tracks,
51 std::unordered_map<int, std::unordered_map<int, float>>* edgeMap = nullptr) const override;
52
54 // Print internal tool parameters and status
56 virtual MsgStream& dump(MsgStream& out) const override;
57 virtual std::ostream& dump(std::ostream& out) const override;
58
59 protected:
60
64
66 StringProperty m_inputMLModuleDir{this, "inputMLModelDir", ""};
67 UnsignedIntegerProperty m_embeddingDim{this, "embeddingDim", 8};
68 FloatProperty m_rVal{this, "rVal", 0.12};
69 UnsignedIntegerProperty m_knnVal{this, "knnVal", 1000};
70 FloatProperty m_filterCut{this, "filterCut", 0.05};
71 FloatProperty m_ccCut{this, "ccCut", 0.01};
72 FloatProperty m_walkMin{this, "walkMin", 0.1};
73 FloatProperty m_walkMax{this, "walkMax", 0.6};
74
76 this, "EmbeddingFeatureNames",
77 "r, phi, z, cluster_x_1, cluster_y_1, cluster_z_1, cluster_x_2, cluster_y_2, cluster_z_2, count_1, charge_count_1, loc_eta_1, loc_phi_1, localDir0_1, localDir1_1, localDir2_1, lengthDir0_1, lengthDir1_1, lengthDir2_1, glob_eta_1, glob_phi_1, eta_angle_1, phi_angle_1, count_2, charge_count_2, loc_eta_2, loc_phi_2, localDir0_2, localDir1_2, localDir2_2, lengthDir0_2, lengthDir1_2, lengthDir2_2, glob_eta_2, glob_phi_2, eta_angle_2, phi_angle_2",
78 "Feature names for the Embedding model"};
80 this, "EmbeddingFeatureScales",
81 "1000, 3.14, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14",
82 "Feature scales for the Embedding model"};
83
84 StringProperty m_filterFeatureNames{
85 this, "FilterFeatureNames",
86 "r, phi, z, cluster_x_1, cluster_y_1, cluster_z_1, cluster_x_2, cluster_y_2, cluster_z_2, count_1, charge_count_1, loc_eta_1, loc_phi_1, localDir0_1, localDir1_1, localDir2_1, lengthDir0_1, lengthDir1_1, lengthDir2_1, glob_eta_1, glob_phi_1, eta_angle_1, phi_angle_1, count_2, charge_count_2, loc_eta_2, loc_phi_2, localDir0_2, localDir1_2, localDir2_2, lengthDir0_2, lengthDir1_2, lengthDir2_2, glob_eta_2, glob_phi_2, eta_angle_2, phi_angle_2",
87 "Feature names for the Filtering model"};
88 StringProperty m_filterFeatureScales{
89 this, "FilterFeatureScales",
90 "1000, 3.14, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14",
91 "Feature scales for the Filtering model"};
92
93 StringProperty m_gnnFeatureNames{
94 this, "GNNFeatureNames",
95 "r, phi, z, eta, cluster_r_1, cluster_phi_1, cluster_z_1, cluster_eta_1, cluster_r_2, cluster_phi_2, cluster_z_2, cluster_eta_2",
96 "Feature names for the GNN model"};
97 StringProperty m_gnnFeatureScales{
98 this, "GNNFeatureScales",
99 "1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0",
100 "Feature scales for the GNN model"};
101
103 MsgStream& dumpevent (MsgStream& out) const;
104
105 private:
106 ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool {
107 this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool"
108 };
109 ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool {
110 this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool"
111 };
112 ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool {
113 this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool"
114 };
115 ToolHandle<ISpacepointFeatureTool> m_spacepointFeatureTool{
116 this, "SpacepointFeatureTool", "InDet::SpacepointFeatureTool"};
117
118 std::vector<std::string> m_embeddingFeatureNamesVec;
119 std::vector<float> m_embeddingFeatureScalesVec;
120 std::vector<std::string> m_filterFeatureNamesVec;
121 std::vector<float> m_filterFeatureScalesVec;
122 std::vector<std::string> m_gnnFeatureNamesVec;
123 std::vector<float> m_gnnFeatureScalesVec;
124
125 };
126
127 MsgStream& operator << (MsgStream& ,const SiGNNTrackFinderTool&);
128 std::ostream& operator << (std::ostream&,const SiGNNTrackFinderTool&);
129
130}
131
132#endif
InDet::SiGNNTrackFinderTool is a tool that produces track candidates with graph neural networks-based...
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks, std::unordered_map< int, std::unordered_map< int, float > > *edgeMap=nullptr) const override
Get track candidates from a list of space points.
std::vector< float > m_gnnFeatureScalesVec
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
virtual StatusCode initialize() override
UnsignedIntegerProperty m_knnVal
UnsignedIntegerProperty m_embeddingDim
std::vector< float > m_filterFeatureScalesVec
SiGNNTrackFinderTool(const SiGNNTrackFinderTool &)=delete
std::vector< std::string > m_embeddingFeatureNamesVec
MsgStream & dumpevent(MsgStream &out) const
SiGNNTrackFinderTool & operator=(const SiGNNTrackFinderTool &)=delete
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
std::vector< std::string > m_gnnFeatureNamesVec
std::vector< float > m_embeddingFeatureScalesVec
SiGNNTrackFinderTool(const std::string &type, const std::string &name, const IInterface *parent)
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
std::vector< std::string > m_filterFeatureNamesVec
Primary Vertex Finder.
MsgStream & operator<<(MsgStream &, const GNNTrackFinderTritonTool &)
-event-from-file