ATLAS Offline Software
Loading...
Searching...
No Matches
AthExTriton_test.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
5from AthenaCommon import Constants
6
7
8def AthExTritonCfg(flags, name="AthExTritonExample", **kwargs):
9 acc = ComponentAccumulator()
10
11 from AthTritonComps.TritonToolConfig import TritonToolCfg
12 kwargs.setdefault("InferenceTool", acc.popToolsAndMerge(
13 TritonToolCfg(flags, "MNIST_testModel", "localhost", name="EvaluateModelTritonTool")
14 ))
15
16
17 input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
18 kwargs.setdefault("BatchSize", 2)
19 kwargs.setdefault("InputDataPixel", input_data)
20 kwargs.setdefault("OutputLevel", Constants.DEBUG)
21 acc.addEventAlgo(CompFactory.AthInfer.ExampleMLInferenceWithTriton(name, **kwargs))
22
23 return acc
24
25if __name__ == "__main__":
26 from AthenaCommon.Logging import log as msg
27 from AthenaConfiguration.AllConfigFlags import initConfigFlags
28 from AthenaConfiguration.MainServicesConfig import MainServicesCfg
29
30 msg.setLevel(Constants.DEBUG)
31
32 flags = initConfigFlags()
33 flags.lock()
34
35 acc = MainServicesCfg(flags)
36 acc.merge(AthExTritonCfg(flags))
37 acc.printConfig(withDetails=True, summariseProps=True)
38
39 acc.store(open('test_AthExTritonCfg.pkl','wb'))
40
41 import sys
42 sys.exit(acc.run(2).isFailure())
AthExTritonCfg(flags, name="AthExTritonExample", **kwargs)