ATLAS Offline Software
Public Member Functions | Protected Member Functions | List of all members
InDet::SiGNNTrackFinderTool Class Reference

InDet::SiGNNTrackFinderTool is a tool that produces track candidates with graph neural networks-based pipeline using 3D space points as inputs. More...

#include <SiGNNTrackFinderTool.h>

Inheritance diagram for InDet::SiGNNTrackFinderTool:
Collaboration diagram for InDet::SiGNNTrackFinderTool:

Public Member Functions

 SiGNNTrackFinderTool (const std::string &type, const std::string &name, const IInterface *parent)
 
virtual StatusCode initialize () override
 
virtual StatusCode getTracks (const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
 Get track candidates from a list of space points. More...
 
virtual MsgStream & dump (MsgStream &out) const override
 
virtual std::ostream & dump (std::ostream &out) const override
 

Protected Member Functions

 SiGNNTrackFinderTool ()=delete
 
 SiGNNTrackFinderTool (const SiGNNTrackFinderTool &)=delete
 
SiGNNTrackFinderTooloperator= (const SiGNNTrackFinderTool &)=delete
 

Exa.TrkX pipeline configurations, which will not be changed after construction

UnsignedIntegerProperty m_embeddingDim {this, "embeddingDim", 8}
 
FloatProperty m_rVal {this, "rVal", 1.7}
 
UnsignedIntegerProperty m_knnVal {this, "knnVal", 500}
 
FloatProperty m_filterCut {this, "filterCut", 0.21}
 
StringProperty m_inputMLModuleDir {this, "inputMLModelDir", ""}
 
ToolHandle< AthOnnx::IOnnxRuntimeInferenceToolm_embedSessionTool
 
ToolHandle< AthOnnx::IOnnxRuntimeInferenceToolm_filterSessionTool
 
ToolHandle< AthOnnx::IOnnxRuntimeInferenceToolm_gnnSessionTool
 
void initTrainedModels ()
 
MsgStream & dumpevent (MsgStream &out) const
 

Detailed Description

InDet::SiGNNTrackFinderTool is a tool that produces track candidates with graph neural networks-based pipeline using 3D space points as inputs.

Author
xiang.nosp@m.yang.nosp@m..ju@c.nosp@m.ern..nosp@m.ch

Definition at line 30 of file SiGNNTrackFinderTool.h.

Constructor & Destructor Documentation

◆ SiGNNTrackFinderTool() [1/3]

InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool ( const std::string &  type,
const std::string &  name,
const IInterface *  parent 
)

Definition at line 13 of file SiGNNTrackFinderTool.cxx.

14  :
15  base_class(type, name, parent)
16  {
17  declareInterface<IGNNTrackFinder>(this);
18  }

◆ SiGNNTrackFinderTool() [2/3]

InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool ( )
protecteddelete

◆ SiGNNTrackFinderTool() [3/3]

InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool ( const SiGNNTrackFinderTool )
protecteddelete

Member Function Documentation

◆ dump() [1/2]

MsgStream & InDet::SiGNNTrackFinderTool::dump ( MsgStream &  out) const
overridevirtual

Definition at line 27 of file SiGNNTrackFinderTool.cxx.

28 {
29  out<<std::endl;
30  return dumpevent(out);
31 }

◆ dump() [2/2]

std::ostream & InDet::SiGNNTrackFinderTool::dump ( std::ostream &  out) const
overridevirtual

Definition at line 33 of file SiGNNTrackFinderTool.cxx.

34 {
35  return out;
36 }

◆ dumpevent()

MsgStream & InDet::SiGNNTrackFinderTool::dumpevent ( MsgStream &  out) const
protected

Definition at line 38 of file SiGNNTrackFinderTool.cxx.

39 {
40  out<<"|---------------------------------------------------------------------|"
41  <<std::endl;
42  out<<"| Number output tracks | "<<std::setw(12)
43  <<" |"<<std::endl;
44  out<<"|---------------------------------------------------------------------|"
45  <<std::endl;
46  return out;
47 }

◆ getTracks()

StatusCode InDet::SiGNNTrackFinderTool::getTracks ( const std::vector< const Trk::SpacePoint * > &  spacepoints,
std::vector< std::vector< uint32_t > > &  tracks 
) const
overridevirtual

Get track candidates from a list of space points.

Parameters
spacepointsa list of spacepoints as inputs to the GNN-based track finder.
tracksa list of track candidates.
Returns

Definition at line 49 of file SiGNNTrackFinderTool.cxx.

52 {
53  int64_t numSpacepoints = (int64_t)spacepoints.size();
54  std::vector<float> inputValues;
55  std::vector<uint32_t> spacepointIDs;
56 
57  int64_t spacepointFeatures = 3;
58  int sp_idx = 0;
59  for(const auto& sp: spacepoints){
60  // depending on the trained embedding and GNN models, the input features
61  // may need to be updated.
62 
63  float z = sp->globalPosition().z() / 1000.;
64  float r = sp->r() / 1000.;
65  float phi = sp->phi() / M_PI;
66  inputValues.push_back(r);
67  inputValues.push_back(phi);
68  inputValues.push_back(z);
69 
70 
71  spacepointIDs.push_back(sp_idx++);
72  }
73 
74  // ************
75  // Embedding
76  // ************
77 
78  std::vector<int64_t> eInputShape{numSpacepoints, spacepointFeatures};
79  std::vector<Ort::Value> eInputTensor;
80  ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, inputValues, 0, numSpacepoints) );
81 
82  std::vector<Ort::Value> eOutputTensor;
83  std::vector<float> eOutputData;
84  ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
85 
86  ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
87 
88  // ************
89  // Building Edges
90  // ************
91  std::vector<int64_t> edgeList;
92  buildEdges(eOutputData, edgeList, numSpacepoints, m_embeddingDim, m_rVal, m_knnVal);
93  int64_t numEdges = edgeList.size() / 2;
94 
95  // ************
96  // Filtering
97  // ************
98  std::vector<Ort::Value> fInputTensor;
99  fInputTensor.push_back(
100  std::move(eInputTensor[0])
101  );
102  ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
103  std::vector<int64_t> fEdgeShape{2, numEdges};
104 
105  std::vector<float> fOutputData;
106  std::vector<Ort::Value> fOutputTensor;
107  ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
108 
109  ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
110 
111  // apply sigmoid to the filtering output data
112  // and remove edges with score < filterCut
113  std::vector<int64_t> rowIndices;
114  std::vector<int64_t> colIndices;
115  for (int64_t i = 0; i < numEdges; i++){
116  float v = 1.f / (1.f + std::exp(-fOutputData[i])); // sigmoid, float type
117  if (v > m_filterCut){
118  rowIndices.push_back(edgeList[i]);
119  colIndices.push_back(edgeList[numEdges + i]);
120  };
121  };
122  std::vector<int64_t> edgesAfterFiltering;
123  edgesAfterFiltering.insert(edgesAfterFiltering.end(), rowIndices.begin(), rowIndices.end());
124  edgesAfterFiltering.insert(edgesAfterFiltering.end(), colIndices.begin(), colIndices.end());
125 
126  int64_t numEdgesAfterF = edgesAfterFiltering.size() / 2;
127 
128  // ************
129  // GNN
130  // ************
131  std::vector<Ort::Value> gInputTensor;
132  gInputTensor.push_back(
133  std::move(fInputTensor[0])
134  );
135  ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
136 
137  // gnn outputs
138  std::vector<float> gOutputData;
139  std::vector<Ort::Value> gOutputTensor;
140  ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) );
141 
142  ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) );
143  // apply sigmoid to the gnn output data
144  for(auto& v : gOutputData){
145  v = 1.f / (1.f + std::exp(-v));
146  };
147 
148  // ************
149  // Track Labeling with cugraph::connected_components
150  // ************
151  std::vector<int32_t> trackLabels(numSpacepoints);
152  weaklyConnectedComponents<int64_t,float,int32_t>(numSpacepoints, rowIndices, colIndices, gOutputData, trackLabels);
153 
154  if (trackLabels.size() == 0) return StatusCode::SUCCESS;
155 
156  tracks.clear();
157 
158  int existTrkIdx = 0;
159  // map labeling from MCC to customized track id.
160  std::map<int32_t, int32_t> trackLableToIds;
161 
162  for(int32_t idx=0; idx < numSpacepoints; ++idx) {
163  int32_t trackLabel = trackLabels[idx];
164  uint32_t spacepointID = spacepointIDs[idx];
165 
166  int trkId;
167  if(trackLableToIds.find(trackLabel) != trackLableToIds.end()) {
168  trkId = trackLableToIds[trackLabel];
169  tracks[trkId].push_back(spacepointID);
170  } else {
171  // a new track, assign the track id
172  // and create a vector
173  trkId = existTrkIdx;
174  tracks.push_back(std::vector<uint32_t>{spacepointID});
175  trackLableToIds[trackLabel] = trkId;
176  existTrkIdx++;
177  }
178  }
179  return StatusCode::SUCCESS;
180 }

