ATLAS Offline Software
Loading...
Searching...
No Matches
muonSegmentDump.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3
4
5def main(args):
6 from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg
7 from MuonConfig.MuonConfigUtils import executeTest, setupHistSvcCfg
8 from AthenaConfiguration.AllConfigFlags import initConfigFlags
9 flags = initConfigFlags()
10 flags.PerfMon.doFullMonMT = False
11 includeG4TrackTruth = bool(getattr(args, "includeG4TrackTruth", False))
12 if includeG4TrackTruth:
13 # MR88449 schedules the G4/pile-up truth segment association from this flag.
14 flags.Muon.includePileUpTruth = True
15
16 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
17 use_gpu_requested = getattr(args, "use_gpu", True)
18 gpu_available = False
19 try:
20 import onnxruntime as ort
21 gpu_available = "CUDAExecutionProvider" in ort.get_available_providers()
22 except Exception:
23 try:
24 import torch
25 gpu_available = torch.cuda.is_available()
26 except Exception:
27 gpu_available = False
28 if use_gpu_requested and gpu_available:
29 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CUDA
30 else:
31 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
32
33 flags, cfg = setupGeoR4TestCfg(args)
34
35 cfg.merge(setupHistSvcCfg(flags, outFile=args.outRootFile, outStream="MuonSegmentDump"))
36
37 from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
38 cfg.merge(xAODUncalibMeasPrepCfg(flags))
39
40 from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
41 cfg.merge(MuonSpacePointFormationCfg(flags))
42
43 from MuonPatternRecognitionAlgs.MuonPatternRecognitionConfig import MuonPatternRecognitionCfg
44
45 do_ml_bucket_filter = bool(getattr(args, "doMLBucketFilter", False) or
46 getattr(args, "bucketModel", None) is not None or
47 getattr(args, "bucketThreshold", None) is not None)
48 if do_ml_bucket_filter:
49 from MuonInference.InferenceConfig import GraphBucketFilterToolCfg, GraphInferenceAlgCfg
50 bucket_tool_kwargs = {"WriteSpacePointKey": "FilteredMlBuckets"}
51 if getattr(args, "bucketModel", None) is not None:
52 bucket_tool_kwargs["ModelPath"] = args.bucketModel
53 if getattr(args, "bucketThreshold", None) is not None:
54 bucket_tool_kwargs["ScoreThreshold"] = args.bucketThreshold
55 bucket_tool = cfg.popToolsAndMerge(
56 GraphBucketFilterToolCfg(
57 flags,
58 **bucket_tool_kwargs,
59 )
60 )
61 cfg.merge(
62 GraphInferenceAlgCfg(
63 flags,
64 InferenceTools=[bucket_tool],
65 )
66 )
67 # Re-run pattern recognition on filtered buckets so dumped segments
68 # correspond to the same filtered container.
69 cfg.merge(MuonPatternRecognitionCfg(flags))
70 cfg.getEventAlgo("MuonEtaHoughTransformAlg").SpacePointContainer = "FilteredMlBuckets"
71 else:
72 cfg.merge(MuonPatternRecognitionCfg(flags))
73
74 # Truth information if MC
75 if flags.Input.isMC:
76 from MuonTruthAlgsR4.MuonTruthAlgsConfig import MuonTruthAlgsCfg
77 cfg.merge(MuonTruthAlgsCfg(flags))
78
79 from MuonBucketDump.MuonBucketDumpConfig import MuonSegmentDumpCfg
80 dumper_kwargs = {"IncludeG4TrackTruth": includeG4TrackTruth}
81
82 if do_ml_bucket_filter:
83 cfg.merge(MuonSegmentDumpCfg(flags, SpacePointKeys=["FilteredMlBuckets"],
84 **dumper_kwargs))
85 else:
86 cfg.merge(MuonSegmentDumpCfg(flags, **dumper_kwargs))
87
88 executeTest(cfg)
89
90
91if __name__ == "__main__":
92 from MuonGeoModelTestR4.testGeoModel import SetupArgParser, MuonPhaseIITestDefaults
93 parser = SetupArgParser()
94 parser.set_defaults(nEvents=-1)
95 parser.set_defaults(outRootFile="MuonSegmentDump_R3SimHits.root")
96
97 parser.set_defaults(inputFile=MuonPhaseIITestDefaults.HITS_PG_R3)
98 parser.add_argument("--doMLBucketFilter", action="store_true", default=False,
99 help="Run ML bucket filtering and dump segments from filtered buckets.")
100 parser.add_argument("--bucketModel", type=str, default=None,
101 help="Path to ONNX model used by the ML bucket filter.")
102 parser.add_argument("--bucketThreshold", type=float, default=None,
103 help="Score threshold for single-output bucket filtering.")
104 parser.add_argument("--use-gpu", action="store_true", default=True,
105 help="Use GPU for ONNX inference when available (default: True)")
106 parser.add_argument("--use-cpu", dest="use_gpu", action="store_false",
107 help="Force CPU for ONNX inference")
108 parser.add_argument("--includeG4TrackTruth", action="store_true", default=False,
109 help="Use sim-hit HepMC/G4 track identifiers for unmatched segment labels.")
110
111 args = parser.parse_args()
112 main(args)
int main()
Definition hello.cxx:18