ATLAS Offline Software
InferenceConfig.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory import CompFactory
5 
6 def GraphInferenceAlgCfg(flags, name = "GraphInferenceAlg", **kwargs):
7  result = ComponentAccumulator()
8  the_alg = CompFactory.MuonML.InferenceAlg(name, **kwargs)
9  result.addEventAlgo(the_alg, primary = True)
10  return result
11 
12 def GraphSPFilterToolCfg(flags, name ="GraphSPFilterTool", **kwargs):
13 
14  from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
15 
16  result = ComponentAccumulator()
17  kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname="/eos/atlas/atlascerngroupdisk/data-art/grid-input/MuonRecRTT/TestModel.onnx")))
18  kwargs.setdefault("MLFilterCut", -3.6) # Working point cut
19 
20  the_tool = CompFactory.MuonML.GraphSPFilterTool(name, **kwargs)
21  result.setPrivateTools(the_tool)
22  return result
23 
24 def GraphBucketFilterToolCfg(flags, name ="GraphBucketFilterTool", **kwargs):
25 
26  from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
27 
28  result = ComponentAccumulator()
29  kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname="MuonInference/edgecnn_multi_bucket_sparse_meta.onnx")))
30  kwargs.setdefault("BiasClass0", 1.0) # Working point selection bias for multi-class comparison
31  kwargs.setdefault("OutputLevel", 3) # DEBUG level (1=VERBOSE, 2=DEBUG, 3=INFO, 4=WARNING, 5=ERROR, 6=FATAL)
32 
33  the_tool = CompFactory.MuonML.GraphBucketFilterTool(name, **kwargs)
34  result.setPrivateTools(the_tool)
35  return result
36 
InferenceConfig.GraphInferenceAlgCfg
def GraphInferenceAlgCfg(flags, name="GraphInferenceAlg", **kwargs)
Definition: InferenceConfig.py:6
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:342
InferenceConfig.GraphSPFilterToolCfg
def GraphSPFilterToolCfg(flags, name="GraphSPFilterTool", **kwargs)
Definition: InferenceConfig.py:12
python.OnnxRuntimeSessionConfig.OnnxRuntimeSessionToolCfg
def OnnxRuntimeSessionToolCfg(flags, str model_fname, Optional[OnnxRuntimeType] execution_provider=None, name="OnnxRuntimeSessionTool", **kwargs)
Definition: OnnxRuntimeSessionConfig.py:8
InferenceConfig.GraphBucketFilterToolCfg
def GraphBucketFilterToolCfg(flags, name="GraphBucketFilterTool", **kwargs)
Definition: InferenceConfig.py:24