ATLAS Offline Software
Loading...
Searching...
No Matches
muonBucketInference.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3if __name__ == "__main__":
4 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
5 parser = SetupArgParser()
6 parser.set_defaults(nEvents = -1)
7 parser.set_defaults(outRootFile="InferenceHoughTest.root")
8 parser.set_defaults(noMM=True)
9 parser.set_defaults(noSTGC=True)
10 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
11
12 parser.add_argument("--noPerfMon", help="If set to true, full perfmonMT is enabled",
13 default=False, action='store_true')
14 parser.add_argument("--use-gpu", action="store_true", default=True,
15 help="Use GPU for ONNX inference (default: True)")
16 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
17 help="Use CPU for ONNX inference")
18
19 args = parser.parse_args()
20
21 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
22 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
23 from AthenaConfiguration.AllConfigFlags import initConfigFlags
24
25 flags = initConfigFlags()
26 flags.PerfMon.doFullMonMT = not args.noPerfMon
27
28 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
29 # Determine whether user requested GPU (parser sets args.use_gpu)
30 use_gpu_requested = getattr(args, "use_gpu", True)
31 # Runtime check for GPU availability. Prefer ONNXRuntime provider list,
32 # fall back to PyTorch if ONNX runtime isn't available.
33 gpu_available = False
34 try:
35 import onnxruntime as ort
36 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
37 except Exception:
38 try:
39 import torch
40 gpu_available = torch.cuda.is_available()
41 except Exception:
42 gpu_available = False
43 if use_gpu_requested and gpu_available:
44 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
45 else:
46 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
47
48 flags, cfg = setupGeoR4TestCfg(args,flags)
49
50 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
51 cfg.merge(xAODUncalibMeasPrepCfg(flags))
52
53 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
54 cfg.merge(MuonSpacePointFormationCfg(flags))
55
56 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
57 bucketTool = cfg.popToolsAndMerge(
58 GraphBucketFilterToolCfg(
59 flags,
60 )
61 )
62 cfg.merge(
63 GraphInferenceAlgCfg(
64 flags,
65 InferenceTools=[bucketTool],
66 )
67 )
68
69 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
70 cfg.merge(MuonPatternRecognitionCfg(flags))
71 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlBuckets"
72
73 cfg.merge(setupHistSvcCfg( flags, outFile=args.outRootFile,
74 outStream="MuonEtaHoughTransformTest"))
75
76 from MuonPatternRecognitionTest.PatternTestConfig import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
77
78 cfg.merge(MuonHoughTransformTesterCfg( flags,
79 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
80
81 executeTest(cfg)
82
83