ATLAS Offline Software
Loading...
Searching...
No Matches
InferenceConfig Namespace Reference

Functions

 MuonLearningOnnxRuntimeSvcCfg (flags, name="OnnxRuntimeSvc", **kwargs)
 GraphInferenceAlgCfg (flags, name="GraphInferenceAlg", **kwargs)
 GraphSPFilterToolCfg (flags, name="GraphSPFilterTool", **kwargs)
 GraphBucketFilterToolCfg (flags, name="GraphBucketFilterTool", **kwargs)
 SegmentEdgeClassifierToolCfg (flags, name="SegmentEdgeClassifierTool", **kwargs)
 SegmentTrackCandidateBuilderToolCfg (flags, name="SegmentTrackCandidateBuilderTool", **kwargs)
 SegmentEdgeInferenceAlgCfg (flags, name="SegmentEdgeInferenceAlg", **kwargs)

Function Documentation

◆ GraphBucketFilterToolCfg()

InferenceConfig.GraphBucketFilterToolCfg ( flags,
name = "GraphBucketFilterTool",
** kwargs )

Definition at line 32 of file InferenceConfig.py.

32def GraphBucketFilterToolCfg(flags, name ="GraphBucketFilterTool", **kwargs):
33
34 from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
35
36 result = ComponentAccumulator()
37 model_path = kwargs.pop("ModelPath", "dev/MuonRecRTT/edgecnn_multi_bucket_sparse_meta.onnx")
38
39 if not model_path.startswith('/'):
40 pass
41 else:
42 pass
43
44 result.merge(MuonLearningOnnxRuntimeSvcCfg(flags))
45 kwargs.setdefault("ModelSession", result.popToolsAndMerge(
46 OnnxRuntimeSessionToolCfg(flags, model_fname=model_path,
47 OnnxRuntimeSvc=result.getService("OnnxRuntimeSvc"))))
48 # BiasClass0: Working point selection bias for multi-class comparison
49 # Higher values make class 0 (reject) less likely, accepting more buckets
50 kwargs.setdefault("BiasClass0", 1.0)
51 kwargs.setdefault("OutputLevel", 3) # INFO level (1=VERBOSE, 2=DEBUG, 3=INFO, 4=WARNING, 5=ERROR, 6=FATAL)
52
53 the_tool = CompFactory.MuonML.GraphBucketFilterTool(name, **kwargs)
54 result.setPrivateTools(the_tool)
55 return result
56
57

◆ GraphInferenceAlgCfg()

InferenceConfig.GraphInferenceAlgCfg ( flags,
name = "GraphInferenceAlg",
** kwargs )

Definition at line 14 of file InferenceConfig.py.

14def GraphInferenceAlgCfg(flags, name = "GraphInferenceAlg", **kwargs):
15 result = ComponentAccumulator()
16 the_alg = CompFactory.MuonML.InferenceAlg(name, **kwargs)
17 result.addEventAlgo(the_alg, primary = True)
18 return result
19

◆ GraphSPFilterToolCfg()

InferenceConfig.GraphSPFilterToolCfg ( flags,
name = "GraphSPFilterTool",
** kwargs )

Definition at line 20 of file InferenceConfig.py.

20def GraphSPFilterToolCfg(flags, name ="GraphSPFilterTool", **kwargs):
21
22 from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
23
24 result = ComponentAccumulator()
25 kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname="/eos/atlas/atlascerngroupdisk/data-art/grid-input/MuonRecRTT/TestModel.onnx")))
26 kwargs.setdefault("MLFilterCut", -3.6) # Working point cut
27
28 the_tool = CompFactory.MuonML.GraphSPFilterTool(name, **kwargs)
29 result.setPrivateTools(the_tool)
30 return result
31

◆ MuonLearningOnnxRuntimeSvcCfg()

InferenceConfig.MuonLearningOnnxRuntimeSvcCfg ( flags,
name = "OnnxRuntimeSvc",
** kwargs )
Configure the shared ONNX Runtime service used by MuonLearning tools.

Definition at line 6 of file InferenceConfig.py.

6def MuonLearningOnnxRuntimeSvcCfg(flags, name="OnnxRuntimeSvc", **kwargs):
7 """Configure the shared ONNX Runtime service used by MuonLearning tools."""
8 result = ComponentAccumulator()
9 kwargs.setdefault("LogLevel", 3)
10 svc = CompFactory.AthOnnx.OnnxRuntimeSvc(name, **kwargs)
11 result.addService(svc, primary=False, create=True)
12 return result
13

◆ SegmentEdgeClassifierToolCfg()

InferenceConfig.SegmentEdgeClassifierToolCfg ( flags,
name = "SegmentEdgeClassifierTool",
** kwargs )

Definition at line 58 of file InferenceConfig.py.

