ATLAS Offline Software
Loading...
Searching...
No Matches
muonSPInference.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3
4if __name__ == "__main__":
5 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
6 parser = SetupArgParser()
7 parser.set_defaults(nEvents = -1)
8 parser.set_defaults(outRootFile="InferenceHoughTest.root")
9 parser.set_defaults(noMM=True)
10 parser.set_defaults(noSTGC=True)
11 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
12
13 parser.add_argument("--noPerfMon", help="If set to true, full perfmonMT is enabled",
14 default=False, action='store_true')
15 parser.add_argument("--use-cpu", action="store_true", default=False,
16 help="Use CPU for ONNX inference")
17
18 args = parser.parse_args()
19
20 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
21 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
22 from AthenaConfiguration.AllConfigFlags import initConfigFlags
23
24 flags = initConfigFlags()
25 flags.PerfMon.doFullMonMT = not args.noPerfMon
26
27 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
28 use_gpu_requested = not args.use_cpu
29 # Runtime check for GPU availability. Prefer ONNXRuntime provider list,
30 # fall back to PyTorch if ONNX runtime isn't available.
31 gpu_available = False
32 try:
33 import onnxruntime as ort
34 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
35 except Exception:
36 try:
37 import torch
38 gpu_available = torch.cuda.is_available()
39 except Exception:
40 gpu_available = False
41 if use_gpu_requested and gpu_available:
42 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
43 else:
44 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
45
46 flags, cfg = setupGeoR4TestCfg(args,flags)
47
48 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
49 cfg.merge(xAODUncalibMeasPrepCfg(flags))
50
51 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
52 cfg.merge(MuonSpacePointFormationCfg(flags))
53 from MuonInference.InferenceConfig import GraphSPFilterToolCfg, GraphInferenceAlgCfg
54 cfg.merge(GraphInferenceAlgCfg(flags,InferenceTools = [cfg.popToolsAndMerge(GraphSPFilterToolCfg(flags))]))
55
56 from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig import MuonPatternRecognitionCfg
57 cfg.merge(MuonPatternRecognitionCfg(flags))
58 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlSpacePoints"
59
60 cfg.merge(setupHistSvcCfg( flags,outFile=args.outRootFile,
61 outStream="MuonEtaHoughTransformTest"))
62
63 from MuonPatternRecognitionTest.PatternTestConfig import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
64
65 cfg.merge(MuonHoughTransformTesterCfg( flags,
66 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
67
68 executeTest(cfg)
69
70