6 from MuonGeoModelTestR4.testGeoModel
import setupGeoR4TestCfg
7 from MuonConfig.MuonConfigUtils
import executeTest, setupHistSvcCfg
8 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
9 flags = initConfigFlags()
10 flags.PerfMon.doFullMonMT =
False
11 includeG4TrackTruth = bool(getattr(args,
"includeG4TrackTruth",
False))
12 if includeG4TrackTruth:
14 flags.Muon.includePileUpTruth =
True
16 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
17 use_gpu_requested = getattr(args,
"use_gpu",
True)
20 import onnxruntime
as ort
21 gpu_available =
"CUDAExecutionProvider" in ort.get_available_providers()
25 gpu_available = torch.cuda.is_available()
28 if use_gpu_requested
and gpu_available:
29 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
31 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
33 flags, cfg = setupGeoR4TestCfg(args)
35 cfg.merge(setupHistSvcCfg(flags, outFile=args.outRootFile, outStream=
"MuonSegmentDump"))
37 from MuonConfig.MuonDataPrepConfig
import xAODUncalibMeasPrepCfg
38 cfg.merge(xAODUncalibMeasPrepCfg(flags))
40 from MuonSpacePointFormation.SpacePointFormationConfig
import MuonSpacePointFormationCfg
41 cfg.merge(MuonSpacePointFormationCfg(flags))
43 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig
import MuonPatternRecognitionCfg
45 do_ml_bucket_filter = bool(getattr(args,
"doMLBucketFilter",
False)
or
46 getattr(args,
"bucketModel",
None)
is not None or
47 getattr(args,
"bucketThreshold",
None)
is not None)
48 if do_ml_bucket_filter:
49 from MuonInference.InferenceConfig
import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
50 bucket_tool_kwargs = {
"WriteSpacePointKey":
"FilteredMlBuckets"}
51 if getattr(args,
"bucketModel",
None)
is not None:
52 bucket_tool_kwargs[
"ModelPath"] = args.bucketModel
53 if getattr(args,
"bucketThreshold",
None)
is not None:
54 bucket_tool_kwargs[
"ScoreThreshold"] = args.bucketThreshold
55 bucket_tool = cfg.popToolsAndMerge(
56 GraphBucketFilterToolCfg(
64 InferenceTools=[bucket_tool],
69 cfg.merge(MuonPatternRecognitionCfg(flags))
70 cfg.getEventAlgo(
"MuonEtaHoughTransformAlg").SpacePointContainer =
"FilteredMlBuckets"
72 cfg.merge(MuonPatternRecognitionCfg(flags))
76 from MuonTruthAlgsR4.MuonTruthAlgsConfig
import MuonTruthAlgsCfg
77 cfg.merge(MuonTruthAlgsCfg(flags))
79 from MuonBucketDump.MuonBucketDumpConfig
import MuonSegmentDumpCfg
80 dumper_kwargs = {
"IncludeG4TrackTruth": includeG4TrackTruth}
82 if do_ml_bucket_filter:
83 cfg.merge(MuonSegmentDumpCfg(flags, SpacePointKeys=[
"FilteredMlBuckets"],
86 cfg.merge(MuonSegmentDumpCfg(flags, **dumper_kwargs))
92 from MuonGeoModelTestR4.testGeoModel
import SetupArgParser, MuonPhaseIITestDefaults