6 from MuonGeoModelTestR4.testGeoModel
import setupGeoR4TestCfg
7 from MuonConfig.MuonConfigUtils
import executeTest, setupHistSvcCfg
8 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
9 flags = initConfigFlags()
11 run_bucket_filter = args.enableBucketFilter
and not args.skip_onnx
12 run_edge_classifier = args.enableEdgeClassifier
and not args.skip_onnx
13 run_ml_seeder = args.useMlSeeder
and not args.skip_onnx
15 if args.skip_onnx
and (args.enableBucketFilter
or args.enableEdgeClassifier):
16 print(
"INFO: --skip-onnx requested. Disabling bucket filter and edge classifier inference stages.")
17 if args.skip_onnx
and args.useMlSeeder:
18 print(
"INFO: --skip-onnx requested. Switching to legacy seeder for a non-ONNX baseline.")
21 flags.Exec.DebugMessageComponents = [
23 "GraphInferenceAlg.GraphBucketFilterTool",
24 "GraphInferenceAlg.GraphBucketFilterTool.OnnxRuntimeSessionToolCPU",
25 "GraphInferenceAlg.GraphBucketFilterTool.OnnxRuntimeSessionToolCUDA",
26 "SegmentEdgeInferenceAlg",
27 "SegmentEdgeInferenceAlg.SegmentEdgeClassifierTool",
28 "SegmentEdgeInferenceAlg.SegmentEdgeClassifierTool.OnnxRuntimeSessionToolCPU",
29 "SegmentEdgeInferenceAlg.SegmentEdgeClassifierTool.OnnxRuntimeSessionToolCUDA",
30 "SegmentEdgeInferenceAlg.SegmentTrackCandidateBuilderTool",
32 "MSTrackFinderAlg.MlMsTrackSeeder",
35 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
36 if run_bucket_filter
or run_edge_classifier:
37 use_gpu_requested = args.use_gpu
if args.use_gpu
is not None else True
40 import onnxruntime
as ort
41 gpu_available =
"CUDAExecutionProvider" in ort.get_available_providers()
45 gpu_available = torch.cuda.is_available()
48 if use_gpu_requested
and gpu_available:
49 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
51 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
53 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
55 flags, cfg = setupGeoR4TestCfg(args, flags)
57 cfg.merge(setupHistSvcCfg(flags, outFile=args.outRootFile,
58 outStream=
"MuonEtaHoughTransformTest"))
60 from MuonConfig.MuonDataPrepConfig
import xAODUncalibMeasPrepCfg
61 cfg.merge(xAODUncalibMeasPrepCfg(flags))
63 from MuonSpacePointFormation.SpacePointFormationConfig
import MuonSpacePointFormationCfg
64 cfg.merge(MuonSpacePointFormationCfg(flags))
66 output_level = 1
if args.athenaDebug
else 3
69 from MuonInference.InferenceConfig
import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
70 bucketTool = cfg.popToolsAndMerge(GraphBucketFilterToolCfg(flags,
71 ModelPath=args.bucketModel,
72 ScoreThreshold=args.bucketThreshold,
73 OutputLevel=output_level))
74 cfg.merge(GraphInferenceAlgCfg(flags, InferenceTools=[bucketTool]))
76 from MuonConfig.ReconstructionConfigR4
import MuonReconstructionConfig
79 cfg.getEventAlgo(
"MuonEtaHoughTransformAlg").SpacePointContainer =
"FilteredMlBuckets"
81 if run_edge_classifier:
82 from MuonInference.InferenceConfig
import SegmentEdgeInferenceAlgCfg
83 cfg.merge(SegmentEdgeInferenceAlgCfg(flags,
84 EdgeModelPath=args.edgeModel,
85 EdgeThreshold=args.edgeThreshold,
86 OverlapThreshold=args.overlapThreshold,
87 UseRecoveryComponents=args.useRecoveryComponents,
88 OutputLevel=output_level))
90 if run_ml_seeder
and not run_edge_classifier:
91 print(
"WARNING: ML seeder enabled while edge classifier is disabled."
92 " The decoration 'trackCandidateIds' may be missing.")
94 from MuonTrackFindingAlgs.TrackFindingConfig
import MSTrackFinderAlgCfg
95 from ActsConfig.ActsGeometryConfig
import ActsTrackingGeometryToolCfg
96 cfg.merge(MSTrackFinderAlgCfg(flags,
97 UseMlSeeder=run_ml_seeder,
98 MlCandidateDecoration=
"trackCandidateIds",
99 TrackingGeometryTool=cfg.getPrimaryAndMerge(
100 ActsTrackingGeometryToolCfg(flags)),
101 MlFallbackToBaselineIfUndecorated=
True,
102 MlFallbackToBaselineIfNoCandidates=
False))
104 if args.enableRecoChainTester:
105 from MuonTrackFindingAlgs.TrackFindingConfig
import MuonActsToTrkConvCfg
106 cfg.merge(MuonActsToTrkConvCfg(flags,
107 ACTSTracksLocation=
"MsTracks",
108 TracksLocation=
"MsTracksConv"))
110 from xAODTrackingCnv.xAODTrackingCnvConfig
import MuonStandaloneTrackParticleCnvAlgCfg
111 cfg.merge(MuonStandaloneTrackParticleCnvAlgCfg(flags,
112 name=
"MuonXAODParticleConvR4",
113 TrackContainerName=
"MsTracksConv",
114 xAODTrackParticlesFromTracksContainerName=
"MuonSpectrometerTrackParticlesR4"))
117 from MuonTruthAlgsR4.MuonTruthAlgsConfig
import RecoSegmentTruthAssocCfg, TrackToTruthPartAssocCfg
118 cfg.merge(RecoSegmentTruthAssocCfg(flags,
119 name=
"MuonSegmentsFromR4TruthMatching",
120 SegmentKey=
"MuonSegmentsFromR4"))
121 cfg.merge(TrackToTruthPartAssocCfg(flags,
122 name=
"TrackToTruthMuonSpectrometerTrackParticlesR4",
123 TrackCollection=
"MuonSpectrometerTrackParticlesR4"))
125 from MuonPatternRecognitionTest.PatternTestConfig
import MuonRecoChainTesterCfg
126 cfg.merge(MuonRecoChainTesterCfg(flags,
127 LegacySegmentKey=
"MuonSegmentsFromR4",
128 SegmentFromR4HoughKey=
"",
129 R4SegmentKey=
"MuonSegmentsFromR4",
130 LegacyTrackKey=
"MuonSpectrometerTrackParticlesR4",
132 TrackKeyR4=
"MuonSpectrometerTrackParticlesR4"))
134 cfg.printConfig(withDetails=
True, summariseProps=
True)