ATLAS Offline Software
Loading...
Searching...
No Matches
muonEdgeTuner.py
Go to the documentation of this file.
1#!/usr/bin/env python
2# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3#
4# External tuning harness:
5# 1. run the regular, non ML reconstruction baseline
6# 2. run the edge-classifier reconstruction for a threshold grid
7# 3. measure signal-muon reconstruction efficiency in MsTrackValidTest
8# 4. report threshold points with relative efficiency loss below target
9
10from __future__ import annotations
11
12import argparse
13import csv
14import json
15import subprocess
16import sys
17from pathlib import Path
18
19REQUIRED_BRANCHES = (
20 "TruthMuons_truthOrigin",
21 "TruthMuons_truthType",
22 "MsTrkSeed_truthLink",
23 "ActsMuons_seedLink",
24)
25
26def _parse_float_list(raw: str) -> list[float]:
27 values = [float(value) for value in raw.split(",") if value.strip()]
28 if not values:
29 raise argparse.ArgumentTypeError("threshold list must contain at least one value")
30 return values
31
32
33def _run(command: list[str], log_path: Path) -> int:
34 log_path.parent.mkdir(parents=True, exist_ok=True)
35 with log_path.open("w", encoding="utf-8") as log:
36 process = subprocess.run(
37 command,
38 stdout=log,
39 stderr=subprocess.STDOUT,
40 text=True,
41 )
42 return process.returncode
43
44
45def _signal_muon_efficiency(root_file: Path,
46 tree_name: str,
47 signal_origin: int,
48 signal_type: int) -> dict[str, float | int]:
49 """Measure the same signal-muon efficiency used in the comparison notebook.
50 A denominator muon is a truth muon satisfying:
51
52 TruthMuons_truthOrigin == signal_origin
53 TruthMuons_truthType == signal_type
54
55 It is reconstructed when an ActsMuons track points to a seed via
56 ActsMuons_seedLink and that seed points back to the truth-muon index via
57 MsTrkSeed_truthLink. All links are local to one event.
58 """
59
60 import ROOT
61
62 input_file = ROOT.TFile.Open(str(root_file), "READ")
63 if not input_file or input_file.IsZombie():
64 raise RuntimeError(f"Could not open ROOT file: {root_file}")
65 tree = input_file.Get(tree_name)
66 if not tree or not tree.InheritsFrom("TTree"):
67 input_file.Close()
68 raise RuntimeError(
69 f"Could not find TTree '{tree_name}' in {root_file}. "
70 "The reco chain must write the MsTrackValidTest tree."
71 )
72
73 missing_branches = [
74 branch for branch in REQUIRED_BRANCHES if not tree.GetBranch(branch)
75 ]
76 if missing_branches:
77 input_file.Close()
78 raise RuntimeError(
79 f"Missing required branch(es) in {root_file}: {', '.join(missing_branches)}"
80 )
81
82 signal_muons = 0
83 matched_signal_muons = 0
84
85 for entry_number in range(tree.GetEntries()):
86 tree.GetEntry(entry_number)
87
88 truth_origins = [
89 int(value) for value in tree.TruthMuons_truthOrigin
90 ]
91 truth_types = [
92 int(value) for value in tree.TruthMuons_truthType
93 ]
94 seed_truth_links = [
95 int(value) for value in tree.MsTrkSeed_truthLink
96 ]
97 reco_seed_links = [
98 int(value) for value in tree.ActsMuons_seedLink
99 ]
100
101 if len(truth_origins) != len(truth_types):
102 input_file.Close()
103 raise RuntimeError(
104 f"Truth origin/type vector-size mismatch in entry {entry_number} "
105 f"of {root_file}"
106 )
107
108 signal_truth_indices = (
109 truth_index
110 for truth_index, (origin, truth_type) in enumerate(
111 zip(truth_origins, truth_types)
112 )
113 if origin == signal_origin and truth_type == signal_type
114 )
115
116 for truth_index in signal_truth_indices:
117 signal_muons += 1
118 matching_seed_indices = {
119 seed_index
120 for seed_index, linked_truth_index in enumerate(seed_truth_links)
121 if linked_truth_index == truth_index
122 }
123 if matching_seed_indices and any(
124 seed_index in matching_seed_indices
125 for seed_index in reco_seed_links
126 ):
127 matched_signal_muons += 1
128
129 input_file.Close()
130
131 if signal_muons == 0:
132 raise RuntimeError(
133 f"No signal truth muons found in {root_file}. "
134 f"Selection is truthOrigin={signal_origin}, truthType={signal_type}."
135 )
136
137 return {
138 "signalMuonCount": signal_muons,
139 "matchedSignalMuonCount": matched_signal_muons,
140 "signalMuonEfficiency": matched_signal_muons / signal_muons,
141 }
142
144 out_root: Path,
145 edge: bool,
146 edge_threshold: float | None = None,
147 overlap_threshold: float | None = None) -> list[str]:
148 """Build a command equivalent to reco_chain_ec.sh or the no-ML baseline."""
149 recochain_script = Path(__file__).with_name("muonEdgeRecoChain.py")
150
151 command = [
152 sys.executable,
153 str(recochain_script),
154 "--threads", str(args.threads),
155 "--nEvents", str(args.nEvents),
156 "--skipEvents", str(args.skipEvents),
157 "--inputFile", args.inputFile,
158 "--outRootFile", str(out_root),
159 "--defaultGeoFile", args.defaultGeoFile,
160 "--noPerfMon",
161 ]
162
163 if edge:
164 command += [
165 "--edgeModel", args.edgeModel,
166 "--edgeThreshold", str(edge_threshold),
167 "--overlapThreshold", str(overlap_threshold),
168 "--enableBucketFilter",
169 "--enableEdgeClassifier",
170 "--useMlSeeder",
171 ]
172 if args.bucketThreshold is not None:
173 command += ["--bucketThreshold", str(args.bucketThreshold)]
174 if args.bucketModel:
175 command += ["--bucketModel", args.bucketModel]
176 else:
177 # Same upstream chain, but no edge inference and old seeder.
178 command += [
179 "--disableBucketFilter",
180 "--disableEdgeClassifier",
181 "--useOldSeeder",
182 "--skip-onnx",
183 ]
184
185 # muonEdgeRecoChain.py defaults to CUDA when ONNX inference is enabled.
186 if args.use_cpu:
187 command.append("--use-cpu")
188
189 return command
190
191def _result_row(edge_threshold: float,
192 overlap_threshold: float,
193 baseline: dict[str, float | int],
194 edge: dict[str, float | int],
195 target_relative_efficiency_loss: float,
196 root_file: Path,
197 log_file: Path) -> dict[str, float | int | str | bool]:
198 if edge["signalMuonCount"] != baseline["signalMuonCount"]:
199 return {
200 "edgeThreshold": edge_threshold,
201 "overlapThreshold": overlap_threshold,
202 "status": "truth_count_mismatch",
203 "baselineSignalMuonCount": baseline["signalMuonCount"],
204 "edgeSignalMuonCount": edge["signalMuonCount"],
205 "rootFile": str(root_file),
206 "log": str(log_file),
207 }
208
209 baseline_efficiency = float(baseline["signalMuonEfficiency"])
210 edge_efficiency = float(edge["signalMuonEfficiency"])
211 relative_efficiency_difference = (
212 (edge_efficiency - baseline_efficiency) / baseline_efficiency
213 )
214 relative_efficiency_loss = max(0.0, -relative_efficiency_difference)
215
216 return {
217 "edgeThreshold": edge_threshold,
218 "overlapThreshold": overlap_threshold,
219 "status": "ok",
220 "baselineSignalMuonCount": baseline["signalMuonCount"],
221 "baselineMatchedSignalMuons": baseline["matchedSignalMuonCount"],
222 "baselineSignalEfficiency": baseline_efficiency,
223 "edgeSignalMuonCount": edge["signalMuonCount"],
224 "edgeMatchedSignalMuons": edge["matchedSignalMuonCount"],
225 "edgeSignalEfficiency": edge_efficiency,
226 "absoluteEfficiencyDifference": edge_efficiency - baseline_efficiency,
227 "relativeEfficiencyDifference": relative_efficiency_difference,
228 "relativeEfficiencyLoss": relative_efficiency_loss,
229 "passesTarget": relative_efficiency_loss <= target_relative_efficiency_loss,
230 "rootFile": str(root_file),
231 "log": str(log_file),
232 }
233
234def main() -> None:
235 parser = argparse.ArgumentParser(
236 description=("Tune SegmentEdge thresholds using signal-muon reconstruction "
237 "efficiency relative to the regular no-ML reconstruction."))
238 parser.add_argument("--inputFile", required=True)
239 parser.add_argument("--bucketModel", default=None, help=("Optional bucket-filter ONNX model."),)
240 parser.add_argument("--bucketThreshold", "--score-threshold", dest="bucketThreshold", type=float,
241 default=None, help="Bucket-filter score threshold")
242 parser.add_argument("--bucket-output-is-logit", dest="bucketOutputIsLogit", action="store_true", default=False,
243 help=("Interpret the scalar bucket-model output as a logit"))
244 parser.add_argument("--edgeModel", required=True)
245 parser.add_argument("--nEvents", type=int, default=100)
246 parser.add_argument("--skipEvents", type=int, default=0)
247 parser.add_argument("--threads", type=int, default=1)
248 parser.add_argument("--defaultGeoFile", default="RUN4")
249 parser.add_argument("--workDir", default="edge_threshold_tuning")
250 parser.add_argument("--edgeThresholds", default="0.08,0.10,0.119,0.14,0.16", help="Comma-separated recovery edge-probability thresholds to scan",)
251 parser.add_argument("--overlapThresholds", default="0.20,0.30,0.50,0.80", help=("Comma-separated high-purity core edge-probability thresholds to scan."))
252 parser.add_argument("--targetRelativeEfficiencyLoss", "--targetLoss", dest="targetRelativeEfficiencyLoss",
253 type=float, default=0.05, help=("Maximum allowed relative loss in signal-muon efficiency"),)
254 parser.add_argument("--treeName", default="MsTrackValidTest")
255 parser.add_argument("--signalOrigin", type=int, default=13)
256 parser.add_argument("--signalType", type=int, default=6)
257 parser.add_argument("--use-cpu", dest="use_cpu", action="store_true", default=False,)
258 parser.add_argument("--skipExisting", action="store_true")
259 args = parser.parse_args()
260
261 edge_thresholds = _parse_float_list(args.edgeThresholds)
262 overlap_thresholds = _parse_float_list(args.overlapThresholds)
263
264 work_dir = Path(args.workDir).resolve()
265 work_dir.mkdir(parents=True, exist_ok=True)
266
267 baseline_root = work_dir / "baseline_noml.root"
268 baseline_log = work_dir / "baseline_noml.log"
269 baseline_command = _chain_command(args, baseline_root, edge=False)
270 if not args.skipExisting or not baseline_root.exists():
271 return_code = _run(baseline_command, baseline_log)
272 if return_code != 0:
273 raise SystemExit(
274 f"No-ML baseline job failed with rc={return_code}. See {baseline_log}"
275 )
276 if not baseline_root.exists():
277 raise SystemExit(
278 f"No-ML baseline finished but ROOT output is missing: {baseline_root}. "
279 f"See {baseline_log}"
280 )
281
282 baseline = _signal_muon_efficiency(baseline_root, args.treeName, args.signalOrigin, args.signalType,)
283
284 rows: list[dict[str, float | int | str | bool]] = []
285
286 best = None
287 for edge_threshold in edge_thresholds:
288 for overlap_threshold in overlap_thresholds:
289 tag = (
290 f"edge{edge_threshold:.6f}_overlap{overlap_threshold:.6f}"
291 .replace(".", "p")
292 )
293 out_root = work_dir / f"{tag}.root"
294 out_log = work_dir / f"{tag}.log"
295 command = _chain_command(
296 args,
297 out_root,
298 edge=True,
299 edge_threshold=edge_threshold,
300 overlap_threshold=overlap_threshold,
301 )
302 if not args.skipExisting or not out_root.exists():
303 return_code = _run(command, out_log)
304 if return_code != 0:
305 rows.append({
306 "edgeThreshold": edge_threshold,
307 "overlapThreshold": overlap_threshold,
308 "status": "failed",
309 "rootFile": str(out_root),
310 "log": str(out_log),
311 })
312 continue
313 if not out_root.exists():
314 rows.append({
315 "edgeThreshold": edge_threshold,
316 "overlapThreshold": overlap_threshold,
317 "status": "missing_output",
318 "rootFile": str(out_root),
319 "log": str(out_log),
320 })
321 continue
322
324 out_root,
325 args.treeName,
326 args.signalOrigin,
327 args.signalType,
328 )
329 row = _result_row(
330 edge_threshold,
331 overlap_threshold,
332 baseline,
333 edge,
334 args.targetRelativeEfficiencyLoss,
335 out_root,
336 out_log,
337 )
338
339 rows.append(row)
340
341 if row["status"] == "ok" and row["passesTarget"]:
342 key = (edge_threshold, overlap_threshold)
343 if best is None or key > (
344 best["edgeThreshold"],
345 best["overlapThreshold"],
346 ):
347 best = row
348
349 csv_path = work_dir / "edge_threshold_scan.csv"
350 fieldnames = [
351 "edgeThreshold",
352 "overlapThreshold",
353 "status",
354 "baselineSignalMuonCount",
355 "baselineMatchedSignalMuons",
356 "baselineSignalEfficiency",
357 "edgeSignalMuonCount",
358 "edgeMatchedSignalMuons",
359 "edgeSignalEfficiency",
360 "absoluteEfficiencyDifference",
361 "relativeEfficiencyDifference",
362 "relativeEfficiencyLoss",
363 "passesTarget",
364 "rootFile",
365 "log",
366 ]
367 with csv_path.open("w", newline="", encoding="utf-8") as csv_file:
368 writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
369 writer.writeheader()
370 writer.writerows(rows)
371
372 summary = {
373 "treeName": args.treeName,
374 "signalSelection": {
375 "truthOrigin": args.signalOrigin,
376 "truthType": args.signalType,
377 },
378 "matching": (
379 "Truth muon <- MsTrkSeed_truthLink -> seed <- "
380 "ActsMuons_seedLink -> reconstructed track"
381 ),
382 "baseline": baseline,
383 "bucketFilter": {
384 "model": args.bucketModel,
385 "scoreThreshold": args.bucketThreshold,
386 "scoreThresholdSource": (
387 "CLI override" if args.bucketThreshold is not None
388 else "GraphBucketFilterToolCfg default"
389 ),
390 },
391 "segmentEdgeGraph": {
392 "ReadSpacePoints": "FilteredMlBuckets",
393 "OrderingSpacePoints": "MuonSpacePoints",
394 "note": (
395 "Applied by muonEdgeRecoChain.py when bucket filtering and "
396 "edge inference are enabled."
397 ),
398 },
399 "targetRelativeEfficiencyLoss": args.targetRelativeEfficiencyLoss,
400 "selectionPolicy": (
401 "largest (edgeThreshold, overlapThreshold) pair with relative "
402 "signal-efficiency loss at or below the target"
403 ),
404 "best": best,
405 "scanCsv": str(csv_path),
406 }
407 (work_dir / "summary.json").write_text(
408 json.dumps(summary, indent=2),
409 encoding="utf-8",
410 )
411 print(json.dumps(summary, indent=2))
412
413
414if __name__ == "__main__":
415 main()
void print(char *figname, TCanvas *c1)
#define max(a, b)
Definition cfImp.cxx:41
std::string replace(std::string s, const std::string &s2, const std::string &s3)
Definition hcg.cxx:312
dict[str, float|int|str|bool] _result_row(float edge_threshold, float overlap_threshold, dict[str, float|int] baseline, dict[str, float|int] edge, float target_relative_efficiency_loss, Path root_file, Path log_file)
int _run(list[str] command, Path log_path)
list[str] _chain_command(args, Path out_root, bool edge, float|None edge_threshold=None, float|None overlap_threshold=None)
list[float] _parse_float_list(str raw)
dict[str, float|int] _signal_muon_efficiency(Path root_file, str tree_name, int signal_origin, int signal_type)