Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
InferenceConfig.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory import CompFactory
5 
6 def GraphInferenceAlgCfg(flags, name = "GraphInferenceAlg", **kwargs):
7  result = ComponentAccumulator()
8  the_alg = CompFactory.MuonML.InferenceAlg(name, **kwargs)
9  result.addEventAlgo(the_alg, primary = True)
10  return result
11 
12 def GraphBucketFilterToolCfg(flags, name ="GraphBucketFilterTool", **kwargs):
13 
14  from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
15 
16  result = ComponentAccumulator()
17  #kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname="MuonSPId/EdgeGAT_FCG_8vars_quantized_metadata.onnx")))
18  kwargs.setdefault("ModelSession", result.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname="/eos/atlas/atlascerngroupdisk/data-art/grid-input/MuonRecRTT/TestModel.onnx")))
19  kwargs.setdefault("MLFilterCut", -2.7)
20 
21  the_tool = CompFactory.MuonML.GraphBucketFilterTool(name, **kwargs)
22  result.setPrivateTools(the_tool)
23  return result
InferenceConfig.GraphInferenceAlgCfg
def GraphInferenceAlgCfg(flags, name="GraphInferenceAlg", **kwargs)
Definition: InferenceConfig.py:6
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
python.OnnxRuntimeSessionConfig.OnnxRuntimeSessionToolCfg
def OnnxRuntimeSessionToolCfg(flags, str model_fname, Optional[OnnxRuntimeType] execution_provider=None, name="OnnxRuntimeSessionTool", **kwargs)
Definition: OnnxRuntimeSessionConfig.py:8
InferenceConfig.GraphBucketFilterToolCfg
def GraphBucketFilterToolCfg(flags, name="GraphBucketFilterTool", **kwargs)
Definition: InferenceConfig.py:12