ATLAS Offline Software
Loading...
Searching...
No Matches
AthExOnnxRuntime_test_async_infer.py
Go to the documentation of this file.
1# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
5from AthenaCommon import Constants
6from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
7from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
8
9
10def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
11 acc = ComponentAccumulator()
12
13 model_fname = "dev/MLTest/2020-03-02/MNIST_testModel.onnx"
14 execution_provider = OnnxRuntimeType.CPU
15 kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
16 OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
17 ))
18
19 input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
20 kwargs.setdefault("BatchSize", 100)
21 kwargs.setdefault("InputDataPixel", input_data)
22 kwargs.setdefault("OutputLevel", Constants.DEBUG)
23 acc.addEventAlgo(
24 CompFactory.AthOnnx.EvaluateModelWithAsyncInfer(name, **kwargs))
25
26 return acc
27
28
29if __name__ == "__main__":
30 from AthenaCommon.Logging import log as msg
31 from AthenaConfiguration.AllConfigFlags import initConfigFlags
32 from AthenaConfiguration.MainServicesConfig import MainServicesCfg
33
34 msg.setLevel(Constants.DEBUG)
35
36 flags = initConfigFlags()
37 flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
38 flags.Concurrency.NumThreads = 3
39 flags.Concurrency.NumOffloadThreads = 1
40 flags.Exec.FPE = -1
41 flags.lock()
42
43 acc = MainServicesCfg(flags)
44 acc.merge(AthExOnnxRuntimeExampleCfg(flags))
45 acc.printConfig(withDetails=True, summariseProps=True)
46
47 acc.store(open('test_AsyncInferORTExampleCfg.pkl', 'wb'))
48
49 import sys
50 sys.exit(acc.run(2).isFailure())
AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs)