3if __name__ ==
"__main__":
4 from MuonGeoModelTestR4.testGeoModel
import SetupArgParser, MuonPhaseIITestDefaults
5 parser = SetupArgParser()
6 parser.set_defaults(nEvents = -1)
7 parser.set_defaults(outRootFile=
"InferenceHoughTest.root")
8 parser.set_defaults(noMM=
True)
9 parser.set_defaults(noSTGC=
True)
10 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
12 parser.add_argument(
"--noPerfMon", help=
"If set to true, full perfmonMT is enabled",
13 default=
False, action=
'store_true')
14 parser.add_argument(
"--use-gpu", action=
"store_true", default=
True,
15 help=
"Use GPU for ONNX inference (default: True)")
16 parser.add_argument(
"--use-cpu", dest=
"use_gpu", action=
"store_false",
17 help=
"Use CPU for ONNX inference")
19 args = parser.parse_args()
21 from MuonGeoModelTestR4.testGeoModel
import setupGeoR4TestCfg
22 from MuonConfig.MuonConfigUtils
import executeTest, setupHistSvcCfg
23 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
25 flags = initConfigFlags()
26 flags.PerfMon.doFullMonMT =
not args.noPerfMon
28 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
30 use_gpu_requested = getattr(args,
"use_gpu",
True)
35 import onnxruntime
as ort
36 gpu_available =
"CUDAExecutionProvider" in ort.get_available_providers()
40 gpu_available = torch.cuda.is_available()
43 if use_gpu_requested
and gpu_available:
44 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
46 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
48 flags, cfg = setupGeoR4TestCfg(args,flags)
50 from MuonConfig.MuonDataPrepConfig
import xAODUncalibMeasPrepCfg
51 cfg.merge(xAODUncalibMeasPrepCfg(flags))
53 from MuonSpacePointFormation.SpacePointFormationConfig
import MuonSpacePointFormationCfg
54 cfg.merge(MuonSpacePointFormationCfg(flags))
56 from MuonInference.InferenceConfig
import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
57 bucketTool = cfg.popToolsAndMerge(
58 GraphBucketFilterToolCfg(
65 InferenceTools=[bucketTool],
69 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig
import MuonPatternRecognitionCfg
70 cfg.merge(MuonPatternRecognitionCfg(flags))
71 cfg.getEventAlgo(
"MuonEtaHoughTransformAlg").SpacePointContainer =
"FilteredMlBuckets"
73 cfg.merge(setupHistSvcCfg( flags, outFile=args.outRootFile,
74 outStream=
"MuonEtaHoughTransformTest"))
76 from MuonPatternRecognitionTest.PatternTestConfig
import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
78 cfg.merge(MuonHoughTransformTesterCfg( flags,
79 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))