ATLAS Offline Software
Loading...
Searching...
No Matches
muonBucketRecoChain.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3import os
4import logging
5
6# Suppress ONNX Runtime warnings at Python logging level before Athena initialization
7logging.getLogger("onnxruntime").setLevel(logging.ERROR)
8
9# Set environment variable for ONNX Runtime before imports (attempt early suppression)
10os.environ["ORT_LOGGING_LEVEL"] = "3" # 3 = ERROR
11
12from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
13from AthenaConfiguration.ComponentFactory import CompFactory
14
15def MsTrackTesterCfg(flags, name = "MsTrackTester", scheduleLegacy = True, **kwargs):
16 result = ComponentAccumulator()
17 kwargs.setdefault("isMC", flags.Input.isMC)
18 from MuonTrackFindingAlgs.TrackFindingConfig import SegmentSelectorCfg, TrackSummaryToolCfg
19 kwargs.setdefault("SegmentSelectionTool", result.popToolsAndMerge(SegmentSelectorCfg(flags)))
20 kwargs.setdefault("SummaryTool", result.popToolsAndMerge(TrackSummaryToolCfg(flags)))
21 if not scheduleLegacy:
22 kwargs.setdefault("LegacySegmentKey", "")
23 kwargs.setdefault("LegacyTrackKey", "")
24 kwargs.setdefault("LegacyMuonKey" , "")
25 the_alg = CompFactory.MuonValR4.MsTrackTester(name= name, **kwargs)
26 result.addEventAlgo(the_alg, primary = True)
27 return result
28
29def MsTrackVisualizationToolCfg(flags, name = "VisualizationTool", **kwargs):
30 result = ComponentAccumulator()
31 if not flags.Input.isMC:
32 from MuonPatternRecognitionTest.PatternTestConfig import LegacyMuonRecoChainCfg
33 result.merge(LegacyMuonRecoChainCfg(flags))
34 kwargs.setdefault("TruthSegkey", "MuonSegments")
35 from ActsConfig.ActsGeometryConfig import ActsExtrapolationToolCfg
36 kwargs.setdefault("ExtrapolationTool", result.popToolsAndMerge(ActsExtrapolationToolCfg(flags, MaxSteps=10000)))
37 the_tool = CompFactory.MuonValR4.TrackVisualizationTool(name, **kwargs)
38 result.setPrivateTools(the_tool)
39 return result
40
41if __name__=="__main__":
42 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg, SetupArgParser, MuonPhaseIITestDefaults
43 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
44 parser = SetupArgParser()
45 parser.add_argument("--noMonitorPlots", default = False, action='store_true', help="If set to true, there're no monitoring plots")
46 parser.add_argument("--writeSpacePoints", default=False, action='store_true', help="If set to true, the spacepoints in the bucket are saved to disk")
47 parser.add_argument("--noPerfMon", default=False, action='store_true', help="If set to true, disable performance monitoring")
48 parser.add_argument("--LegacyChain", default = False, action = 'store_true', help="If set to true, the legacy chain is not scheduled",)
49 parser.add_argument("--use-cpu", default = False, action = 'store_true', help="Use CPU for ONNX inference")
50 parser.add_argument("--skip-onnx", action="store_true", default=False, help="Skip ONNX inference step")
51 parser.add_argument("--bucket-model-path", dest="bucket_model_path",
52 default="/eos/project-i01/f/fcc-ml/ddicroce/ATLAS_MuonSpectrometer/KubeFlow/Inference_EdgeClassifier/athena/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/models/edgecnn_bucket_sparse_best.onnx")
53 parser.add_argument("--score-threshold", type=float, default=0.0, dest="score_threshold")
54 parser.add_argument("--output-name", default="logits", dest="output_name")
55 parser.add_argument("--graph-bucket-output-level", type=int, default=3, dest="graph_bucket_output_level", help="OutputLevel for GraphBucketFilterTool")
56 parser.add_argument("--is-logit", dest="is_logit", default=False, action="store_true", help="Interpret the single output directly and do not apply sigmoid")
57 parser.set_defaults(nEvents = -1)
58
59 parser.set_defaults(outRootFile="MsTrkTester.root")
60 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
61
62 args = parser.parse_args()
63 from AthenaConfiguration.AllConfigFlags import initConfigFlags
64 flags = initConfigFlags()
65 flags.PerfMon.doFullMonMT = not args.noPerfMon
66 flags.PerfMon.OutputJSON = "perfmonmt_MuonR4Reco.json"
67 flags.Trigger.Muon.useNewRegionSelector = False
68
69 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
70 if args.use_cpu:
71 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
72 else:
73 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
74
75 flags, cfg = setupGeoR4TestCfg(args,flags)
76
77 cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile,
78 outStream="MuonTrackTester"))
79
80 from MuonConfig.ReconstructionConfigR4 import MuonReconstructionConfig
81 cfg.merge(MuonReconstructionConfig(flags))
82
83 if not args.skip_onnx:
84 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
85 bucketTool = cfg.popToolsAndMerge(
86 GraphBucketFilterToolCfg(
87 flags,
88 ModelPath=args.bucket_model_path,
89 ScoreThreshold=args.score_threshold,
90 OutputName=args.output_name,
91 OutputLevel=args.graph_bucket_output_level,
92 SingleOutputIsLogit=args.is_logit if hasattr(args, "is_logit") else False,
93 )
94 )
95 cfg.merge(
96 GraphInferenceAlgCfg(
97 flags,
98 InferenceTools=[bucketTool],
99 )
100 )
101 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlBuckets"
102
103
104 from MuonPatternRecognitionTest.PatternTestConfig import LegacyMuonRecoChainCfg
105
106 if args.LegacyChain:
107 cfg.merge(LegacyMuonRecoChainCfg(flags))
108
109 cfg.merge(MsTrackTesterCfg(flags, scheduleLegacy = args.LegacyChain))
110
111 cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile,
112 outStream="MuonEtaHoughTransformTest"))
113
114 from MuonPatternRecognitionTest.PatternTestConfig import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
115
116
117 cfg.merge(MuonHoughTransformTesterCfg(flags,
118 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
119
120
121 if not args.noMonitorPlots:
122 cfg.getEventAlgo("MSTrackFinderAlg").VisualizationTool = cfg.popToolsAndMerge(MsTrackVisualizationToolCfg(flags))
123 cfg.getEventAlgo("MuonSegmentFittingAlg").VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags,
124 CanvasPreFix="SegmentPlotValid", outSubDir="SegmentValidPlots",
125 displayTruthOnly = True, saveSinglePDFs = False, saveSummaryPDF= False))
126
127 cfg.getEventAlgo("MuonNswSegmentFinderAlg").VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags,
128 CanvasPreFix="NswSegmentFitPlotValid", outSubDir="SegmentValidPlots",
129 doPhiBucketViews = False, saveSinglePDFs = False,
130 saveSummaryPDF= False,CanvasLimits=10000))
131
132
133
134 executeTest(cfg)
MsTrackVisualizationToolCfg(flags, name="VisualizationTool", **kwargs)
MsTrackTesterCfg(flags, name="MsTrackTester", scheduleLegacy=True, **kwargs)