3 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory
import CompFactory
5 from AthenaCommon
import Constants
6 from AthOnnxComps.OnnxRuntimeFlags
import OnnxRuntimeType
12 model_fname =
"dev/MLTest/2020-03-02/MNIST_testModel.onnx"
13 execution_provider = OnnxRuntimeType.CPU
14 from AthOnnxComps.OnnxRuntimeInferenceConfig
import OnnxRuntimeInferenceToolCfg
15 kwargs.setdefault(
"ORTInferenceTool", acc.popToolsAndMerge(
19 input_data =
"dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
20 kwargs.setdefault(
"BatchSize", 3)
21 kwargs.setdefault(
"InputDataPixel", input_data)
22 kwargs.setdefault(
"OutputLevel", Constants.DEBUG)
23 acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs))
27 if __name__ ==
"__main__":
28 from AthenaCommon.Logging
import log
as msg
29 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
30 from AthenaConfiguration.MainServicesConfig
import MainServicesCfg
32 msg.setLevel(Constants.DEBUG)
35 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
40 acc.printConfig(withDetails=
True, summariseProps=
True)
42 acc.store(
open(
'test_AthExOnnxRuntimeExampleCfg.pkl',
'wb'))
45 sys.exit(acc.run(2).isFailure())