ATLAS Offline Software
Loading...
Searching...
No Matches
muonEdgeTuner Namespace Reference

Functions

list[float] _parse_float_list (str raw)
int _run (list[str] cmd, Path log_path)
 _iter_root_trees (Path root_file)
float _pt_to_gev (float pt, str units)
float _matched_ms_track_metric (tree, str truth_link_branch, str|None pt_branch, int truth_link_threshold, float min_pt_gev, str pt_units)
float _metric_from_root (Path root_file, str|None preferred_tree, str|None preferred_branch, args)
list[str_chain_cmd (args, Path out_root, bool edge, float|None edge_threshold=None, float|None overlap_threshold=None)
 main ()

Function Documentation

◆ _chain_cmd()

list[str] muonEdgeTuner._chain_cmd ( args,
Path out_root,
bool edge,
float | None edge_threshold = None,
float | None overlap_threshold = None )
protected

Definition at line 160 of file muonEdgeTuner.py.

162 overlap_threshold: float | None = None) -> list[str]:
163 recochain_script = Path(__file__).with_name("muonEdgeRecoChain.py")
164 cmd = [
165 sys.executable,
166 str(recochain_script),
167 "--inputFile", args.inputFile,
168 "--nEvents", str(args.nEvents),
169 "--outRootFile", str(out_root),
170 "--bucketModel", args.bucketModel,
171 "--bucketThreshold", str(args.bucketThreshold),
172 ]
173 if args.extraRecoArgs:
174 cmd += args.extraRecoArgs.split()
175
176 # GPU passthrough to muonEdgeRecoChain
177 if getattr(args, 'use_gpu', None) is True:
178 cmd.append("--use-gpu")
179 elif getattr(args, 'use_gpu', None) is False:
180 cmd.append("--use-cpu")
181
182 if edge:
183 cmd += [
184 "--edgeModel", args.edgeModel,
185 "--enableEdgeClassifier",
186 "--useMlSeeder",
187 "--edgeThreshold", str(edge_threshold),
188 "--overlapThreshold", str(overlap_threshold),
189 ]
190 else:
191 # Same upstream chain, but no edge inference and old seeder.
192 cmd += [
193 "--disableEdgeClassifier",
194 "--useOldSeeder",
195 ]
196 return cmd
197
198

◆ _iter_root_trees()

muonEdgeTuner._iter_root_trees ( Path root_file)
protected

Definition at line 31 of file muonEdgeTuner.py.

31def _iter_root_trees(root_file: Path):
32 import ROOT
33
34 f = ROOT.TFile.Open(str(root_file), "READ")
35 if not f or f.IsZombie():
36 raise RuntimeError(f"Could not open ROOT file: {root_file}")
37 if not f.GetListOfKeys() or f.GetListOfKeys().GetEntries() == 0:
38 f.Close()
39 raise RuntimeError(
40 f"ROOT file has no keys: {root_file}. "
41 "The Athena job finished, but no monitoring/tester tree was written. "
42 "Run muonEdgeRecoChain.py with --enableRecoChainTester, or pass an "
43 "output configuration that writes a tree containing the track-count metric."
44 )
45
46 def walk(directory, prefix=""):
47 for key in directory.GetListOfKeys():
48 obj = key.ReadObj()
49 name = key.GetName()
50 full = f"{prefix}/{name}" if prefix else name
51 if obj.InheritsFrom("TTree"):
52 yield full, obj
53 elif obj.InheritsFrom("TDirectory"):
54 yield from walk(obj, full)
55
56 yield from walk(f)
57 f.Close()
58
59
TGraphErrors * GetEntries(TH2F *histo)

◆ _matched_ms_track_metric()

float muonEdgeTuner._matched_ms_track_metric ( tree,
str truth_link_branch,
str | None pt_branch,
int truth_link_threshold,
float min_pt_gev,
str pt_units )
protected

Definition at line 69 of file muonEdgeTuner.py.

74 pt_units: str) -> float:
75 branches = {b.GetName() for b in tree.GetListOfBranches()}
76 if truth_link_branch not in branches:
77 raise RuntimeError(f"Missing truth-link branch '{truth_link_branch}'")
78 if pt_branch and pt_branch not in branches:
79 raise RuntimeError(f"Missing pT branch '{pt_branch}'")
80
81 total = 0
82 for entry in tree:
83 links = list(getattr(entry, truth_link_branch))
84 pts = list(getattr(entry, pt_branch)) if pt_branch else [None] * len(links)
85
86 for i, link in enumerate(links):
87 if int(link) < truth_link_threshold:
88 continue
89 if pt_branch:
90 pt_gev = _pt_to_gev(float(pts[i]), pt_units)
91 if pt_gev < min_pt_gev:
92 continue
93 total += 1
94 return float(total)
95
96

◆ _metric_from_root()

