3"""Run/debug the MuonLearning DisplacedVertex event classifier.
6python -m MuonInference.muonDisplacedVertexInference \
7 --inputFile my.RDO.pool.root \
8 --model-path MuonInference/models/edge_class_dv_mu200.onnx \
13The ONNX model is expected to embed feature normalization and consume raw
14DisplacedVertex graph tensors: x, edge_index, edge_attr, n_muon_nodes.
20logging.getLogger(
"onnxruntime").setLevel(logging.ERROR)
21os.environ.setdefault(
"ORT_LOGGING_LEVEL",
"3")
23if __name__ ==
"__main__":
24 from MuonGeoModelTestR4.testGeoModel
import setupGeoR4TestCfg, SetupArgParser, MuonPhaseIITestDefaults
25 from MuonConfig.MuonConfigUtils
import executeTest
27 parser = SetupArgParser()
28 parser.add_argument(
"--model-path", default=
"MuonInference/models/edge_class_dv_mu200.onnx")
29 parser.add_argument(
"--segment-key", default=
"MuonSegmentsFromR4")
30 parser.add_argument(
"--tower-container-key", default=
"CombinedTower")
31 parser.add_argument(
"--min-tower-energy-mev", type=float, default=1000.0)
32 parser.add_argument(
"--max-tower-segment-dr", type=float, default=0.4)
33 parser.add_argument(
"--calo-r-max-mm", type=float, default=4250.0)
34 parser.add_argument(
"--calo-z-max-mm", type=float, default=6500.0)
35 parser.add_argument(
"--sector-modulo", type=int, default=16)
36 parser.add_argument(
"--require-edges", action=
"store_true", default=
False)
37 parser.add_argument(
"--score-threshold", type=float, default=0.5)
38 parser.add_argument(
"--threshold-mode", choices=[
"score",
"raw"], default=
"score",
39 help=
"'score' compares the post-processed signal score/probability;"
40 "'raw' compares the raw ONNX output/logit.")
41 parser.add_argument(
"--single-output-mode", default=
"logit", choices=[
"auto",
"logit",
"prob"])
42 parser.add_argument(
"--use-cpu", action=
"store_true", default=
False)
43 parser.add_argument(
"--no-reco", action=
"store_true", default=
False,
44 help=
"Do not schedule MuonReconstructionConfig; use this when segments/towers already exist in the input")
45 parser.add_argument(
"--use-filtered-buckets-for-dv-graph", action=
"store_true", default=
False,
46 help=
"Feed FilteredMlBuckets to the DV graph builder")
47 parser.add_argument(
"--doMLBucketFilter", dest=
"do_ml_bucket_filter", action=
"store_true", default=
True,
48 help=
"Run the same bucket prefilter used when producing the DV training ROOT/H5 files")
49 parser.add_argument(
"--no-ml-bucket-filter", dest=
"do_ml_bucket_filter", action=
"store_false",
50 help=
"Disable the bucket prefilter. This no longer matches the default DV training production.")
51 parser.add_argument(
"--bucketModel",
"--bucket-model", dest=
"bucket_model",
52 default=
"dev/MuonRecRTT/edgecnn_mu200.onnx",
53 help=
"Bucket prefilter ONNX model, matching the muonBucketDump --bucketModel option")
54 parser.add_argument(
"--bucketThreshold",
"--bucket-threshold", dest=
"bucket_threshold", type=float, default=0.160,
55 help=
"Bucket prefilter working point, matching the muonBucketDump --bucketThreshold option")
56 parser.add_argument(
"--filtered-bucket-key", default=
"FilteredMlBuckets",
57 help=
"StoreGate key written by the bucket prefilter and read by the DV graph builder")
58 parser.add_argument(
"--no-calo-towers", action=
"store_true", default=
False,
59 help=
"Do not schedule CaloRecoCfg/CaloTowerMakerCfg. Use only when TowerContainerKey already exists in the input/event store.")
60 parser.add_argument(
"--debug", action=
"store_true", default=
False,
61 help=
"Set the DV tool and algorithm OutputLevel to DEBUG and dump the first graph entries")
62 parser.add_argument(
"--debug-nodes", type=int, default=5)
63 parser.add_argument(
"--debug-edges", type=int, default=10)
64 parser.add_argument(
"--print-every-event", action=
"store_true", default=
False)
65 parser.add_argument(
"--decorate-event-info", action=
"store_true", default=
False,
66 help=
"Enable optional EventInfo DV decorations for validation/debug output")
67 parser.set_defaults(nEvents=10)
68 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.RDO_R4)
69 parser.set_defaults(defaultGeoFile=
"RUN4")
71 args = parser.parse_args()
73 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
74 flags = initConfigFlags()
75 flags.Trigger.Muon.useNewRegionSelector =
False
77 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
78 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
if args.use_cpu
else OnnxRuntimeType.CUDA
80 flags, cfg = setupGeoR4TestCfg(args, flags)
83 from MuonConfig.MuonDataPrepConfig
import xAODUncalibMeasPrepCfg
84 cfg.merge(xAODUncalibMeasPrepCfg(flags))
86 from MuonSpacePointFormation.SpacePointFormationConfig
import MuonSpacePointFormationCfg
87 cfg.merge(MuonSpacePointFormationCfg(flags))
89 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig
import MuonPatternRecognitionCfg
90 cfg.merge(MuonPatternRecognitionCfg(flags))
92 from MuonInference.InferenceConfig
import DisplacedVertexInferenceAlgCfg
94 output_level = 2
if args.debug
else 3
95 cfg.merge(DisplacedVertexInferenceAlgCfg(
97 ModelPath=args.model_path,
98 SegmentKey=args.segment_key,
99 TowerContainerKey=args.tower_container_key,
100 MinTowerEnergyMeV=args.min_tower_energy_mev,
101 MaxTowerSegmentDR=args.max_tower_segment_dr,
102 CaloRMaxMm=args.calo_r_max_mm,
103 CaloZMaxMm=args.calo_z_max_mm,
104 SectorModulo=args.sector_modulo,
105 DoMLBucketFilter=args.do_ml_bucket_filter,
106 BucketModelPath=args.bucket_model,
107 BucketThreshold=args.bucket_threshold,
108 FilteredBucketKey=args.filtered_bucket_key,
109 UseFilteredBucketsForDVGraph=args.use_filtered_buckets_for_dv_graph,
110 RequireEdges=args.require_edges,
111 DoCaloTowerBuild=
not args.no_calo_towers,
112 SingleOutputMode=args.single_output_mode,
113 ScoreThreshold=args.score_threshold,
114 ThresholdMode=args.threshold_mode,
115 OutputLevel=output_level,
116 DebugDumpFirstNNodes=args.debug_nodes
if args.debug
else 0,
117 DebugDumpFirstNEdges=args.debug_edges
if args.debug
else 0,
118 PrintEveryEvent=args.print_every_event,
119 DecorateEventInfo=args.decorate_event_info,