4if __name__ ==
"__main__":
5 from MuonGeoModelTestR4.testGeoModel
import SetupArgParser, MuonPhaseIITestDefaults
6 parser = SetupArgParser()
7 parser.set_defaults(nEvents = -1)
8 parser.set_defaults(outRootFile=
"InferenceHoughTest.root")
9 parser.set_defaults(noMM=
True)
10 parser.set_defaults(noSTGC=
True)
11 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
13 parser.add_argument(
"--noPerfMon", help=
"If set to true, full perfmonMT is enabled",
14 default=
False, action=
'store_true')
15 parser.add_argument(
"--use-gpu", action=
"store_true", default=
True,
16 help=
"Use GPU for ONNX inference (default: True)")
17 parser.add_argument(
"--use-cpu", dest=
"use_gpu", action=
"store_false",
18 help=
"Use CPU for ONNX inference")
20 args = parser.parse_args()
22 from MuonGeoModelTestR4.testGeoModel
import setupGeoR4TestCfg
23 from MuonConfig.MuonConfigUtils
import executeTest, setupHistSvcCfg
24 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
26 flags = initConfigFlags()
27 flags.PerfMon.doFullMonMT =
not args.noPerfMon
29 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
31 use_gpu_requested = getattr(args,
"use_gpu",
True)
36 import onnxruntime
as ort
37 gpu_available =
"CUDAExecutionProvider" in ort.get_available_providers()
41 gpu_available = torch.cuda.is_available()
44 if use_gpu_requested
and gpu_available:
45 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
47 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
49 flags, cfg = setupGeoR4TestCfg(args,flags)
51 from MuonConfig.MuonDataPrepConfig
import xAODUncalibMeasPrepCfg
52 cfg.merge(xAODUncalibMeasPrepCfg(flags))
54 from MuonSpacePointFormation.SpacePointFormationConfig
import MuonSpacePointFormationCfg
55 cfg.merge(MuonSpacePointFormationCfg(flags))
56 from MuonInference.InferenceConfig
import GraphSPFilterToolCfg, GraphInferenceAlgCfg
57 cfg.merge(GraphInferenceAlgCfg(flags,InferenceTools = [cfg.popToolsAndMerge(GraphSPFilterToolCfg(flags))]))
59 from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig
import MuonPatternRecognitionCfg
60 cfg.merge(MuonPatternRecognitionCfg(flags))
61 cfg.getEventAlgo(
"MuonEtaHoughTransformAlg").SpacePointContainer =
"FilteredMlSpacePoints"
63 cfg.merge(setupHistSvcCfg( flags,outFile=args.outRootFile,
64 outStream=
"MuonEtaHoughTransformTest"))
66 from MuonPatternRecognitionTest.PatternTestConfig
import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
68 cfg.merge(MuonHoughTransformTesterCfg( flags,
69 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))