ATLAS Offline Software
SiGNNTrackFinderTool.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef SiGNNTrackFinder_H
6 #define SiGNNTrackFinder_H
7 
8 // System include(s).
9 #include <list>
10 #include <iostream>
11 #include <memory>
12 
15 
16 // ONNX Runtime include(s).
18 #include <onnxruntime_cxx_api.h>
19 
20 class MsgStream;
21 
22 namespace InDet{
30  class SiGNNTrackFinderTool: public extends<AthAlgTool, IGNNTrackFinder>
31  {
32  public:
33  SiGNNTrackFinderTool(const std::string& type, const std::string& name, const IInterface* parent);
34  virtual StatusCode initialize() override;
35 
37  // Main methods for local track finding asked by the ISiMLTrackFinder
39 
47  virtual StatusCode getTracks(
48  const std::vector<const Trk::SpacePoint*>& spacepoints,
49  std::vector<std::vector<uint32_t> >& tracks) const override;
50 
52  // Print internal tool parameters and status
54  virtual MsgStream& dump(MsgStream& out) const override;
55  virtual std::ostream& dump(std::ostream& out) const override;
56 
57  protected:
58 
62 
64  UnsignedIntegerProperty m_embeddingDim{this, "embeddingDim", 8};
65  FloatProperty m_rVal{this, "rVal", 1.7};
66  UnsignedIntegerProperty m_knnVal{this, "knnVal", 500};
67  FloatProperty m_filterCut{this, "filterCut", 0.21};
68  StringProperty m_inputMLModuleDir{this, "inputMLModelDir", ""};
69  BooleanProperty m_useCUDA {this, "UseCUDA", false, "Use CUDA"};
70 
72  MsgStream& dumpevent (MsgStream& out) const;
73 
74  private:
75  ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool {
76  this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool"
77  };
78  ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool {
79  this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool"
80  };
81  ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool {
82  this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool"
83  };
84 
85  };
86 
87  MsgStream& operator << (MsgStream& ,const SiGNNTrackFinderTool&);
88  std::ostream& operator << (std::ostream&,const SiGNNTrackFinderTool&);
89 
90 }
91 
92 #endif
InDet::operator<<
MsgStream & operator<<(MsgStream &, const GNNTrackReaderTool &)
InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool
SiGNNTrackFinderTool(const SiGNNTrackFinderTool &)=delete
InDet::SiGNNTrackFinderTool::initTrainedModels
void initTrainedModels()
InDet::SiGNNTrackFinderTool::dumpevent
MsgStream & dumpevent(MsgStream &out) const
Definition: SiGNNTrackFinderTool.cxx:38
InDet::SiGNNTrackFinderTool::m_useCUDA
BooleanProperty m_useCUDA
Definition: SiGNNTrackFinderTool.h:69
IOnnxRuntimeInferenceTool.h
InDet
DUMMY Primary Vertex Finder.
Definition: VP1ErrorUtils.h:36
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
InDet::SiGNNTrackFinderTool::operator=
SiGNNTrackFinderTool & operator=(const SiGNNTrackFinderTool &)=delete
InDet::SiGNNTrackFinderTool::m_gnnSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
Definition: SiGNNTrackFinderTool.h:81
InDet::SiGNNTrackFinderTool::m_filterCut
FloatProperty m_filterCut
Definition: SiGNNTrackFinderTool.h:67
IGNNTrackFinder.h
InDet::SiGNNTrackFinderTool::m_inputMLModuleDir
StringProperty m_inputMLModuleDir
Definition: SiGNNTrackFinderTool.h:68
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthAlgTool.h
InDet::SiGNNTrackFinderTool::m_embedSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
Definition: SiGNNTrackFinderTool.h:75
InDet::SiGNNTrackFinderTool::m_knnVal
UnsignedIntegerProperty m_knnVal
Definition: SiGNNTrackFinderTool.h:66
InDet::SiGNNTrackFinderTool::getTracks
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
Definition: SiGNNTrackFinderTool.cxx:49
test_pyathena.parent
parent
Definition: test_pyathena.py:15
InDet::SiGNNTrackFinderTool::m_filterSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
Definition: SiGNNTrackFinderTool.h:78
InDet::SiGNNTrackFinderTool::dump
virtual MsgStream & dump(MsgStream &out) const override
Definition: SiGNNTrackFinderTool.cxx:27
InDet::SiGNNTrackFinderTool::m_rVal
FloatProperty m_rVal
Definition: SiGNNTrackFinderTool.h:65
InDet::SiGNNTrackFinderTool::initialize
virtual StatusCode initialize() override
Definition: SiGNNTrackFinderTool.cxx:20
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
InDet::SiGNNTrackFinderTool
InDet::SiGNNTrackFinderTool is a tool that produces track candidates with graph neural networks-based...
Definition: SiGNNTrackFinderTool.h:31
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool
SiGNNTrackFinderTool()=delete
InDet::SiGNNTrackFinderTool::m_embeddingDim
UnsignedIntegerProperty m_embeddingDim
Definition: SiGNNTrackFinderTool.h:64