◆ initialize()

StatusCode InDet::SiGNNTrackFinderTool::initialize ( )
overridevirtual

Definition at line 20 of file SiGNNTrackFinderTool.cxx.

20  {
21  ATH_CHECK( m_embedSessionTool.retrieve() );
22  ATH_CHECK( m_filterSessionTool.retrieve() );
23  ATH_CHECK( m_gnnSessionTool.retrieve() );
24  return StatusCode::SUCCESS;
25 }

◆ initTrainedModels()

void InDet::SiGNNTrackFinderTool::initTrainedModels ( )
protected

◆ operator=()

SiGNNTrackFinderTool& InDet::SiGNNTrackFinderTool::operator= ( const SiGNNTrackFinderTool )
protecteddelete

Member Data Documentation

◆ m_embeddingDim

UnsignedIntegerProperty InDet::SiGNNTrackFinderTool::m_embeddingDim {this, "embeddingDim", 8}
protected

Definition at line 64 of file SiGNNTrackFinderTool.h.

◆ m_embedSessionTool

ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > InDet::SiGNNTrackFinderTool::m_embedSessionTool
private
Initial value:
{
this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool"
}

Definition at line 74 of file SiGNNTrackFinderTool.h.

◆ m_filterCut

FloatProperty InDet::SiGNNTrackFinderTool::m_filterCut {this, "filterCut", 0.21}
protected

