200 parser = argparse.ArgumentParser(
201 description="Tune SegmentEdge thresholds by comparing edge-chain track loss to baseline."
202 )
203 parser.add_argument("--inputFile", required=True)
204 parser.add_argument("--bucketModel", required=True)
205 parser.add_argument("--bucketThreshold", "--score-threshold", dest="bucketThreshold", type=float, default=0.0,
206 help="Threshold on bucket filter score")
207 parser.add_argument("--edgeModel", required=True)
208 parser.add_argument("--nEvents", type=int, default=100)
209 parser.add_argument("--workDir", default="edge_threshold_tuning")
210 parser.add_argument("--edgeThresholds", default="0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.50")
211 parser.add_argument("--overlapThresholds", default="0.60,0.70,0.80,0.90")
212 parser.add_argument("--targetLoss", type=float, default=0.001,
213 help="Maximum allowed relative loss, default 0.001 = 0.1 percent")
214 parser.add_argument("--metricTree", default=None)
215 parser.add_argument("--metricBranch", default=None)
216 parser.add_argument("--metricMode", default="matchedTruthTracks",
217 choices=["matchedTruthTracks", "rawTrackCount"],
218 help="matchedTruthTracks counts MS tracks with truthLink >= threshold and pT cut")
219 parser.add_argument("--truthLinkBranch", default="MSTrksR4_truthLink")
220 parser.add_argument("--trackPtBranch", default="MSTrksR4_pt")
221 parser.add_argument("--truthLinkThreshold", type=int, default=1)
222 parser.add_argument("--minPtGeV", type=float, default=2.0)
223 parser.add_argument("--ptUnits", default="auto",
224 choices=["auto", "MeV", "GeV"])
225 parser.add_argument("--extraRecoArgs", default="",
226 help="Extra args forwarded to muonEdgeRecoChain.py")
227 parser.add_argument("--use-gpu", action="store_true", dest="use_gpu", default=None,
228 help="Use GPU for ONNX inference in the reco chain (default: auto-detect)")
229 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
230 help="Force CPU for ONNX inference in the reco chain")
231 parser.add_argument("--skipExisting", action="store_true")
232 args = parser.parse_args()
233
234 work = Path(args.workDir).resolve()
235 work.mkdir(parents=True, exist_ok=True)
236
237 baseline_root = work / "baseline.root"
238 baseline_log = work / "baseline.log"
239 baseline_cmd = _chain_cmd(args, baseline_root, edge=False)
240 if not args.skipExisting or not baseline_root.exists():
241 rc = _run(baseline_cmd, baseline_log)
242 if rc != 0:
243 raise SystemExit(f"Baseline job failed with rc={rc}. See {baseline_log}")
244 if not baseline_root.exists():
245 raise SystemExit(
246 f"Baseline job finished but output ROOT file is missing: {baseline_root}. "
247 f"Ensure reco args produce this output. "
248 f"See {baseline_log}"
249 )
250
251 baseline_metric = _metric_from_root(baseline_root, args.metricTree, args.metricBranch, args)
252 if baseline_metric <= 0:
253 raise SystemExit(f"Baseline metric is non-positive: {baseline_metric}")
254
255 rows = []
256 best = None
257 for edge_thr in _parse_float_list(args.edgeThresholds):
258 for overlap_thr in _parse_float_list(args.overlapThresholds):
259 tag = f
"edge{edge_thr:.3f}_overlap{overlap_thr:.3f}".
replace(
".",
"p")
260 out_root = work / f"{tag}.root"
261 out_log = work / f"{tag}.log"
262 cmd = _chain_cmd(args, out_root, edge=True,
263 edge_threshold=edge_thr,
264 overlap_threshold=overlap_thr)
265 if not args.skipExisting or not out_root.exists():
266 rc = _run(cmd, out_log)
267 if rc != 0:
268 rows.append({
269 "edgeThreshold": edge_thr,
270 "overlapThreshold": overlap_thr,
271 "status": "failed",
273 })
274 continue
275 if not out_root.exists():
276 rows.append({
277 "edgeThreshold": edge_thr,
278 "overlapThreshold": overlap_thr,
279 "status": "missing_output",
281 "rootFile":
str(out_root),
282 })
283 continue
284
285 metric = _metric_from_root(out_root, args.metricTree, args.metricBranch, args)
286 loss =
max(0.0, (baseline_metric - metric) / baseline_metric)
287 row = {
288 "edgeThreshold": edge_thr,
289 "overlapThreshold": overlap_thr,
290 "status": "ok",
291 "baselineMetric": baseline_metric,
292 "edgeMetric": metric,
293 "relativeLoss": loss,
294 "passesTarget": loss < args.targetLoss,
295 "rootFile":
str(out_root),
297 }
298 rows.append(row)
299
300 if row["passesTarget"]:
301
302 key = (edge_thr, overlap_thr)
303 if best is None or key > (best["edgeThreshold"], best["overlapThreshold"]):
304 best = row
305
306 csv_path = work / "edge_threshold_scan.csv"
307 with csv_path.open("w", newline="", encoding="utf-8") as f:
308 writer = csv.DictWriter(f, fieldnames=sorted({k for r in rows for k in r}))
309 writer.writeheader()
310 writer.writerows(rows)
311
312 summary = {
313 "baselineMetric": baseline_metric,
314 "targetLoss": args.targetLoss,
315 "best": best,
316 "scanCsv":
str(csv_path),
317 }
318 (work / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
319
320 print(json.dumps(summary, indent=2))
321
322
void print(char *figname, TCanvas *c1)
std::string replace(std::string s, const std::string &s2, const std::string &s3)