3from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory
import CompFactory
5from AthenaCommon
import Constants
9 acc = ComponentAccumulator()
11 from AthTritonComps.TritonToolConfig
import TritonToolCfg
12 kwargs.setdefault(
"InferenceTool", acc.popToolsAndMerge(
13 TritonToolCfg(flags,
"MNIST_testModel",
"localhost", name=
"EvaluateModelTritonTool")
17 input_data =
"dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
18 kwargs.setdefault(
"BatchSize", 2)
19 kwargs.setdefault(
"InputDataPixel", input_data)
20 kwargs.setdefault(
"OutputLevel", Constants.DEBUG)
21 acc.addEventAlgo(CompFactory.AthInfer.ExampleMLInferenceWithTriton(name, **kwargs))
25if __name__ ==
"__main__":
26 from AthenaCommon.Logging
import log
as msg
27 from AthenaConfiguration.AllConfigFlags
import initConfigFlags
28 from AthenaConfiguration.MainServicesConfig
import MainServicesCfg
30 msg.setLevel(Constants.DEBUG)
32 flags = initConfigFlags()
35 acc = MainServicesCfg(flags)
37 acc.printConfig(withDetails=
True, summariseProps=
True)
39 acc.store(open(
'test_AthExTritonCfg.pkl',
'wb'))
42 sys.exit(acc.run(2).isFailure())
AthExTritonCfg(flags, name="AthExTritonExample", **kwargs)