3 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory
import CompFactory
5 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
6 from typing
import Optional
7 from AthOnnxComps.OnnxRuntimeSessionConfig
import OnnxRuntimeSessionToolCfg
10 model_fname: str =
None,
11 execution_provider: Optional[OnnxRuntimeType] =
None,
12 name=
"OnnxRuntimeInferenceTool", **kwargs):
13 """Configure OnnxRuntimeInferenceTool in Control/AthOnnx/AthOnnxComps/src"""
17 if "OnnxRuntimeSvc" not in kwargs:
18 from AthOnnxComps.OnnxRuntimeSvcConfig
import OnnxRuntimeSvcCfg
19 kwargs.setdefault(
"OnnxRuntimeSvc", acc.getPrimaryAndMerge(
OnnxRuntimeSvcCfg(flags)))
20 kwargs.setdefault(
"ORTSessionTool", acc.popToolsAndMerge(
OnnxRuntimeSessionToolCfg(flags, model_fname, execution_provider)))
21 acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeInferenceTool(name, **kwargs))