ATLAS Offline Software
Loading...
Searching...
No Matches
muonBucketInference.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3if __name__ == "__main__":
4 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
5 parser = SetupArgParser()
6 parser.set_defaults(nEvents = -1)
7 parser.set_defaults(outRootFile="InferenceHoughTest.root")
8 parser.set_defaults(noMM=True)
9 parser.set_defaults(noSTGC=True)
10 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
11
12 parser.add_argument("--noPerfMon", help="If set to true, full perfmonMT is enabled",
13 default=False, action='store_true')
14 parser.add_argument("--use-gpu", action="store_true", dest="use_gpu", default=None,
15 help="Use GPU for ONNX inference (default: True)")
16 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
17 help="Use CPU for ONNX inference")
18
19 args = parser.parse_args()
20
21 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
22 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
23 from AthenaConfiguration.AllConfigFlags import initConfigFlags
24
25 flags = initConfigFlags()
26 flags.PerfMon.doFullMonMT = not args.noPerfMon
27
28 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
29 # Use command line argument if provided, otherwise default to True
30 use_gpu_requested = args.use_gpu if args.use_gpu is not None else True
31 # Runtime check for GPU availability. Prefer ONNXRuntime provider list,
32 # fall back to PyTorch if ONNX runtime isn't available.
33 gpu_available = False
34 try:
35 import onnxruntime as ort
36 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
37 except Exception:
38 try:
39 import torch
40 gpu_available = torch.cuda.is_available()
41 except Exception:
42 gpu_available = False
43 if use_gpu_requested and gpu_available:
44 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
45 else:
46 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
47
48 flags, cfg = setupGeoR4TestCfg(args,flags)
49
50 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
51 cfg.merge(xAODUncalibMeasPrepCfg(flags))
52
53 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
54 cfg.merge(MuonSpacePointFormationCfg(flags))
55
56 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
57 bucketTool = cfg.popToolsAndMerge(GraphBucketFilterToolCfg(flags))
58 cfg.merge(GraphInferenceAlgCfg(flags, InferenceTools=[bucketTool]))
59
60 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
61 cfg.merge(MuonPatternRecognitionCfg(flags))
62 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlBuckets"
63
64 cfg.merge(setupHistSvcCfg( flags, outFile=args.outRootFile,
65 outStream="MuonEtaHoughTransformTest"))
66
67 from MuonPatternRecognitionTest.PatternTestConfig import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
68
69 cfg.merge(MuonHoughTransformTesterCfg( flags,
70 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
71
72 executeTest(cfg)
73
74