ATLAS Offline Software
SiGNNTrackFinderTool.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef SiGNNTrackFinder_H
6 #define SiGNNTrackFinder_H
7 
8 // System include(s).
9 #include <list>
10 #include <iostream>
11 #include <memory>
12 
15 #include "ISpacepointFeatureTool.h"
16 
17 // ONNX Runtime include(s).
19 #include <onnxruntime_cxx_api.h>
20 
21 class MsgStream;
22 
23 namespace InDet{
31  class SiGNNTrackFinderTool: public extends<AthAlgTool, IGNNTrackFinder>
32  {
33  public:
34  SiGNNTrackFinderTool(const std::string& type, const std::string& name, const IInterface* parent);
35  virtual StatusCode initialize() override;
36 
38  // Main methods for local track finding asked by the ISiMLTrackFinder
40 
48  virtual StatusCode getTracks(
49  const std::vector<const Trk::SpacePoint*>& spacepoints,
50  std::vector<std::vector<uint32_t> >& tracks) const override;
51 
53  // Print internal tool parameters and status
55  virtual MsgStream& dump(MsgStream& out) const override;
56  virtual std::ostream& dump(std::ostream& out) const override;
57 
58  protected:
59 
63 
65  StringProperty m_inputMLModuleDir{this, "inputMLModelDir", ""};
66  UnsignedIntegerProperty m_embeddingDim{this, "embeddingDim", 8};
67  FloatProperty m_rVal{this, "rVal", 0.12};
68  UnsignedIntegerProperty m_knnVal{this, "knnVal", 1000};
69  FloatProperty m_filterCut{this, "filterCut", 0.05};
70  FloatProperty m_ccCut{this, "ccCut", 0.01};
71  FloatProperty m_walkMin{this, "walkMin", 0.1};
72  FloatProperty m_walkMax{this, "walkMax", 0.6};
73 
74  StringProperty m_embeddingFeatureNames{
75  this, "EmbeddingFeatureNames",
76  "r, phi, z, cluster_x_1, cluster_y_1, cluster_z_1, cluster_x_2, cluster_y_2, cluster_z_2, count_1, charge_count_1, loc_eta_1, loc_phi_1, localDir0_1, localDir1_1, localDir2_1, lengthDir0_1, lengthDir1_1, lengthDir2_1, glob_eta_1, glob_phi_1, eta_angle_1, phi_angle_1, count_2, charge_count_2, loc_eta_2, loc_phi_2, localDir0_2, localDir1_2, localDir2_2, lengthDir0_2, lengthDir1_2, lengthDir2_2, glob_eta_2, glob_phi_2, eta_angle_2, phi_angle_2",
77  "Feature names for the Embedding model"};
78  StringProperty m_embeddingFeatureScales{
79  this, "EmbeddingFeatureScales",
80  "1000, 3.14, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14",
81  "Feature scales for the Embedding model"};
82 
83  StringProperty m_filterFeatureNames{
84  this, "FilterFeatureNames",
85  "r, phi, z, cluster_x_1, cluster_y_1, cluster_z_1, cluster_x_2, cluster_y_2, cluster_z_2, count_1, charge_count_1, loc_eta_1, loc_phi_1, localDir0_1, localDir1_1, localDir2_1, lengthDir0_1, lengthDir1_1, lengthDir2_1, glob_eta_1, glob_phi_1, eta_angle_1, phi_angle_1, count_2, charge_count_2, loc_eta_2, loc_phi_2, localDir0_2, localDir1_2, localDir2_2, lengthDir0_2, lengthDir1_2, lengthDir2_2, glob_eta_2, glob_phi_2, eta_angle_2, phi_angle_2",
86  "Feature names for the Filtering model"};
87  StringProperty m_filterFeatureScales{
88  this, "FilterFeatureScales",
89  "1000, 3.14, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14, 1, 1, 3.14, 3.14, 1, 1, 1, 1, 1, 1, 3.14, 3.14, 3.14, 3.14",
90  "Feature scales for the Filtering model"};
91 
92  StringProperty m_gnnFeatureNames{
93  this, "GNNFeatureNames",
94  "r, phi, z, eta, cluster_r_1, cluster_phi_1, cluster_z_1, cluster_eta_1, cluster_r_2, cluster_phi_2, cluster_z_2, cluster_eta_2",
95  "Feature names for the GNN model"};
96  StringProperty m_gnnFeatureScales{
97  this, "GNNFeatureScales",
98  "1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0, 1000.0, 3.14159265359, 1000.0, 1.0",
99  "Feature scales for the GNN model"};
100 
102  MsgStream& dumpevent (MsgStream& out) const;
103 
104  private:
105  ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool {
106  this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool"
107  };
108  ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool {
109  this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool"
110  };
111  ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool {
112  this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool"
113  };
114  ToolHandle<ISpacepointFeatureTool> m_spacepointFeatureTool{
115  this, "SpacepointFeatureTool", "InDet::SpacepointFeatureTool"};
116 
117  std::vector<std::string> m_embeddingFeatureNamesVec;
118  std::vector<float> m_embeddingFeatureScalesVec;
119  std::vector<std::string> m_filterFeatureNamesVec;
120  std::vector<float> m_filterFeatureScalesVec;
121  std::vector<std::string> m_gnnFeatureNamesVec;
122  std::vector<float> m_gnnFeatureScalesVec;
123 
124  };
125 
126  MsgStream& operator << (MsgStream& ,const SiGNNTrackFinderTool&);
127  std::ostream& operator << (std::ostream&,const SiGNNTrackFinderTool&);
128 
129 }
130 
131 #endif
InDet::SiGNNTrackFinderTool::m_embeddingFeatureNamesVec
std::vector< std::string > m_embeddingFeatureNamesVec
Definition: SiGNNTrackFinderTool.h:117
InDet::operator<<
MsgStream & operator<<(MsgStream &, const GNNTrackReaderTool &)
InDet::SiGNNTrackFinderTool::m_gnnFeatureScalesVec
std::vector< float > m_gnnFeatureScalesVec
Definition: SiGNNTrackFinderTool.h:122
InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool
SiGNNTrackFinderTool(const SiGNNTrackFinderTool &)=delete
InDet::SiGNNTrackFinderTool::initTrainedModels
void initTrainedModels()
InDet::SiGNNTrackFinderTool::dumpevent
MsgStream & dumpevent(MsgStream &out) const
Definition: SiGNNTrackFinderTool.cxx:73
IOnnxRuntimeInferenceTool.h
InDet
Primary Vertex Finder.
Definition: VP1ErrorUtils.h:36
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
InDet::SiGNNTrackFinderTool::operator=
SiGNNTrackFinderTool & operator=(const SiGNNTrackFinderTool &)=delete
InDet::SiGNNTrackFinderTool::m_filterFeatureScalesVec
std::vector< float > m_filterFeatureScalesVec
Definition: SiGNNTrackFinderTool.h:120
InDet::SiGNNTrackFinderTool::m_gnnSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
Definition: SiGNNTrackFinderTool.h:111
InDet::SiGNNTrackFinderTool::m_gnnFeatureNames
StringProperty m_gnnFeatureNames
Definition: SiGNNTrackFinderTool.h:92
InDet::SiGNNTrackFinderTool::m_walkMin
FloatProperty m_walkMin
Definition: SiGNNTrackFinderTool.h:71
InDet::SiGNNTrackFinderTool::m_filterCut
FloatProperty m_filterCut
Definition: SiGNNTrackFinderTool.h:69
InDet::SiGNNTrackFinderTool::m_embeddingFeatureNames
StringProperty m_embeddingFeatureNames
Definition: SiGNNTrackFinderTool.h:74
IGNNTrackFinder.h
InDet::SiGNNTrackFinderTool::m_inputMLModuleDir
StringProperty m_inputMLModuleDir
Definition: SiGNNTrackFinderTool.h:65
InDet::SiGNNTrackFinderTool::m_gnnFeatureScales
StringProperty m_gnnFeatureScales
Definition: SiGNNTrackFinderTool.h:96
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
InDet::SiGNNTrackFinderTool::m_embedSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
Definition: SiGNNTrackFinderTool.h:105
InDet::SiGNNTrackFinderTool::m_knnVal
UnsignedIntegerProperty m_knnVal
Definition: SiGNNTrackFinderTool.h:68
InDet::SiGNNTrackFinderTool::getTracks
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
Definition: SiGNNTrackFinderTool.cxx:84
test_pyathena.parent
parent
Definition: test_pyathena.py:15
InDet::SiGNNTrackFinderTool::m_filterSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
Definition: SiGNNTrackFinderTool.h:108
InDet::SiGNNTrackFinderTool::dump
virtual MsgStream & dump(MsgStream &out) const override
Definition: SiGNNTrackFinderTool.cxx:62
InDet::SiGNNTrackFinderTool::m_walkMax
FloatProperty m_walkMax
Definition: SiGNNTrackFinderTool.h:72
InDet::SiGNNTrackFinderTool::m_rVal
FloatProperty m_rVal
Definition: SiGNNTrackFinderTool.h:67
InDet::SiGNNTrackFinderTool::m_embeddingFeatureScales
StringProperty m_embeddingFeatureScales
Definition: SiGNNTrackFinderTool.h:78
InDet::SiGNNTrackFinderTool::m_filterFeatureNames
StringProperty m_filterFeatureNames
Definition: SiGNNTrackFinderTool.h:83
InDet::SiGNNTrackFinderTool::m_ccCut
FloatProperty m_ccCut
Definition: SiGNNTrackFinderTool.h:70
InDet::SiGNNTrackFinderTool::initialize
virtual StatusCode initialize() override
Definition: SiGNNTrackFinderTool.cxx:22
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
InDet::SiGNNTrackFinderTool::m_filterFeatureNamesVec
std::vector< std::string > m_filterFeatureNamesVec
Definition: SiGNNTrackFinderTool.h:119
InDet::SiGNNTrackFinderTool::m_filterFeatureScales
StringProperty m_filterFeatureScales
Definition: SiGNNTrackFinderTool.h:87
ISpacepointFeatureTool.h
InDet::SiGNNTrackFinderTool
InDet::SiGNNTrackFinderTool is a tool that produces track candidates with graph neural networks-based...
Definition: SiGNNTrackFinderTool.h:32
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool
SiGNNTrackFinderTool()=delete
InDet::SiGNNTrackFinderTool::m_embeddingDim
UnsignedIntegerProperty m_embeddingDim
Definition: SiGNNTrackFinderTool.h:66
InDet::SiGNNTrackFinderTool::m_embeddingFeatureScalesVec
std::vector< float > m_embeddingFeatureScalesVec
Definition: SiGNNTrackFinderTool.h:118
InDet::SiGNNTrackFinderTool::m_spacepointFeatureTool
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
Definition: SiGNNTrackFinderTool.h:114
InDet::SiGNNTrackFinderTool::m_gnnFeatureNamesVec
std::vector< std::string > m_gnnFeatureNamesVec
Definition: SiGNNTrackFinderTool.h:121