ATLAS Offline Software
Loading...
Searching...
No Matches
InferenceConfig.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
5
6def GraphInferenceAlgCfg(flags, name = "GraphInferenceAlg", **kwargs):
7 result = ComponentAccumulator()
8 the_alg = CompFactory.MuonML.InferenceAlg(name, **kwargs)
9 result.addEventAlgo(the_alg, primary = True)
10 return result
11
12def GraphSPFilterToolCfg(flags, name ="GraphSPFilterTool", **kwargs):
13
14 from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
15
16 result = ComponentAccumulator()
17 kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname="/eos/atlas/atlascerngroupdisk/data-art/grid-input/MuonRecRTT/TestModel.onnx")))
18 kwargs.setdefault("MLFilterCut", -3.6) # Working point cut
19
20 the_tool = CompFactory.MuonML.GraphSPFilterTool(name, **kwargs)
21 result.setPrivateTools(the_tool)
22 return result
23
24def GraphBucketFilterToolCfg(flags, name ="GraphBucketFilterTool", **kwargs):
25
26 from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
27
28 result = ComponentAccumulator()
29 model_path = kwargs.pop("ModelPath", "edgecnn_multi_bucket_sparse_meta.onnx")
30
31 if not model_path.startswith('/'):
32 pass
33 else:
34 pass
35
36 kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname=model_path)))
37 # BiasClass0: Working point selection bias for multi-class comparison
38 # Higher values make class 0 (reject) less likely, accepting more buckets
39 # Typical values: 0.5-2.0. Default: 1.0 (no bias)
40 kwargs.setdefault("BiasClass0", 1.0)
41 kwargs.setdefault("OutputLevel", 3) # INFO level (1=VERBOSE, 2=DEBUG, 3=INFO, 4=WARNING, 5=ERROR, 6=FATAL)
42
43 the_tool = CompFactory.MuonML.GraphBucketFilterTool(name, **kwargs)
44 result.setPrivateTools(the_tool)
45 return result
46
GraphInferenceAlgCfg(flags, name="GraphInferenceAlg", **kwargs)
GraphBucketFilterToolCfg(flags, name="GraphBucketFilterTool", **kwargs)
GraphSPFilterToolCfg(flags, name="GraphSPFilterTool", **kwargs)