Definition at line 67 of file SiGNNTrackFinderTool.h.

◆ m_filterSessionTool

ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > InDet::SiGNNTrackFinderTool::m_filterSessionTool
private
Initial value:
{
this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool"
}

Definition at line 77 of file SiGNNTrackFinderTool.h.

◆ m_gnnSessionTool

ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > InDet::SiGNNTrackFinderTool::m_gnnSessionTool
private
Initial value:
{
this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool"
}

Definition at line 80 of file SiGNNTrackFinderTool.h.

◆ m_inputMLModuleDir

StringProperty InDet::SiGNNTrackFinderTool::m_inputMLModuleDir {this, "inputMLModelDir", ""}
protected

Definition at line 68 of file SiGNNTrackFinderTool.h.

◆ m_knnVal

UnsignedIntegerProperty InDet::SiGNNTrackFinderTool::m_knnVal {this, "knnVal", 500}
protected

Definition at line 66 of file SiGNNTrackFinderTool.h.

◆ m_rVal

FloatProperty InDet::SiGNNTrackFinderTool::m_rVal {this, "rVal", 1.7}
protected

Definition at line 65 of file SiGNNTrackFinderTool.h.


The documentation for this class was generated from the following files:
beamspotman.r
def r
Definition: beamspotman.py:676
phi
Scalar phi() const
phi method
Definition: AmgMatrixBasePlugin.h:67
xAOD::uint32_t
setEventNumber uint32_t
Definition: EventInfo_v1.cxx:127
InDet::SiGNNTrackFinderTool::dumpevent
MsgStream & dumpevent(MsgStream &out) const
Definition: SiGNNTrackFinderTool.cxx:38
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
M_PI
#define M_PI
Definition: ActiveFraction.h:11
InDet::SiGNNTrackFinderTool::m_gnnSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
Definition: SiGNNTrackFinderTool.h:80
drawFromPickle.exp
exp
Definition: drawFromPickle.py:36
InDet::SiGNNTrackFinderTool::m_filterCut
FloatProperty m_filterCut
Definition: SiGNNTrackFinderTool.h:67
lumiFormat.i
int i
Definition: lumiFormat.py:85
z
#define z
InDet::SiGNNTrackFinderTool::m_embedSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
Definition: SiGNNTrackFinderTool.h:74
InDet::SiGNNTrackFinderTool::m_knnVal
UnsignedIntegerProperty m_knnVal
Definition: SiGNNTrackFinderTool.h:66
buildEdges
void buildEdges(std::vector< float > &embedFeatures, std::vector< int64_t > &edgeList, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
Definition: ExaTrkXUtils.cxx:10
test_pyathena.parent
parent
Definition: test_pyathena.py:15
InDet::SiGNNTrackFinderTool::m_filterSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
Definition: SiGNNTrackFinderTool.h:77
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
InDet::SiGNNTrackFinderTool::m_rVal
FloatProperty m_rVal
Definition: SiGNNTrackFinderTool.h:65
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
python.PyAthena.v
v
Definition: PyAthena.py:154
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
InDet::SiGNNTrackFinderTool::m_embeddingDim
UnsignedIntegerProperty m_embeddingDim
Definition: SiGNNTrackFinderTool.h:64