23 ATH_MSG_ERROR(
"Path to 1st stage NN-based fake track removal ONNX file is empty! If you want to run this pipeline, you need to provide an input file.");
24 return StatusCode::FAILURE;
37 return StatusCode::SUCCESS;
52 return StatusCode::SUCCESS;
58 std::unique_ptr<TTree>
tree(
static_cast<TTree*
>(
file->Get(
"TreeModuleDoublet")));
60 unsigned int mid1_value = 0;
61 unsigned int mid2_value = 0;
62 float z0min_12_value = 0.0;
63 float dphimin_12_value = 0.0;
64 float phislopemin_12_value = 0.0;
65 float detamin_12_value = 0.0;
66 float z0max_12_value = 0.0;
67 float dphimax_12_value = 0.0;
68 float phislopemax_12_value = 0.0;
69 float detamax_12_value = 0.0;
71 tree->SetBranchAddress(
"Module1", &mid1_value);
72 tree->SetBranchAddress(
"Module2", &mid2_value);
73 tree->SetBranchAddress(
"z0min_12", &z0min_12_value);
74 tree->SetBranchAddress(
"dphimin_12", &dphimin_12_value);
75 tree->SetBranchAddress(
"phiSlopemin_12", &phislopemin_12_value);
76 tree->SetBranchAddress(
"detamin_12", &detamin_12_value);
77 tree->SetBranchAddress(
"z0max_12", &z0max_12_value);
78 tree->SetBranchAddress(
"dphimax_12", &dphimax_12_value);
79 tree->SetBranchAddress(
"phiSlopemax_12", &phislopemax_12_value);
80 tree->SetBranchAddress(
"detamax_12", &detamax_12_value);
82 int64_t nEntries =
tree->GetEntries();
83 for (int64_t i = 0; i < nEntries; ++i) {
85 m_mid1.emplace_back(mid1_value);
86 m_mid2.emplace_back(mid2_value);
115 for (
size_t i = 0; i <
m_mid2.size(); i++) {
116 std::vector<std::shared_ptr<FPGATrackSimGNNHit>> hit1_matches;
117 std::vector<std::shared_ptr<FPGATrackSimGNNHit>> hit2_matches;
119 std::vector<int> hit1_indices;
120 std::vector<int> hit2_indices;
122 for (
size_t j = 0; j < hits.size(); j++) {
123 if (hits[j]->getIdentifier() ==
m_mid1[i]) {
124 hit1_matches.emplace_back(hits[j]);
125 hit1_indices.emplace_back(j);
127 if (hits[j]->getIdentifier() ==
m_mid2[i]) {
128 hit2_matches.emplace_back(hits[j]);
129 hit2_indices.emplace_back(j);
133 for (
size_t h1 = 0; h1 < hit1_matches.size(); h1++) {
134 for (
size_t h2 = 0; h2 < hit2_matches.size(); h2++) {
135 applyDoubletCuts(hit1_matches[h1], hit2_matches[h2], edges, hit1_indices[h1], hit2_indices[h2], i);
147 float deta = hit1->getEta() - hit2->getEta();
151 float dz = hit2->getZ() - hit1->getZ();
152 float dr = hit2->getR() - hit1->getR();
153 float z0 = dr==0. ? 0. : hit1->getZ() - (hit1->getR() * dz / dr);
161 float phislope = dr==0. ? 0. : dphi / dr;
165 std::shared_ptr<FPGATrackSimGNNEdge> edge = std::make_shared<FPGATrackSimGNNEdge>();
166 edge->setEdgeIndex1(hit1_index);
167 edge->setEdgeIndex2(hit2_index);
168 edges.emplace_back(edge);
195 if(feature < 0.0) {
return -1.0; }
196 else if(feature > 0.0) {
return 1.0; }
205 std::vector<float> gEmbedded =
embed(hits);
211 std::vector<float> gNodeFeatures;
213 for(
auto hit : hits) {
214 std::map<std::string, float> features;
215 features[
"r"] = hit->getR();
216 features[
"phi"] = hit->getPhi();
217 features[
"z"] = hit->getZ();
220 gNodeFeatures.push_back(
224 return gNodeFeatures;
231 std::vector<float> gEmbedded;
233 std::vector<Ort::Value> gInputTensor;
234 StatusCode s =
m_MLInferenceTool->addInput(gInputTensor, gNodeFeatures, 0, hits.size());
235 std::vector<Ort::Value> gOutputTensor;
243 std::vector<float> & gEmbedded)
248 int size = hits.size();
254 std::vector<float> start(n_dim);
257 for(
int k = 0; k < size; ++k){
260 for(
int j = 0; j < n_dim; ++j){
261 start[j] = gEmbedded[k*n_dim + j];
264 for (
int i = k + 1; i < size; ++i){
266 for(
int d = 0; d < n_dim; ++d){
267 distance += (start[d] - gEmbedded[i*n_dim + d]) * (start[d] - gEmbedded[i*n_dim + d]);
270 if(distance < r_squared){
271 std::shared_ptr<FPGATrackSimGNNEdge> edge = std::make_shared<FPGATrackSimGNNEdge>();
273 float d_i_sq = (hits[i]->getR() * hits[i]->getR()) + (hits[i]->getZ() * hits[i]->getZ());
274 float d_k_sq = (hits[k]->getR() * hits[k]->getR()) + (hits[k]->getZ() * hits[k]->getZ());
275 if (d_i_sq < d_k_sq){
283 edge->setEdgeIndex1(index1);
284 edge->setEdgeIndex2(index2);
285 edges.emplace_back(edge);
#define ATH_CHECK
Evaluate an expression and check for errors.
int count(std::string s, const std::string ®x)
count how many occurances of a regx are in a string
double deltaPhi(double phiA, double phiB)
delta Phi in range [-pi,pi[