33def _run(command: list[str], log_path: Path) -> int:
34 log_path.parent.mkdir(parents=
True, exist_ok=
True)
35 with log_path.open(
"w", encoding=
"utf-8")
as log:
36 process = subprocess.run(
39 stderr=subprocess.STDOUT,
42 return process.returncode
48 signal_type: int) -> dict[str, float | int]:
49 """Measure the same signal-muon efficiency used in the comparison notebook.
50 A denominator muon is a truth muon satisfying:
52 TruthMuons_truthOrigin == signal_origin
53 TruthMuons_truthType == signal_type
55 It is reconstructed when an ActsMuons track points to a seed via
56 ActsMuons_seedLink and that seed points back to the truth-muon index via
57 MsTrkSeed_truthLink. All links are local to one event.
62 input_file = ROOT.TFile.Open(str(root_file),
"READ")
63 if not input_file
or input_file.IsZombie():
64 raise RuntimeError(f
"Could not open ROOT file: {root_file}")
65 tree = input_file.Get(tree_name)
66 if not tree
or not tree.InheritsFrom(
"TTree"):
69 f
"Could not find TTree '{tree_name}' in {root_file}. "
70 "The reco chain must write the MsTrackValidTest tree."
74 branch
for branch
in REQUIRED_BRANCHES
if not tree.GetBranch(branch)
79 f
"Missing required branch(es) in {root_file}: {', '.join(missing_branches)}"
83 matched_signal_muons = 0
85 for entry_number
in range(tree.GetEntries()):
86 tree.GetEntry(entry_number)
89 int(value)
for value
in tree.TruthMuons_truthOrigin
92 int(value)
for value
in tree.TruthMuons_truthType
95 int(value)
for value
in tree.MsTrkSeed_truthLink
98 int(value)
for value
in tree.ActsMuons_seedLink
101 if len(truth_origins) != len(truth_types):
104 f
"Truth origin/type vector-size mismatch in entry {entry_number} "
108 signal_truth_indices = (
110 for truth_index, (origin, truth_type)
in enumerate(
111 zip(truth_origins, truth_types)
113 if origin == signal_origin
and truth_type == signal_type
116 for truth_index
in signal_truth_indices:
118 matching_seed_indices = {
120 for seed_index, linked_truth_index
in enumerate(seed_truth_links)
121 if linked_truth_index == truth_index
123 if matching_seed_indices
and any(
124 seed_index
in matching_seed_indices
125 for seed_index
in reco_seed_links
127 matched_signal_muons += 1
131 if signal_muons == 0:
133 f
"No signal truth muons found in {root_file}. "
134 f
"Selection is truthOrigin={signal_origin}, truthType={signal_type}."
138 "signalMuonCount": signal_muons,
139 "matchedSignalMuonCount": matched_signal_muons,
140 "signalMuonEfficiency": matched_signal_muons / signal_muons,
146 edge_threshold: float |
None =
None,
147 overlap_threshold: float |
None =
None) -> list[str]:
148 """Build a command equivalent to reco_chain_ec.sh or the no-ML baseline."""
149 recochain_script = Path(__file__).with_name(
"muonEdgeRecoChain.py")
153 str(recochain_script),
154 "--threads", str(args.threads),
155 "--nEvents", str(args.nEvents),
156 "--skipEvents", str(args.skipEvents),
157 "--inputFile", args.inputFile,
158 "--outRootFile", str(out_root),
159 "--defaultGeoFile", args.defaultGeoFile,
165 "--edgeModel", args.edgeModel,
166 "--edgeThreshold", str(edge_threshold),
167 "--overlapThreshold", str(overlap_threshold),
168 "--enableBucketFilter",
169 "--enableEdgeClassifier",
172 if args.bucketThreshold
is not None:
173 command += [
"--bucketThreshold", str(args.bucketThreshold)]
175 command += [
"--bucketModel", args.bucketModel]
179 "--disableBucketFilter",
180 "--disableEdgeClassifier",
187 command.append(
"--use-cpu")
192 overlap_threshold: float,
193 baseline: dict[str, float | int],
194 edge: dict[str, float | int],
195 target_relative_efficiency_loss: float,
197 log_file: Path) -> dict[str, float | int | str | bool]:
198 if edge[
"signalMuonCount"] != baseline[
"signalMuonCount"]:
200 "edgeThreshold": edge_threshold,
201 "overlapThreshold": overlap_threshold,
202 "status":
"truth_count_mismatch",
203 "baselineSignalMuonCount": baseline[
"signalMuonCount"],
204 "edgeSignalMuonCount": edge[
"signalMuonCount"],
205 "rootFile": str(root_file),
206 "log": str(log_file),
209 baseline_efficiency = float(baseline[
"signalMuonEfficiency"])
210 edge_efficiency = float(edge[
"signalMuonEfficiency"])
211 relative_efficiency_difference = (
212 (edge_efficiency - baseline_efficiency) / baseline_efficiency
214 relative_efficiency_loss =
max(0.0, -relative_efficiency_difference)
217 "edgeThreshold": edge_threshold,
218 "overlapThreshold": overlap_threshold,
220 "baselineSignalMuonCount": baseline[
"signalMuonCount"],
221 "baselineMatchedSignalMuons": baseline[
"matchedSignalMuonCount"],
222 "baselineSignalEfficiency": baseline_efficiency,
223 "edgeSignalMuonCount": edge[
"signalMuonCount"],
224 "edgeMatchedSignalMuons": edge[
"matchedSignalMuonCount"],
225 "edgeSignalEfficiency": edge_efficiency,
226 "absoluteEfficiencyDifference": edge_efficiency - baseline_efficiency,
227 "relativeEfficiencyDifference": relative_efficiency_difference,
228 "relativeEfficiencyLoss": relative_efficiency_loss,
229 "passesTarget": relative_efficiency_loss <= target_relative_efficiency_loss,
230 "rootFile": str(root_file),
231 "log": str(log_file),
235 parser = argparse.ArgumentParser(
236 description=(
"Tune SegmentEdge thresholds using signal-muon reconstruction "
237 "efficiency relative to the regular no-ML reconstruction."))
238 parser.add_argument(
"--inputFile", required=
True)
239 parser.add_argument(
"--bucketModel", default=
None, help=(
"Optional bucket-filter ONNX model."),)
240 parser.add_argument(
"--bucketThreshold",
"--score-threshold", dest=
"bucketThreshold", type=float,
241 default=
None, help=
"Bucket-filter score threshold")
242 parser.add_argument(
"--bucket-output-is-logit", dest=
"bucketOutputIsLogit", action=
"store_true", default=
False,
243 help=(
"Interpret the scalar bucket-model output as a logit"))
244 parser.add_argument(
"--edgeModel", required=
True)
245 parser.add_argument(
"--nEvents", type=int, default=100)
246 parser.add_argument(
"--skipEvents", type=int, default=0)
247 parser.add_argument(
"--threads", type=int, default=1)
248 parser.add_argument(
"--defaultGeoFile", default=
"RUN4")
249 parser.add_argument(
"--workDir", default=
"edge_threshold_tuning")
250 parser.add_argument(
"--edgeThresholds", default=
"0.08,0.10,0.119,0.14,0.16", help=
"Comma-separated recovery edge-probability thresholds to scan",)
251 parser.add_argument(
"--overlapThresholds", default=
"0.20,0.30,0.50,0.80", help=(
"Comma-separated high-purity core edge-probability thresholds to scan."))
252 parser.add_argument(
"--targetRelativeEfficiencyLoss",
"--targetLoss", dest=
"targetRelativeEfficiencyLoss",
253 type=float, default=0.05, help=(
"Maximum allowed relative loss in signal-muon efficiency"),)
254 parser.add_argument(
"--treeName", default=
"MsTrackValidTest")
255 parser.add_argument(
"--signalOrigin", type=int, default=13)
256 parser.add_argument(
"--signalType", type=int, default=6)
257 parser.add_argument(
"--use-cpu", dest=
"use_cpu", action=
"store_true", default=
False,)
258 parser.add_argument(
"--skipExisting", action=
"store_true")
259 args = parser.parse_args()
264 work_dir = Path(args.workDir).resolve()
265 work_dir.mkdir(parents=
True, exist_ok=
True)
267 baseline_root = work_dir /
"baseline_noml.root"
268 baseline_log = work_dir /
"baseline_noml.log"
269 baseline_command =
_chain_command(args, baseline_root, edge=
False)
270 if not args.skipExisting
or not baseline_root.exists():
271 return_code =
_run(baseline_command, baseline_log)
274 f
"No-ML baseline job failed with rc={return_code}. See {baseline_log}"
276 if not baseline_root.exists():
278 f
"No-ML baseline finished but ROOT output is missing: {baseline_root}. "
279 f
"See {baseline_log}"
284 rows: list[dict[str, float | int | str | bool]] = []
287 for edge_threshold
in edge_thresholds:
288 for overlap_threshold
in overlap_thresholds:
290 f
"edge{edge_threshold:.6f}_overlap{overlap_threshold:.6f}"
293 out_root = work_dir / f
"{tag}.root"
294 out_log = work_dir / f
"{tag}.log"
299 edge_threshold=edge_threshold,
300 overlap_threshold=overlap_threshold,
302 if not args.skipExisting
or not out_root.exists():
303 return_code =
_run(command, out_log)
306 "edgeThreshold": edge_threshold,
307 "overlapThreshold": overlap_threshold,
309 "rootFile": str(out_root),
313 if not out_root.exists():
315 "edgeThreshold": edge_threshold,
316 "overlapThreshold": overlap_threshold,
317 "status":
"missing_output",
318 "rootFile": str(out_root),
334 args.targetRelativeEfficiencyLoss,
341 if row[
"status"] ==
"ok" and row[
"passesTarget"]:
342 key = (edge_threshold, overlap_threshold)
343 if best
is None or key > (
344 best[
"edgeThreshold"],
345 best[
"overlapThreshold"],
349 csv_path = work_dir /
"edge_threshold_scan.csv"
354 "baselineSignalMuonCount",
355 "baselineMatchedSignalMuons",
356 "baselineSignalEfficiency",
357 "edgeSignalMuonCount",
358 "edgeMatchedSignalMuons",
359 "edgeSignalEfficiency",
360 "absoluteEfficiencyDifference",
361 "relativeEfficiencyDifference",
362 "relativeEfficiencyLoss",
367 with csv_path.open(
"w", newline=
"", encoding=
"utf-8")
as csv_file:
368 writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
370 writer.writerows(rows)
373 "treeName": args.treeName,
375 "truthOrigin": args.signalOrigin,
376 "truthType": args.signalType,
379 "Truth muon <- MsTrkSeed_truthLink -> seed <- "
380 "ActsMuons_seedLink -> reconstructed track"
382 "baseline": baseline,
384 "model": args.bucketModel,
385 "scoreThreshold": args.bucketThreshold,
386 "scoreThresholdSource": (
387 "CLI override" if args.bucketThreshold
is not None
388 else "GraphBucketFilterToolCfg default"
391 "segmentEdgeGraph": {
392 "ReadSpacePoints":
"FilteredMlBuckets",
393 "OrderingSpacePoints":
"MuonSpacePoints",
395 "Applied by muonEdgeRecoChain.py when bucket filtering and "
396 "edge inference are enabled."
399 "targetRelativeEfficiencyLoss": args.targetRelativeEfficiencyLoss,
401 "largest (edgeThreshold, overlapThreshold) pair with relative "
402 "signal-efficiency loss at or below the target"
405 "scanCsv": str(csv_path),
407 (work_dir /
"summary.json").write_text(
408 json.dumps(summary, indent=2),
411 print(json.dumps(summary, indent=2))
dict[str, float|int|str|bool] _result_row(float edge_threshold, float overlap_threshold, dict[str, float|int] baseline, dict[str, float|int] edge, float target_relative_efficiency_loss, Path root_file, Path log_file)