24def _run(cmd: list[str], log_path: Path) -> int:
25 log_path.parent.mkdir(parents=
True, exist_ok=
True)
26 with log_path.open(
"w", encoding=
"utf-8")
as log:
27 proc = subprocess.run(cmd, stdout=log, stderr=subprocess.STDOUT, text=
True)
28 return proc.returncode
70 truth_link_branch: str,
71 pt_branch: str |
None,
72 truth_link_threshold: int,
74 pt_units: str) -> float:
75 branches = {b.GetName()
for b
in tree.GetListOfBranches()}
76 if truth_link_branch
not in branches:
77 raise RuntimeError(f
"Missing truth-link branch '{truth_link_branch}'")
78 if pt_branch
and pt_branch
not in branches:
79 raise RuntimeError(f
"Missing pT branch '{pt_branch}'")
83 links = list(getattr(entry, truth_link_branch))
84 pts = list(getattr(entry, pt_branch))
if pt_branch
else [
None] * len(links)
86 for i, link
in enumerate(links):
87 if int(link) < truth_link_threshold:
91 if pt_gev < min_pt_gev:
98 preferred_tree: str |
None,
99 preferred_branch: str |
None,
102 Generic ROOT metric reader.
105 --metricTree <tree> --metricBranch <branch>
107 If not provided, the script tries common branch names. This is intentionally
108 external to Athena so the tuning loop can run many complete jobs.
111 branch_candidates = []
113 branch_candidates.append(preferred_branch)
114 branch_candidates += [
124 trees = [(name, tree)
for name, tree
in trees
if name == preferred_tree
or name.endswith(
"/" + preferred_tree)]
126 raise RuntimeError(f
"No matching TTree found in {root_file}")
128 if args.metricMode ==
"matchedTruthTracks":
129 for _, tree
in trees:
131 args.truthLinkBranch,
133 args.truthLinkThreshold,
137 for tree_name, tree
in trees:
138 branches = {b.GetName()
for b
in tree.GetListOfBranches()}
139 for branch
in branch_candidates:
140 if branch
not in branches:
144 val = getattr(entry, branch)
148 total += float(len(val))
152 for tree_name, tree
in trees:
153 available[tree_name] = [b.GetName()
for b
in tree.GetListOfBranches()]
155 "Could not find a metric branch. Pass --metricTree/--metricBranch. "
156 f
"Available branches: {json.dumps(available, indent=2)}"
161 edge_threshold: float |
None =
None,
162 overlap_threshold: float |
None =
None) -> list[str]:
163 recochain_script = Path(__file__).with_name(
"muonEdgeRecoChain.py")
166 str(recochain_script),
167 "--inputFile", args.inputFile,
168 "--nEvents",
str(args.nEvents),
169 "--outRootFile",
str(out_root),
170 "--bucketModel", args.bucketModel,
171 "--bucketThreshold",
str(args.bucketThreshold),
173 if args.extraRecoArgs:
174 cmd += args.extraRecoArgs.split()
177 if getattr(args,
'use_gpu',
None)
is True:
178 cmd.append(
"--use-gpu")
179 elif getattr(args,
'use_gpu',
None)
is False:
180 cmd.append(
"--use-cpu")
184 "--edgeModel", args.edgeModel,
185 "--enableEdgeClassifier",
187 "--edgeThreshold",
str(edge_threshold),
188 "--overlapThreshold",
str(overlap_threshold),
193 "--disableEdgeClassifier",
200 parser = argparse.ArgumentParser(
201 description=
"Tune SegmentEdge thresholds by comparing edge-chain track loss to baseline."
203 parser.add_argument(
"--inputFile", required=
True)
204 parser.add_argument(
"--bucketModel", required=
True)
205 parser.add_argument(
"--bucketThreshold",
"--score-threshold", dest=
"bucketThreshold", type=float, default=0.0,
206 help=
"Threshold on bucket filter score")
207 parser.add_argument(
"--edgeModel", required=
True)
208 parser.add_argument(
"--nEvents", type=int, default=100)
209 parser.add_argument(
"--workDir", default=
"edge_threshold_tuning")
210 parser.add_argument(
"--edgeThresholds", default=
"0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.50")
211 parser.add_argument(
"--overlapThresholds", default=
"0.60,0.70,0.80,0.90")
212 parser.add_argument(
"--targetLoss", type=float, default=0.001,
213 help=
"Maximum allowed relative loss, default 0.001 = 0.1 percent")
214 parser.add_argument(
"--metricTree", default=
None)
215 parser.add_argument(
"--metricBranch", default=
None)
216 parser.add_argument(
"--metricMode", default=
"matchedTruthTracks",
217 choices=[
"matchedTruthTracks",
"rawTrackCount"],
218 help=
"matchedTruthTracks counts MS tracks with truthLink >= threshold and pT cut")
219 parser.add_argument(
"--truthLinkBranch", default=
"MSTrksR4_truthLink")
220 parser.add_argument(
"--trackPtBranch", default=
"MSTrksR4_pt")
221 parser.add_argument(
"--truthLinkThreshold", type=int, default=1)
222 parser.add_argument(
"--minPtGeV", type=float, default=2.0)
223 parser.add_argument(
"--ptUnits", default=
"auto",
224 choices=[
"auto",
"MeV",
"GeV"])
225 parser.add_argument(
"--extraRecoArgs", default=
"",
226 help=
"Extra args forwarded to muonEdgeRecoChain.py")
227 parser.add_argument(
"--use-gpu", action=
"store_true", dest=
"use_gpu", default=
None,
228 help=
"Use GPU for ONNX inference in the reco chain (default: auto-detect)")
229 parser.add_argument(
"--use-cpu", dest=
"use_gpu", action=
"store_false",
230 help=
"Force CPU for ONNX inference in the reco chain")
231 parser.add_argument(
"--skipExisting", action=
"store_true")
232 args = parser.parse_args()
234 work = Path(args.workDir).resolve()
235 work.mkdir(parents=
True, exist_ok=
True)
237 baseline_root = work /
"baseline.root"
238 baseline_log = work /
"baseline.log"
239 baseline_cmd =
_chain_cmd(args, baseline_root, edge=
False)
240 if not args.skipExisting
or not baseline_root.exists():
241 rc =
_run(baseline_cmd, baseline_log)
243 raise SystemExit(f
"Baseline job failed with rc={rc}. See {baseline_log}")
244 if not baseline_root.exists():
246 f
"Baseline job finished but output ROOT file is missing: {baseline_root}. "
247 f
"Ensure reco args produce this output. "
248 f
"See {baseline_log}"
251 baseline_metric =
_metric_from_root(baseline_root, args.metricTree, args.metricBranch, args)
252 if baseline_metric <= 0:
253 raise SystemExit(f
"Baseline metric is non-positive: {baseline_metric}")
259 tag = f
"edge{edge_thr:.3f}_overlap{overlap_thr:.3f}".
replace(
".",
"p")
260 out_root = work / f
"{tag}.root"
261 out_log = work / f
"{tag}.log"
263 edge_threshold=edge_thr,
264 overlap_threshold=overlap_thr)
265 if not args.skipExisting
or not out_root.exists():
266 rc =
_run(cmd, out_log)
269 "edgeThreshold": edge_thr,
270 "overlapThreshold": overlap_thr,
275 if not out_root.exists():
277 "edgeThreshold": edge_thr,
278 "overlapThreshold": overlap_thr,
279 "status":
"missing_output",
281 "rootFile":
str(out_root),
286 loss =
max(0.0, (baseline_metric - metric) / baseline_metric)
288 "edgeThreshold": edge_thr,
289 "overlapThreshold": overlap_thr,
291 "baselineMetric": baseline_metric,
292 "edgeMetric": metric,
293 "relativeLoss": loss,
294 "passesTarget": loss < args.targetLoss,
295 "rootFile":
str(out_root),
300 if row[
"passesTarget"]:
302 key = (edge_thr, overlap_thr)
303 if best
is None or key > (best[
"edgeThreshold"], best[
"overlapThreshold"]):
306 csv_path = work /
"edge_threshold_scan.csv"
307 with csv_path.open(
"w", newline=
"", encoding=
"utf-8")
as f:
308 writer = csv.DictWriter(f, fieldnames=sorted({k
for r
in rows
for k
in r}))
310 writer.writerows(rows)
313 "baselineMetric": baseline_metric,
314 "targetLoss": args.targetLoss,
316 "scanCsv":
str(csv_path),
318 (work /
"summary.json").write_text(json.dumps(summary, indent=2), encoding=
"utf-8")
320 print(json.dumps(summary, indent=2))