ATLAS Offline Software
Loading...
Searching...
No Matches
AthExTriton_test.py
Go to the documentation of this file.
1# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
5from AthenaCommon import Constants
6
7
8def AthExTritonCfg(flags, name="AthExTritonExample", **kwargs):
9 acc = ComponentAccumulator()
10 from AthTritonComps.TritonToolConfig import TritonToolCfg
11 try:
12 kwargs.setdefault("InferenceTool", acc.popToolsAndMerge(
13 TritonToolCfg(flags, "MNIST_testModel", "localhost",
14 name="EvaluateModelTritonTool")
15 ))
16 except RuntimeError as e:
17 import sys
18 from AthenaCommon.Logging import log as msg
19 msg.warning(e)
20 sys.exit(2) # indicate test is skipped, not failed
21
22 input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
23 kwargs.setdefault("BatchSize", 2)
24 kwargs.setdefault("InputDataPixel", input_data)
25 kwargs.setdefault("OutputLevel", Constants.DEBUG)
26 acc.addEventAlgo(
27 CompFactory.AthInfer.ExampleMLInferenceWithTriton(name, **kwargs))
28
29 return acc
30
31
32if __name__ == "__main__":
33 from AthenaCommon.Logging import log as msg
34 from AthenaConfiguration.AllConfigFlags import initConfigFlags
35 from AthenaConfiguration.MainServicesConfig import MainServicesCfg
36
37 msg.setLevel(Constants.DEBUG)
38
39 flags = initConfigFlags()
40 flags.lock()
41
42 acc = MainServicesCfg(flags)
43 acc.merge(AthExTritonCfg(flags))
44 acc.printConfig(withDetails=True, summariseProps=True)
45
46 acc.store(open('test_AthExTritonCfg.pkl', 'wb'))
47
48 import sys
49 sys.exit(acc.run(2).isFailure())
AthExTritonCfg(flags, name="AthExTritonExample", **kwargs)