ATLAS Offline Software
Loading...
Searching...
No Matches
muonDisplacedVertexInference.py
Go to the documentation of this file.
1# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3"""Run/debug the MuonLearning DisplacedVertex event classifier.
4
5Example:
6python -m MuonInference.muonDisplacedVertexInference \
7 --inputFile my.RDO.pool.root \
8 --model-path MuonInference/models/edge_class_dv_mu200.onnx \
9 --nEvents 10 \
10 --debug \
11 --print-every-event
12
13The ONNX model is expected to embed feature normalization and consume raw
14DisplacedVertex graph tensors: x, edge_index, edge_attr, n_muon_nodes.
15"""
16
17import os
18import logging
19
20logging.getLogger("onnxruntime").setLevel(logging.ERROR)
21os.environ.setdefault("ORT_LOGGING_LEVEL", "3")
22
23if __name__ == "__main__":
24 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg, SetupArgParser, MuonPhaseIITestDefaults
25 from MuonConfig.MuonConfigUtils import executeTest
26
27 parser = SetupArgParser()
28 parser.add_argument("--model-path", default="MuonInference/models/edge_class_dv_mu200.onnx")
29 parser.add_argument("--segment-key", default="MuonSegmentsFromR4")
30 parser.add_argument("--tower-container-key", default="CombinedTower")
31 parser.add_argument("--min-tower-energy-mev", type=float, default=1000.0)
32 parser.add_argument("--max-tower-segment-dr", type=float, default=0.4)
33 parser.add_argument("--calo-r-max-mm", type=float, default=4250.0)
34 parser.add_argument("--calo-z-max-mm", type=float, default=6500.0)
35 parser.add_argument("--sector-modulo", type=int, default=16)
36 parser.add_argument("--require-edges", action="store_true", default=False)
37 parser.add_argument("--score-threshold", type=float, default=0.5)
38 parser.add_argument("--threshold-mode", choices=["score", "raw"], default="score",
39 help="'score' compares the post-processed signal score/probability;"
40 "'raw' compares the raw ONNX output/logit.")
41 parser.add_argument("--single-output-mode", default="logit", choices=["auto", "logit", "prob"])
42 parser.add_argument("--use-cpu", action="store_true", default=False)
43 parser.add_argument("--no-reco", action="store_true", default=False,
44 help="Do not schedule MuonReconstructionConfig; use this when segments/towers already exist in the input")
45 parser.add_argument("--use-filtered-buckets-for-dv-graph", action="store_true", default=False,
46 help="Feed FilteredMlBuckets to the DV graph builder")
47 parser.add_argument("--doMLBucketFilter", dest="do_ml_bucket_filter", action="store_true", default=True,
48 help="Run the same bucket prefilter used when producing the DV training ROOT/H5 files")
49 parser.add_argument("--no-ml-bucket-filter", dest="do_ml_bucket_filter", action="store_false",
50 help="Disable the bucket prefilter. This no longer matches the default DV training production.")
51 parser.add_argument("--bucketModel", "--bucket-model", dest="bucket_model",
52 default="dev/MuonRecRTT/edgecnn_mu200.onnx",
53 help="Bucket prefilter ONNX model, matching the muonBucketDump --bucketModel option")
54 parser.add_argument("--bucketThreshold", "--bucket-threshold", dest="bucket_threshold", type=float, default=0.160,
55 help="Bucket prefilter working point, matching the muonBucketDump --bucketThreshold option")
56 parser.add_argument("--filtered-bucket-key", default="FilteredMlBuckets",
57 help="StoreGate key written by the bucket prefilter and read by the DV graph builder")
58 parser.add_argument("--no-calo-towers", action="store_true", default=False,
59 help="Do not schedule CaloRecoCfg/CaloTowerMakerCfg. Use only when TowerContainerKey already exists in the input/event store.")
60 parser.add_argument("--debug", action="store_true", default=False,
61 help="Set the DV tool and algorithm OutputLevel to DEBUG and dump the first graph entries")
62 parser.add_argument("--debug-nodes", type=int, default=5)
63 parser.add_argument("--debug-edges", type=int, default=10)
64 parser.add_argument("--print-every-event", action="store_true", default=False)
65 parser.add_argument("--decorate-event-info", action="store_true", default=False,
66 help="Enable optional EventInfo DV decorations for validation/debug output")
67 parser.set_defaults(nEvents=10)
68 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.RDO_R4)
69 parser.set_defaults(defaultGeoFile="RUN4")
70
71 args = parser.parse_args()
72
73 from AthenaConfiguration.AllConfigFlags import initConfigFlags
74 flags = initConfigFlags()
75 flags.Trigger.Muon.useNewRegionSelector = False
76
77 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
78 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU if args.use_cpu else OnnxRuntimeType.CUDA
79
80 flags, cfg = setupGeoR4TestCfg(args, flags)
81
82 if not args.no_reco:
83 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
84 cfg.merge(xAODUncalibMeasPrepCfg(flags))
85
86 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
87 cfg.merge(MuonSpacePointFormationCfg(flags))
88
89 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
90 cfg.merge(MuonPatternRecognitionCfg(flags))
91
92 from MuonInference.InferenceConfig import DisplacedVertexInferenceAlgCfg
93
94 output_level = 2 if args.debug else 3
95 cfg.merge(DisplacedVertexInferenceAlgCfg(
96 flags,
97 ModelPath=args.model_path,
98 SegmentKey=args.segment_key,
99 TowerContainerKey=args.tower_container_key,
100 MinTowerEnergyMeV=args.min_tower_energy_mev,
101 MaxTowerSegmentDR=args.max_tower_segment_dr,
102 CaloRMaxMm=args.calo_r_max_mm,
103 CaloZMaxMm=args.calo_z_max_mm,
104 SectorModulo=args.sector_modulo,
105 DoMLBucketFilter=args.do_ml_bucket_filter,
106 BucketModelPath=args.bucket_model,
107 BucketThreshold=args.bucket_threshold,
108 FilteredBucketKey=args.filtered_bucket_key,
109 UseFilteredBucketsForDVGraph=args.use_filtered_buckets_for_dv_graph,
110 RequireEdges=args.require_edges,
111 DoCaloTowerBuild=not args.no_calo_towers,
112 SingleOutputMode=args.single_output_mode,
113 ScoreThreshold=args.score_threshold,
114 ThresholdMode=args.threshold_mode,
115 OutputLevel=output_level,
116 DebugDumpFirstNNodes=args.debug_nodes if args.debug else 0,
117 DebugDumpFirstNEdges=args.debug_edges if args.debug else 0,
118 PrintEveryEvent=args.print_every_event,
119 DecorateEventInfo=args.decorate_event_info,
120 ))
121
122 executeTest(cfg)