10 execution_provider: Optional[OnnxRuntimeType] =
None,
11 name=
"OnnxRuntimeSessionTool", **kwargs):
12 """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
17 execution_provider = flags.AthOnnx.ExecutionProvider
if execution_provider
is None else execution_provider
18 name += execution_provider.name
20 if "OnnxRuntimeSvc" not in kwargs:
21 from AthOnnxComps.OnnxRuntimeSvcConfig
import OnnxRuntimeSvcCfg
22 kwargs.setdefault(
"OnnxRuntimeSvc", acc.getPrimaryAndMerge(
OnnxRuntimeSvcCfg(flags)))
23 kwargs.setdefault(
"ModelFileName", model_fname)
24 if execution_provider
is OnnxRuntimeType.CPU:
25 acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
26 elif execution_provider
is OnnxRuntimeType.CUDA:
27 acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs))
29 raise ValueError(
"Unknown OnnxRuntime Execution Provider: %s" % execution_provider)