ATLAS Offline Software
Loading...
Searching...
No Matches
AthExOnnxRuntime_test.py
Go to the documentation of this file.
1# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
5from AthenaCommon import Constants
6from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
7
8
9def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
10 acc = ComponentAccumulator()
11
12 model_fname = "dev/MLTest/2020-03-02/MNIST_testModel.onnx"
13 execution_provider = OnnxRuntimeType.CPU
14 from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
15 kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
16 OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
17 ))
18
19 input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
20 kwargs.setdefault("BatchSize", 3)
21 kwargs.setdefault("InputDataPixel", input_data)
22 kwargs.setdefault("OutputLevel", Constants.DEBUG)
23 acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs))
24
25 return acc
26
27if __name__ == "__main__":
28 from AthenaCommon.Logging import log as msg
29 from AthenaConfiguration.AllConfigFlags import initConfigFlags
30 from AthenaConfiguration.MainServicesConfig import MainServicesCfg
31
32 msg.setLevel(Constants.DEBUG)
33
34 flags = initConfigFlags()
35 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
36 flags.lock()
37
38 acc = MainServicesCfg(flags)
39 acc.merge(AthExOnnxRuntimeExampleCfg(flags))
40 acc.printConfig(withDetails=True, summariseProps=True)
41
42 acc.store(open('test_AthExOnnxRuntimeExampleCfg.pkl','wb'))
43
44 import sys
45 sys.exit(acc.run(2).isFailure())
AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs)