ATLAS Offline Software
Loading...
Searching...
No Matches
muonBucketInference.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3import os
4import logging
5
6# Suppress ONNX Runtime warnings at Python logging level before Athena initialization
7logging.getLogger("onnxruntime").setLevel(logging.ERROR)
8
9# Set environment variable for ONNX Runtime before imports (attempt early suppression)
10os.environ["ORT_LOGGING_LEVEL"] = "3" # 3 = ERROR
11
12if __name__ == "__main__":
13 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
14 parser = SetupArgParser()
15 parser.set_defaults(nEvents = -1)
16 parser.set_defaults(outRootFile="InferenceHoughTest.root")
17 parser.set_defaults(noMM=True)
18 parser.set_defaults(noSTGC=True)
19 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
20
21 parser.add_argument("--noPerfMon", help="If set to true, full perfmonMT is enabled",
22 default=False, action='store_true')
23 parser.add_argument("--use-gpu", action="store_true", dest="use_gpu", default=True,
24 help="Use GPU for ONNX inference (default: True)")
25 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
26 help="Use CPU for ONNX inference")
27 parser.add_argument("--athenaDebug", action="store_true", default=False,
28 help="Enable DEBUG verbosity for bucket inference components in MessageSvc")
29 parser.add_argument("--athenaVerbose", action="store_true", default=False,
30 help="Enable VERBOSE verbosity for bucket inference components in MessageSvc")
31 parser.add_argument("--bucket-model-path", dest="bucket_model_path",
32 default="/eos/project-i01/f/fcc-ml/ddicroce/ATLAS_MuonSpectrometer/KubeFlow/Inference_EdgeClassifier/athena/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/models/edgecnn_bucket_sparse_best.onnx",
33 help="Absolute path (or PathResolver key) for the bucket ONNX model")
34 parser.add_argument("--score-threshold", type=float, default=0.2, dest="score_threshold",
35 help="Keep bucket if single-output score > threshold (default: 0.2)")
36 parser.add_argument("--bucket-debug-dump-file", default="", dest="bucket_debug_dump_file",
37 help=("Optional JSONL output file with Athena-side bucket ONNX."))
38 parser.add_argument("--bucket-debug-dump-max-events", type=int, default=0, dest="bucket_debug_dump_max_events",
39 help=("Maximum number of events to write to --bucket-debug-dump-file."))
40 parser.add_argument("--bucket-print-labels", action="store_true", default=False, dest="bucket_print_labels",
41 help=("Print per-event good-bucket efficiency using the training label."))
42 parser.add_argument("--bucket-label-print-first-n-buckets", type=int, default=20, dest="bucket_label_print_first_n_buckets",
43 help=("When --bucket-print-labels is enabled, print detailed label/decision."))
44 parser.add_argument("--bucket-label-segment-key", default="MuonSegmentsFromR4", dest="bucket_label_segment_key",
45 help=("Optional xAOD::MuonSegmentContainer key used for the bucket_segments."))
46
47 args = parser.parse_args()
48
49 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
50 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
51 from AthenaConfiguration.AllConfigFlags import initConfigFlags
52
53 flags = initConfigFlags()
54 flags.PerfMon.doFullMonMT = not args.noPerfMon
55 if args.athenaDebug or args.athenaVerbose:
56 flags.Common.MsgSuppression = False
57
58 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
59 # Use command line argument if provided, otherwise default to True
60 use_gpu_requested = args.use_gpu if args.use_gpu is not None else True
61 # Runtime check for GPU availability. Prefer ONNXRuntime provider list,
62 # fall back to PyTorch if ONNX runtime isn't available.
63 gpu_available = False
64 try:
65 import onnxruntime as ort
66 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
67 except Exception:
68 try:
69 import torch
70 gpu_available = torch.cuda.is_available()
71 except Exception:
72 gpu_available = False
73 if use_gpu_requested and gpu_available:
74 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
75 else:
76 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
77
78 flags, cfg = setupGeoR4TestCfg(args,flags)
79
80 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
81 cfg.merge(xAODUncalibMeasPrepCfg(flags))
82
83 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
84 cfg.merge(MuonSpacePointFormationCfg(flags))
85
86 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
87
88
89 sample_is_mc = bool(flags.Input.isMC)
90 use_truth_labels = bool(args.bucket_print_labels and sample_is_mc)
91 use_segment_labels = bool(args.bucket_print_labels and (not sample_is_mc))
92
93 if use_segment_labels and not args.bucket_label_segment_key:
94 raise RuntimeError(
95 "--bucket-print-labels was requested for recorded data, but "
96 "--bucket-label-segment-key is empty. Set it to MuonSegmentsFromR4 "
97 "or disable label printing."
98 )
99
100 build_label_segments_first = bool(use_segment_labels)
101
102 if build_label_segments_first:
103 cfg.merge(MuonPatternRecognitionCfg(flags))
104
105 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
106 output_level = 1 if args.athenaVerbose else (2 if args.athenaDebug else 3)
107
108 bucket_tool_kwargs = dict(
109 ModelPath=args.bucket_model_path,
110 ScoreThreshold=args.score_threshold,
111 DebugDumpFile=args.bucket_debug_dump_file,
112 DebugDumpMaxEvents=args.bucket_debug_dump_max_events,
113 PrintLabels=args.bucket_print_labels,
114 LabelPrintFirstNBuckets=args.bucket_label_print_first_n_buckets,
115 OutputLevel=output_level,
116 )
117
118 if args.bucket_print_labels:
119 if use_truth_labels:
120 from MuonPatternRecognitionTest.PatternTestConfig import PatternVisualizationToolCfg
121 bucket_tool_kwargs["LabelVisualizationTool"] = cfg.popToolsAndMerge(
122 PatternVisualizationToolCfg(flags, CanvasLimits=0))
123 print("Bucket label mode: MC input sample -> label = bucket with truth muons only")
124 elif use_segment_labels:
125 bucket_tool_kwargs["LabelSegmentKey"] = args.bucket_label_segment_key
126 print("Bucket label mode: recorded-data input sample -> label = bucket_segments > 0 "
127 f"using {args.bucket_label_segment_key}")
128
129 bucketTool = cfg.popToolsAndMerge(GraphBucketFilterToolCfg(flags, **bucket_tool_kwargs))
130 cfg.merge(GraphInferenceAlgCfg(flags, InferenceTools=[bucketTool]))
131
132 if not build_label_segments_first:
133 cfg.merge(MuonPatternRecognitionCfg(flags))
134 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlBuckets"
135
136 cfg.merge(setupHistSvcCfg( flags, outFile=args.outRootFile,
137 outStream="MuonEtaHoughTransformTest"))
138
139 from MuonPatternRecognitionTest.PatternTestConfig import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
140
141 cfg.merge(MuonHoughTransformTesterCfg( flags,
142 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
143
144 executeTest(cfg)
145
146
void print(char *figname, TCanvas *c1)