ATLAS Offline Software
AthExOnnxRuntime_test.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
2 
3 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory import CompFactory
5 from AthenaCommon import Constants
6 from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
7 
8 
9 def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
10  acc = ComponentAccumulator()
11 
12  model_fname = "dev/MLTest/2020-03-02/MNIST_testModel.onnx"
13  execution_provider = OnnxRuntimeType.CPU
14  from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
15  kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
16  OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
17  ))
18 
19  input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
20  kwargs.setdefault("BatchSize", 3)
21  kwargs.setdefault("InputDataPixel", input_data)
22  kwargs.setdefault("OutputLevel", Constants.DEBUG)
23  acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs))
24 
25  return acc
26 
27 if __name__ == "__main__":
28  from AthenaCommon.Logging import log as msg
29  from AthenaConfiguration.AllConfigFlags import initConfigFlags
30  from AthenaConfiguration.MainServicesConfig import MainServicesCfg
31 
32  msg.setLevel(Constants.DEBUG)
33 
34  flags = initConfigFlags()
35  flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
36  flags.lock()
37 
38  acc = MainServicesCfg(flags)
39  acc.merge(AthExOnnxRuntimeExampleCfg(flags))
40  acc.printConfig(withDetails=True, summariseProps=True)
41 
42  acc.store(open('test_AthExOnnxRuntimeExampleCfg.pkl','wb'))
43 
44  import sys
45  sys.exit(acc.run(2).isFailure())
AthExOnnxRuntime_test.AthExOnnxRuntimeExampleCfg
def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs)
Definition: AthExOnnxRuntime_test.py:9
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
python.MainServicesConfig.MainServicesCfg
def MainServicesCfg(flags, LoopMgr='AthenaEventLoopMgr')
Definition: MainServicesConfig.py:260
python.OnnxRuntimeInferenceConfig.OnnxRuntimeInferenceToolCfg
def OnnxRuntimeInferenceToolCfg(flags, str model_fname=None, Optional[OnnxRuntimeType] execution_provider=None, name="OnnxRuntimeInferenceTool", **kwargs)
Definition: OnnxRuntimeInferenceConfig.py:9
Trk::open
@ open
Definition: BinningType.h:40
python.AllConfigFlags.initConfigFlags
def initConfigFlags()
Definition: AllConfigFlags.py:19