ATLAS Offline Software
Loading...
Searching...
No Matches
muonSPInference.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3
4if __name__ == "__main__":
5 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
6 parser = SetupArgParser()
7 parser.set_defaults(nEvents = -1)
8 parser.set_defaults(outRootFile="InferenceHoughTest.root")
9 parser.set_defaults(noMM=True)
10 parser.set_defaults(noSTGC=True)
11 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
12
13 parser.add_argument("--noPerfMon", help="If set to true, full perfmonMT is enabled",
14 default=False, action='store_true')
15 parser.add_argument("--use-gpu", action="store_true", default=True,
16 help="Use GPU for ONNX inference (default: True)")
17 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
18 help="Use CPU for ONNX inference")
19
20 args = parser.parse_args()
21
22 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
23 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
24 from AthenaConfiguration.AllConfigFlags import initConfigFlags
25
26 flags = initConfigFlags()
27 flags.PerfMon.doFullMonMT = not args.noPerfMon
28
29 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
30 # Determine whether user requested GPU (parser sets args.use_gpu)
31 use_gpu_requested = getattr(args, "use_gpu", True)
32 # Runtime check for GPU availability. Prefer ONNXRuntime provider list,
33 # fall back to PyTorch if ONNX runtime isn't available.
34 gpu_available = False
35 try:
36 import onnxruntime as ort
37 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
38 except Exception:
39 try:
40 import torch
41 gpu_available = torch.cuda.is_available()
42 except Exception:
43 gpu_available = False
44 if use_gpu_requested and gpu_available:
45 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
46 else:
47 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
48
49 flags, cfg = setupGeoR4TestCfg(args,flags)
50
51 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
52 cfg.merge(xAODUncalibMeasPrepCfg(flags))
53
54 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
55 cfg.merge(MuonSpacePointFormationCfg(flags))
56 from MuonInference.InferenceConfig import GraphSPFilterToolCfg, GraphInferenceAlgCfg
57 cfg.merge(GraphInferenceAlgCfg(flags,InferenceTools = [cfg.popToolsAndMerge(GraphSPFilterToolCfg(flags))]))
58
59 from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig import MuonPatternRecognitionCfg
60 cfg.merge(MuonPatternRecognitionCfg(flags))
61 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlSpacePoints"
62
63 cfg.merge(setupHistSvcCfg( flags,outFile=args.outRootFile,
64 outStream="MuonEtaHoughTransformTest"))
65
66 from MuonPatternRecognitionTest.PatternTestConfig import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
67
68 cfg.merge(MuonHoughTransformTesterCfg( flags,
69 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
70
71 executeTest(cfg)
72
73