6 #include "ExaTrkXUtils.hpp"
16 const std::string&
type,
const std::string&
name,
const IInterface*
parent):
19 declareInterface<IGNNTrackFinder>(
this);
33 auto split_fn = [](
const std::string&
s,
auto convert_fn) {
34 using ReturnType = std::decay_t<decltype(convert_fn(std::declval<std::string>()))>;
35 std::vector<ReturnType>
tokens;
37 std::istringstream tokenStream(
s);
38 while (std::getline(tokenStream, token,
',')) {
39 token = token.substr(token.find_first_not_of(
" "), token.find_last_not_of(
" ") + 1);
40 tokens.push_back(convert_fn(token));
44 auto convert_to_float = [](
const std::string&
s) ->
float {
return std::stof(
s); };
45 auto convert_to_str = [](
const std::string&
s) -> std::string {
return s; };
59 return StatusCode::SUCCESS;
75 out<<
"|---------------------------------------------------------------------|"
77 out<<
"| Number output tracks | "<<std::setw(12)
79 out<<
"|---------------------------------------------------------------------|"
85 const std::vector<const Trk::SpacePoint*>& spacepoints,
86 std::vector<std::vector<uint32_t> >& tracks)
const
88 int64_t numSpacepoints = (int64_t)spacepoints.size();
89 std::vector<float> eNodeFeatures;
90 std::vector<float> fNodeFeatures;
91 std::vector<float> gNodeFeatures;
92 std::vector<uint32_t> spacepointIDs;
93 std::vector<int> regions;
96 for(
const auto& sp: spacepoints){
98 regions.push_back(featureMap[
"region"]);
101 eNodeFeatures.push_back(
107 fNodeFeatures.push_back(
113 gNodeFeatures.push_back(
117 spacepointIDs.push_back(sp_idx++);
123 std::vector<Ort::Value> eInputTensor;
126 std::vector<Ort::Value> eOutputTensor;
127 std::vector<float> eOutputData;
135 std::vector<int64_t> senders;
136 std::vector<int64_t> receivers;
138 int64_t numEdges = senders.size();
141 eNodeFeatures.clear();
142 eInputTensor.clear();
144 eOutputTensor.clear();
147 std::vector<std::pair<int64_t, int64_t>> edgePairs;
148 for(int64_t
idx = 0;
idx < numEdges;
idx ++ ) {
149 edgePairs.push_back({senders[
idx], receivers[
idx]});
151 std::sort(edgePairs.begin(), edgePairs.end());
152 edgePairs.erase(std::unique(edgePairs.begin(), edgePairs.end()), edgePairs.end());
155 std::random_device rd;
156 std::mt19937 rdm_gen(rd());
157 std::random_shuffle(edgePairs.begin(), edgePairs.end());
160 std::sort(edgePairs.begin(), edgePairs.end(),
161 [numSpacepoints](
const std::pair<int64_t, int64_t>&
a,
const std::pair<int64_t, int64_t>&
b){
162 return a.first * numSpacepoints + a.second < b.first * numSpacepoints + b.second;
168 for(
const auto& edge: edgePairs){
169 senders.push_back(edge.first);
170 receivers.push_back(edge.second);
178 std::vector<Ort::Value> fInputTensor;
181 std::vector<int64_t> edgeList(numEdges * 2);
182 std::copy(senders.begin(), senders.end(), edgeList.begin());
183 std::copy(receivers.begin(), receivers.end(), edgeList.begin() + senders.size());
188 std::vector<float> fOutputData;
189 std::vector<Ort::Value> fOutputTensor;
197 std::vector<int64_t> rowIndices;
198 std::vector<int64_t> colIndices;
199 for (int64_t
i = 0;
i < numEdges;
i++){
200 float v = 1.f / (1.f +
std::exp(-fOutputData[
i]));
202 auto src = edgeList[
i];
203 auto dst = edgeList[numEdges +
i];
207 rowIndices.push_back(
src);
208 colIndices.push_back(dst);
211 int64_t numEdgesAfterF = rowIndices.size();
214 fNodeFeatures.clear();
215 fInputTensor.clear();
217 fOutputTensor.clear();
222 std::vector<int64_t> edgesAfterFiltering(numEdgesAfterF * 2);
223 std::copy(rowIndices.begin(), rowIndices.end(), edgesAfterFiltering.begin());
224 std::copy(colIndices.begin(), colIndices.end(), edgesAfterFiltering.begin() + senders.size());
231 for(
size_t idx = 0; idx < static_cast<size_t>(numSpacepoints);
idx++){
232 if (regions[
idx] == 2 || regions[
idx] == 6){
239 std::vector<Ort::Value> gInputTensor;
244 std::vector<float> gnnEdgeFeatures;
249 std::vector<float> gOutputData;
250 std::vector<Ort::Value> gOutputTensor;
255 for(
auto&
v : gOutputData){
260 gNodeFeatures.clear();
261 gInputTensor.clear();
262 edgesAfterFiltering.clear();
270 rowIndices, colIndices, gOutputData,
274 return StatusCode::SUCCESS;