ATLAS Offline Software
Loading...
Searching...
No Matches
muonSegmentDump.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3
4
5def main(args):
6 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
7 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
8 from AthenaConfiguration.AllConfigFlags import initConfigFlags
9 flags = initConfigFlags()
10 flags.PerfMon.doFullMonMT = False
11
12 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
13 use_gpu_requested = getattr(args, "use_gpu", True)
14 gpu_available = False
15 try:
16 import onnxruntime as ort
17 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
18 except Exception:
19 try:
20 import torch
21 gpu_available = torch.cuda.is_available()
22 except Exception:
23 gpu_available = False
24 if use_gpu_requested and gpu_available:
25 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
26 else:
27 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
28
29 flags, cfg = setupGeoR4TestCfg(args)
30
31 cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile,
32 outStream="MuonSegmentDump"))
33
34 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
35 cfg.merge(xAODUncalibMeasPrepCfg(flags))
36
37 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
38 cfg.merge(MuonSpacePointFormationCfg(flags))
39
40 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
41
42 do_ml_bucket_filter = bool(getattr(args, "doMLBucketFilter", False) or
43 getattr(args, "bucketModel", None) is not None or
44 getattr(args, "bucketThreshold", None) is not None)
45 if do_ml_bucket_filter:
46 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
47 bucket_tool_kwargs = {"WriteSpacePointKey": "FilteredMlBuckets"}
48 if getattr(args, "bucketModel", None) is not None:
49 bucket_tool_kwargs["ModelPath"] = args.bucketModel
50 if getattr(args, "bucketThreshold", None) is not None:
51 bucket_tool_kwargs["ScoreThreshold"] = args.bucketThreshold
52 bucket_tool = cfg.popToolsAndMerge(
53 GraphBucketFilterToolCfg(
54 flags,
55 **bucket_tool_kwargs,
56 )
57 )
58 cfg.merge(
59 GraphInferenceAlgCfg(
60 flags,
61 InferenceTools=[bucket_tool],
62 )
63 )
64 # Re-run pattern recognition on filtered buckets so dumped segments
65 # correspond to the same filtered container.
66 cfg.merge(MuonPatternRecognitionCfg(flags))
67 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlBuckets"
68 else:
69 cfg.merge(MuonPatternRecognitionCfg(flags))
70
71 # Truth information if MC
72 if flags.Input.isMC:
73 from MuonTruthAlgsR4.MuonTruthAlgsConfig import MuonTruthAlgsCfg
74 cfg.merge(MuonTruthAlgsCfg(flags))
75
76 from MuonBucketDump.MuonBucketDumpConfig import MuonSegmentDumpCfg
77 if do_ml_bucket_filter:
78 cfg.merge(MuonSegmentDumpCfg(flags, SpacePointKeys=["FilteredMlBuckets"]))
79 else:
80 cfg.merge(MuonSegmentDumpCfg(flags))
81
82 executeTest(cfg)
83
84if __name__=="__main__":
85 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
86 parser = SetupArgParser()
87 parser.set_defaults(nEvents = -1)
88 parser.set_defaults(outRootFile="MuonSegmentDump_R3SimHits.root")
89
90 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
91 parser.add_argument("--doMLBucketFilter", action="store_true", default=False,
92 help="Run ML bucket filtering and dump segments from filtered buckets.")
93 parser.add_argument("--bucketModel", type=str, default=None,
94 help="Path to ONNX model used by the ML bucket filter.")
95 parser.add_argument("--bucketThreshold", type=float, default=None,
96 help="Score threshold for single-output bucket filtering.")
97 parser.add_argument("--use-gpu", action="store_true", default=True,
98 help="Use GPU for ONNX inference when available (default: True)")
99 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
100 help="Force CPU for ONNX inference")
101
102 args = parser.parse_args()
103 main(args)
104
105
int main()
Definition hello.cxx:18