11 name="OnnxRuntimeSessionTool", **kwargs):
12 """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
13
14 acc = ComponentAccumulator()
15
16
17 execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
18 name += execution_provider.name
19
20 if "OnnxRuntimeSvc" not in kwargs:
21 from AthOnnxComps.OnnxRuntimeSvcConfig import OnnxRuntimeSvcCfg
22 kwargs.setdefault("OnnxRuntimeSvc", acc.getPrimaryAndMerge(OnnxRuntimeSvcCfg(flags)))
23 kwargs.setdefault("ModelFileName", model_fname)
24 if execution_provider is OnnxRuntimeType.CPU:
25 acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
26 elif execution_provider is OnnxRuntimeType.CUDA:
27 acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs))
28 else:
29 raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
30
31 return acc