float muonEdgeTuner._metric_from_root ( Path root_file,
str | None preferred_tree,
str | None preferred_branch,
args )
protected
Generic ROOT metric reader.

Preferred usage:
  --metricTree <tree> --metricBranch <branch>

If not provided, the script tries common branch names. This is intentionally
external to Athena so the tuning loop can run many complete jobs.

Definition at line 97 of file muonEdgeTuner.py.

100 args) -> float:
101 """
102 Generic ROOT metric reader.
103
104 Preferred usage:
105 --metricTree <tree> --metricBranch <branch>
106
107 If not provided, the script tries common branch names. This is intentionally
108 external to Athena so the tuning loop can run many complete jobs.
109 """
110
111 branch_candidates = []
112 if preferred_branch:
113 branch_candidates.append(preferred_branch)
114 branch_candidates += [
115 "nMsTracks",
116 "nMSTracks",
117 "nTracks",
118 "nRecoTracks",
119 "nMuonTracks",
120 ]
121
122 trees = list(_iter_root_trees(root_file))
123 if preferred_tree:
124 trees = [(name, tree) for name, tree in trees if name == preferred_tree or name.endswith("/" + preferred_tree)]
125 if not trees:
126 raise RuntimeError(f"No matching TTree found in {root_file}")
127
128 if args.metricMode == "matchedTruthTracks":
129 for _, tree in trees:
130 return _matched_ms_track_metric(tree,
131 args.truthLinkBranch,
132 args.trackPtBranch,
133 args.truthLinkThreshold,
134 args.minPtGeV,
135 args.ptUnits)
136
137 for tree_name, tree in trees:
138 branches = {b.GetName() for b in tree.GetListOfBranches()}
139 for branch in branch_candidates:
140 if branch not in branches:
141 continue
142 total = 0.0
143 for entry in tree:
144 val = getattr(entry, branch)
145 try:
146 total += float(val)
147 except TypeError:
148 total += float(len(val))
149 return total
150
151 available = {}
152 for tree_name, tree in trees:
153 available[tree_name] = [b.GetName() for b in tree.GetListOfBranches()]
154 raise RuntimeError(
155 "Could not find a metric branch. Pass --metricTree/--metricBranch. "
156 f"Available branches: {json.dumps(available, indent=2)}"
157 )
158
159

◆ _parse_float_list()

list[float] muonEdgeTuner._parse_float_list ( str raw)
protected

Definition at line 20 of file muonEdgeTuner.py.

20def _parse_float_list(raw: str) -> list[float]:
21 return [float(x) for x in raw.split(",") if x.strip()]
22
23

◆ _pt_to_gev()

float muonEdgeTuner._pt_to_gev ( float pt,
str units )
protected

Definition at line 60 of file muonEdgeTuner.py.

60def _pt_to_gev(pt: float, units: str) -> float:
61 if units == "MeV":
62 return pt / 1000.0
63 if units == "GeV":
64 return pt
65 # auto: ATLAS track pT branches are usually MeV
66 return pt / 1000.0 if abs(pt) > 200.0 else pt
67
68

◆ _run()

int muonEdgeTuner._run ( list[str] cmd,
Path log_path )
protected

Definition at line 24 of file muonEdgeTuner.py.

24def _run(cmd: list[str], log_path: Path) -> int:
25 log_path.parent.mkdir(parents=True, exist_ok=True)
26 with log_path.open("w", encoding="utf-8") as log:
27 proc = subprocess.run(cmd, stdout=log, stderr=subprocess.STDOUT, text=True)
28 return proc.returncode
29
30

◆ main()

muonEdgeTuner.main ( )

Definition at line 199 of file muonEdgeTuner.py.

199def main():
200 parser = argparse.ArgumentParser(
201 description="Tune SegmentEdge thresholds by comparing edge-chain track loss to baseline."
202 )
203 parser.add_argument("--inputFile", required=True)
204 parser.add_argument("--bucketModel", required=True)
205 parser.add_argument("--bucketThreshold", "--score-threshold", dest="bucketThreshold", type=float, default=0.0,
206 help="Threshold on bucket filter score")
207 parser.add_argument("--edgeModel", required=True)
208 parser.add_argument("--nEvents", type=int, default=100)
209 parser.add_argument("--workDir", default="edge_threshold_tuning")
210 parser.add_argument("--edgeThresholds", default="0.10,0.15,0.20,0.25,0.30,0.35,0.40,0.50")
211 parser.add_argument("--overlapThresholds", default="0.60,0.70,0.80,0.90")
212 parser.add_argument("--targetLoss", type=float, default=0.001,
213 help="Maximum allowed relative loss, default 0.001 = 0.1 percent")
214 parser.add_argument("--metricTree", default=None)
215 parser.add_argument("--metricBranch", default=None)
216 parser.add_argument("--metricMode", default="matchedTruthTracks",
217 choices=["matchedTruthTracks", "rawTrackCount"],
218 help="matchedTruthTracks counts MS tracks with truthLink >= threshold and pT cut")
219 parser.add_argument("--truthLinkBranch", default="MSTrksR4_truthLink")
220 parser.add_argument("--trackPtBranch", default="MSTrksR4_pt")
221 parser.add_argument("--truthLinkThreshold", type=int, default=1)
222 parser.add_argument("--minPtGeV", type=float, default=2.0)
223 parser.add_argument("--ptUnits", default="auto",
224 choices=["auto", "MeV", "GeV"])
225 parser.add_argument("--extraRecoArgs", default="",
226 help="Extra args forwarded to muonEdgeRecoChain.py")
227 parser.add_argument("--use-gpu", action="store_true", dest="use_gpu", default=None,
228 help="Use GPU for ONNX inference in the reco chain (default: auto-detect)")
229 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
230 help="Force CPU for ONNX inference in the reco chain")
231 parser.add_argument("--skipExisting", action="store_true")
232 args = parser.parse_args()
233
234 work = Path(args.workDir).resolve()
235 work.mkdir(parents=True, exist_ok=True)
236
237 baseline_root = work / "baseline.root"
238 baseline_log = work / "baseline.log"
239 baseline_cmd = _chain_cmd(args, baseline_root, edge=False)
240 if not args.skipExisting or not baseline_root.exists():
241 rc = _run(baseline_cmd, baseline_log)
242 if rc != 0:
243 raise SystemExit(f"Baseline job failed with rc={rc}. See {baseline_log}")
244 if not baseline_root.exists():
245 raise SystemExit(
246 f"Baseline job finished but output ROOT file is missing: {baseline_root}. "
247 f"Ensure reco args produce this output. "
248 f"See {baseline_log}"
249 )
250
251 baseline_metric = _metric_from_root(baseline_root, args.metricTree, args.metricBranch, args)
252 if baseline_metric <= 0:
253 raise SystemExit(f"Baseline metric is non-positive: {baseline_metric}")
254
255 rows = []
256 best = None
257 for edge_thr in _parse_float_list(args.edgeThresholds):
258 for overlap_thr in _parse_float_list(args.overlapThresholds):
259 tag = f"edge{edge_thr:.3f}_overlap{overlap_thr:.3f}".replace(".", "p")
260 out_root = work / f"{tag}.root"
261 out_log = work / f"{tag}.log"
262 cmd = _chain_cmd(args, out_root, edge=True,
263 edge_threshold=edge_thr,
264 overlap_threshold=overlap_thr)
265 if not args.skipExisting or not out_root.exists():
266 rc = _run(cmd, out_log)
267 if rc != 0:
268 rows.append({
269 "edgeThreshold": edge_thr,
270 "overlapThreshold": overlap_thr,
271 "status": "failed",
272 "log": str(out_log),
273 })
274 continue
275 if not out_root.exists():
276 rows.append({
277 "edgeThreshold": edge_thr,
278 "overlapThreshold": overlap_thr,
279 "status": "missing_output",
280 "log": str(out_log),
281 "rootFile": str(out_root),
282 })
283 continue
284
285 metric = _metric_from_root(out_root, args.metricTree, args.metricBranch, args)
286 loss = max(0.0, (baseline_metric - metric) / baseline_metric)
287 row = {
288 "edgeThreshold": edge_thr,
289 "overlapThreshold": overlap_thr,
290 "status": "ok",
291 "baselineMetric": baseline_metric,
292 "edgeMetric": metric,
293 "relativeLoss": loss,
294 "passesTarget": loss < args.targetLoss,
295 "rootFile": str(out_root),
296 "log": str(out_log),
297 }
298 rows.append(row)
299
300 if row["passesTarget"]:
301 # Prefer the largest EdgeThreshold that still passes, then largest OverlapThreshold.
302 key = (edge_thr, overlap_thr)
303 if best is None or key > (best["edgeThreshold"], best["overlapThreshold"]):
304 best = row
305
306 csv_path = work / "edge_threshold_scan.csv"
307 with csv_path.open("w", newline="", encoding="utf-8") as f:
308 writer = csv.DictWriter(f, fieldnames=sorted({k for r in rows for k in r}))
309 writer.writeheader()
310 writer.writerows(rows)
311
312 summary = {
313 "baselineMetric": baseline_metric,
314 "targetLoss": args.targetLoss,
315 "best": best,
316 "scanCsv": str(csv_path),
317 }
318 (work / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
319
320 print(json.dumps(summary, indent=2))
321
322
void print(char *figname, TCanvas *c1)
#define max(a, b)
Definition cfImp.cxx:41
std::string replace(std::string s, const std::string &s2, const std::string &s3)
Definition hcg.cxx:312
int main()
Definition hello.cxx:18