ATLAS Offline Software
OnnxRuntimeSessionConfig.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory import CompFactory
5 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
6 from typing import Optional
7 
9  model_fname: str,
10  execution_provider: Optional[OnnxRuntimeType] = None,
11  name="OnnxRuntimeSessionTool", **kwargs):
12  """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
13 
14  acc = ComponentAccumulator()
15 
16 
17  execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
18  name += execution_provider.name
19 
20  if "OnnxRuntimeSvc" not in kwargs:
21  from AthOnnxComps.OnnxRuntimeSvcConfig import OnnxRuntimeSvcCfg
22  kwargs.setdefault("OnnxRuntimeSvc", acc.getPrimaryAndMerge(OnnxRuntimeSvcCfg(flags)))
23  kwargs.setdefault("ModelFileName", model_fname)
24  if execution_provider is OnnxRuntimeType.CPU:
25  acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
26  elif execution_provider is OnnxRuntimeType.CUDA:
27  acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs))
28  else:
29  raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
30 
31  return acc
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
python.OnnxRuntimeSvcConfig.OnnxRuntimeSvcCfg
def OnnxRuntimeSvcCfg(flags, name="OnnxRuntimeSvc", **kwargs)
Definition: OnnxRuntimeSvcConfig.py:6
python.OnnxRuntimeSessionConfig.OnnxRuntimeSessionToolCfg
def OnnxRuntimeSessionToolCfg(flags, str model_fname, Optional[OnnxRuntimeType] execution_provider=None, name="OnnxRuntimeSessionTool", **kwargs)
Definition: OnnxRuntimeSessionConfig.py:8