5 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
6 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
7 from AthenaConfiguration.AllConfigFlags import initConfigFlags
8 flags = initConfigFlags()
9 flags.PerfMon.doFullMonMT = True
10
11 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
12 use_gpu_requested = getattr(args, "use_gpu", True)
13 gpu_available = False
14 try:
15 import onnxruntime as ort
16 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
17 except Exception:
18 try:
19 import torch
20 gpu_available = torch.cuda.is_available()
21 except Exception:
22 gpu_available = False
23 if use_gpu_requested and gpu_available:
24 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
25 else:
26 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
27
28 flags, cfg = setupGeoR4TestCfg(args)
29
30 cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile,
31 outStream="MuonBucketDump"))
32
33 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
34 cfg.merge(xAODUncalibMeasPrepCfg(flags))
35
36 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
37 cfg.merge(MuonSpacePointFormationCfg(flags))
38
39 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
40 cfg.merge(MuonPatternRecognitionCfg(flags))
41
42 if getattr(args, "doMLBucketFilter", False):
43 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
44 bias = getattr(args, "mlBucketBias", 1.0)
45 bucketTool = cfg.popToolsAndMerge(
46 GraphBucketFilterToolCfg(
47 flags,
48 BiasClass0=bias,
49 WriteSpacePointKey="FilteredMlBuckets",
50 ModelPath="/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/MuonRecRTT/edgecnn_multi_bucket_sparse_meta.onnx"
51 )
52 )
53 cfg.merge(
54 GraphInferenceAlgCfg(
55 flags,
56 InferenceTools=[bucketTool],
57 )
58 )
59
60 from MuonBucketDump.MuonBucketDumpConfig import MuonBucketDumpCfg
61 from MuonPatternRecognitionTest.PatternTestConfig import PatternVisualizationToolCfg
62 cfg.merge(MuonBucketDumpCfg(flags,
63 DoCaloDump=getattr(args, "doCaloDump", False),
64 DoMLBucketScore=getattr(args, "doMLBucketScore", False),
65 DoMLBucketFilter=getattr(args, "doMLBucketFilter", False),
66 MLBucketBias=getattr(args, "mlBucketBias", 1.0),
67 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
68 if args.doTruthMuonVertexDump:
69 from MuonBucketDump.MuonBucketDumpConfig import TruthMuonVertexDumpCfg
70 cfg.merge(TruthMuonVertexDumpCfg(flags))
71
72 executeTest(cfg)
73