58def SegmentEdgeClassifierToolCfg(flags, name="SegmentEdgeClassifierTool", **kwargs):
59 from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
60
61 result = ComponentAccumulator()
62 model_path = kwargs.pop("ModelPath", "MuonInference/models/edge_gnn_refit_top01_from_t0020.onnx")
63 result.merge(MuonLearningOnnxRuntimeSvcCfg(flags))
64 kwargs.setdefault("ModelSession", result.popToolsAndMerge(
65 OnnxRuntimeSessionToolCfg(flags, model_fname=model_path,
66 OnnxRuntimeSvc=result.getService("OnnxRuntimeSvc"))))
67 # Keep the same ONNX/model properties used by GraphBucketFilterToolCfg in this file.
68 kwargs.setdefault("InputNodeName", "x")
69 kwargs.setdefault("InputEdgeIndexName", "edge_index")
70 kwargs.setdefault("InputEdgeAttrName", "edge_attr")
71 kwargs.setdefault("OutputName", "logits")
72 kwargs.setdefault("MaxDeltaThetaDeg", 35.0)
73 kwargs.setdefault("MaxDeltaSector", 1)
74 kwargs.setdefault("SectorModulo", 16)
75 tool = CompFactory.MuonML.SegmentEdgeClassifierTool(name, **kwargs)
76 result.setPrivateTools(tool)
77 return result
78
79

◆ SegmentEdgeInferenceAlgCfg()

InferenceConfig.SegmentEdgeInferenceAlgCfg ( flags,
name = "SegmentEdgeInferenceAlg",
** kwargs )

Definition at line 97 of file InferenceConfig.py.

97def SegmentEdgeInferenceAlgCfg(flags, name="SegmentEdgeInferenceAlg", **kwargs):
98 result = ComponentAccumulator()
99 # Accept EdgeModelPath as a convenience shortcut so callers don't need to
100 # build the tool object themselves; a raw dict is also unwrapped for
101 # backwards-compatibility with call-sites that used dict syntax.
102 edge_tool_kwargs = {}
103 if "EdgeModelPath" in kwargs:
104 edge_tool_kwargs["ModelPath"] = kwargs.pop("EdgeModelPath")
105 # Silently unwrap legacy dict-style: EdgeClassifierTool={"ModelPath": ...}
106 if isinstance(kwargs.get("EdgeClassifierTool"), dict):
107 edge_tool_kwargs.update(kwargs.pop("EdgeClassifierTool"))
108
109 candidate_builder_kwargs = {}
110 for key in ("EdgeThreshold",
111 "OverlapThreshold",
112 "UseRecoveryComponents",
113 "SymmetrizeDirectedEdges",
114 "AddAllSegmentsRecoveryCandidate",
115 "KeepIsolatedSegments",
116 "MinCandidateSize"):
117 if key in kwargs:
118 candidate_builder_kwargs[key] = kwargs.pop(key)
119
120 if isinstance(kwargs.get("CandidateBuilderTool"), dict):
121 candidate_builder_kwargs.update(kwargs.pop("CandidateBuilderTool"))
122
123 if "EdgeClassifierTool" not in kwargs:
124 kwargs["EdgeClassifierTool"] = result.popToolsAndMerge(
125 SegmentEdgeClassifierToolCfg(flags, **edge_tool_kwargs))
126 kwargs.setdefault("CandidateBuilderTool", result.popToolsAndMerge(
127 SegmentTrackCandidateBuilderToolCfg(flags, **candidate_builder_kwargs)))
128 kwargs.setdefault("SegmentKey", "MuonSegmentsFromR4")
129 kwargs.setdefault("CandidateDecoration", "MuonSegmentsFromR4.trackCandidateIds")
130 alg = CompFactory.MuonML.SegmentEdgeInferenceAlg(name=name, **kwargs)
131 result.addEventAlgo(alg, primary=True)
132 return result

◆ SegmentTrackCandidateBuilderToolCfg()

InferenceConfig.SegmentTrackCandidateBuilderToolCfg ( flags,
name = "SegmentTrackCandidateBuilderTool",
** kwargs )

Definition at line 80 of file InferenceConfig.py.

80def SegmentTrackCandidateBuilderToolCfg(flags, name="SegmentTrackCandidateBuilderTool", **kwargs):
81 result = ComponentAccumulator()
82 # High-purity candidate cores are built with OverlapThreshold.
83 # A second low-threshold recovery pass is added with EdgeThreshold.
84 # This strongly reduces candidate loss from borderline true edges.
85 kwargs.setdefault("EdgeThreshold", 0.25)
86 kwargs.setdefault("OverlapThreshold", 0.8)
87 kwargs.setdefault("UseRecoveryComponents", True)
88 kwargs.setdefault("SymmetrizeDirectedEdges", True)
89 kwargs.setdefault("AddAllSegmentsRecoveryCandidate", False)
90 kwargs.setdefault("KeepIsolatedSegments", False)
91 kwargs.setdefault("MinCandidateSize", 2)
92 tool = CompFactory.MuonML.SegmentTrackCandidateBuilderTool(name, **kwargs)
93 result.setPrivateTools(tool)
94 return result
95
96