7logging.getLogger(
"onnxruntime").setLevel(logging.ERROR)
10os.environ[
"ORT_LOGGING_LEVEL"] =
"3"
12if __name__ ==
"__main__":
13 from MuonGeoModelTestR4.testGeoModel
import SetupArgParser, MuonPhaseIITestDefaults
14 parser = SetupArgParser()
15 parser.set_defaults(nEvents = -1)
16 parser.set_defaults(outRootFile=
"InferenceHoughTest.root")
17 parser.set_defaults(noMM=
True)
18 parser.set_defaults(noSTGC=
True)
19 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
21 parser.add_argument(
"--noPerfMon", help=
"If set to true, full perfmonMT is enabled",
22 default=
False, action=
'store_true')
23 parser.add_argument(
"--use-gpu", action=
"store_true", dest=
"use_gpu", default=
True,
24 help=
"Use GPU for ONNX inference (default: True)")
25 parser.add_argument(
"--use-cpu", dest=
"use_gpu", action=
"store_false",
26 help=
"Use CPU for ONNX inference")
27 parser.add_argument(
"--athenaDebug", action=
"store_true", default=
False,
28 help=
"Enable DEBUG verbosity for bucket inference components in MessageSvc")
29 parser.add_argument(
"--athenaVerbose", action=
"store_true", default=
False,
30 help=
"Enable VERBOSE verbosity for bucket inference components in MessageSvc")
31 parser.add_argument(
"--bucket-model-path", dest=
"bucket_model_path",
32 default=
"/eos/project-i01/f/fcc-ml/ddicroce/ATLAS_MuonSpectrometer/KubeFlow/Inference_EdgeClassifier/athena/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/models/edgecnn_bucket_sparse_best.onnx",
33 help=
"Absolute path (or PathResolver key) for the bucket ONNX model")
34 parser.add_argument(
"--score-threshold", type=float, default=0.2, dest=
"score_threshold",
35 help=
"Keep bucket if single-output score > threshold (default: 0.2)")
36 parser.add_argument(
"--bucket-debug-dump-file", default=
"", dest=
"bucket_debug_dump_file",
37 help=(
"Optional JSONL output file with Athena-side bucket ONNX."))
38 parser.add_argument(
"--bucket-debug-dump-max-events", type=int, default=0, dest=
"bucket_debug_dump_max_events",
39 help=(
"Maximum number of events to write to --bucket-debug-dump-file."))
40 parser.add_argument(
"--bucket-print-labels", action=
"store_true", default=
False, dest=
"bucket_print_labels",
41 help=(
"Print per-event good-bucket efficiency using the training label."))
42 parser.add_argument(
"--bucket-label-print-first-n-buckets", type=int, default=20, dest=
"bucket_label_print_first_n_buckets",
43 help=(
"When --bucket-print-labels is enabled, print detailed label/decision."))
44 parser.add_argument(
"--bucket-label-segment-key", default=
"MuonSegmentsFromR4", dest=
"bucket_label_segment_key",
45 help=(
"Optional xAOD::MuonSegmentContainer key used for the bucket_segments."))
47 args = parser.parse_args()
49 from MuonGeoModelTestR4.testGeoModel
import setupGeoR4TestCfg
50 from MuonConfig.MuonConfigUtils
import executeTest, setupHistSvcCfg
51 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
53 flags = initConfigFlags()
54 flags.PerfMon.doFullMonMT =
not args.noPerfMon
55 if args.athenaDebug
or args.athenaVerbose:
56 flags.Common.MsgSuppression =
False
58 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
60 use_gpu_requested = args.use_gpu
if args.use_gpu
is not None else True
65 import onnxruntime
as ort
66 gpu_available =
"CUDAExecutionProvider" in ort.get_available_providers()
70 gpu_available = torch.cuda.is_available()
73 if use_gpu_requested
and gpu_available:
74 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
76 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
78 flags, cfg = setupGeoR4TestCfg(args,flags)
80 from MuonConfig.MuonDataPrepConfig
import xAODUncalibMeasPrepCfg
81 cfg.merge(xAODUncalibMeasPrepCfg(flags))
83 from MuonSpacePointFormation.SpacePointFormationConfig
import MuonSpacePointFormationCfg
84 cfg.merge(MuonSpacePointFormationCfg(flags))
86 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig
import MuonPatternRecognitionCfg
89 sample_is_mc = bool(flags.Input.isMC)
90 use_truth_labels = bool(args.bucket_print_labels
and sample_is_mc)
91 use_segment_labels = bool(args.bucket_print_labels
and (
not sample_is_mc))
93 if use_segment_labels
and not args.bucket_label_segment_key:
95 "--bucket-print-labels was requested for recorded data, but "
96 "--bucket-label-segment-key is empty. Set it to MuonSegmentsFromR4 "
97 "or disable label printing."
100 build_label_segments_first = bool(use_segment_labels)
102 if build_label_segments_first:
103 cfg.merge(MuonPatternRecognitionCfg(flags))
105 from MuonInference.InferenceConfig
import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
106 output_level = 1
if args.athenaVerbose
else (2
if args.athenaDebug
else 3)
108 bucket_tool_kwargs = dict(
109 ModelPath=args.bucket_model_path,
110 ScoreThreshold=args.score_threshold,
111 DebugDumpFile=args.bucket_debug_dump_file,
112 DebugDumpMaxEvents=args.bucket_debug_dump_max_events,
113 PrintLabels=args.bucket_print_labels,
114 LabelPrintFirstNBuckets=args.bucket_label_print_first_n_buckets,
115 OutputLevel=output_level,
118 if args.bucket_print_labels:
120 from MuonPatternRecognitionTest.PatternTestConfig
import PatternVisualizationToolCfg
121 bucket_tool_kwargs[
"LabelVisualizationTool"] = cfg.popToolsAndMerge(
122 PatternVisualizationToolCfg(flags, CanvasLimits=0))
123 print(
"Bucket label mode: MC input sample -> label = bucket with truth muons only")
124 elif use_segment_labels:
125 bucket_tool_kwargs[
"LabelSegmentKey"] = args.bucket_label_segment_key
126 print(
"Bucket label mode: recorded-data input sample -> label = bucket_segments > 0 "
127 f
"using {args.bucket_label_segment_key}")
129 bucketTool = cfg.popToolsAndMerge(GraphBucketFilterToolCfg(flags, **bucket_tool_kwargs))
130 cfg.merge(GraphInferenceAlgCfg(flags, InferenceTools=[bucketTool]))
132 if not build_label_segments_first:
133 cfg.merge(MuonPatternRecognitionCfg(flags))
134 cfg.getEventAlgo(
"MuonEtaHoughTransformAlg").SpacePointContainer =
"FilteredMlBuckets"
136 cfg.merge(setupHistSvcCfg( flags, outFile=args.outRootFile,
137 outStream=
"MuonEtaHoughTransformTest"))
139 from MuonPatternRecognitionTest.PatternTestConfig
import MuonHoughTransformTesterCfg, PatternVisualizationToolCfg
141 cfg.merge(MuonHoughTransformTesterCfg( flags,
142 VisualizationTool = cfg.popToolsAndMerge(PatternVisualizationToolCfg(flags, CanvasLimits =0))))
void print(char *figname, TCanvas *c1)