ATLAS Offline Software
AthExTriton_test.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory import CompFactory
5 from AthenaCommon import Constants
6 
7 
8 def AthExTritonCfg(flags, name="AthExTritonExample", **kwargs):
10 
11  from AthTritonComps.TritonToolConfig import TritonToolCfg
12  kwargs.setdefault("InferenceTool", acc.popToolsAndMerge(
13  TritonToolCfg(flags, "MNIST_testModel", "localhost", name="EvaluateModelTritonTool")
14  ))
15 
16 
17  input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
18  kwargs.setdefault("BatchSize", 2)
19  kwargs.setdefault("InputDataPixel", input_data)
20  kwargs.setdefault("OutputLevel", Constants.DEBUG)
21  acc.addEventAlgo(CompFactory.AthInfer.ExampleMLInferenceWithTriton(name, **kwargs))
22 
23  return acc
24 
25 if __name__ == "__main__":
26  from AthenaCommon.Logging import log as msg
27  from AthenaConfiguration.AllConfigFlags import initConfigFlags
28  from AthenaConfiguration.MainServicesConfig import MainServicesCfg
29 
30  msg.setLevel(Constants.DEBUG)
31 
32  flags = initConfigFlags()
33  flags.lock()
34 
35  acc = MainServicesCfg(flags)
36  acc.merge(AthExTritonCfg(flags))
37  acc.printConfig(withDetails=True, summariseProps=True)
38 
39  acc.store(open('test_AthExTritonCfg.pkl','wb'))
40 
41  import sys
42  sys.exit(acc.run(2).isFailure())
AthExTriton_test.AthExTritonCfg
def AthExTritonCfg(flags, name="AthExTritonExample", **kwargs)
Definition: AthExTriton_test.py:8
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
python.TritonToolConfig.TritonToolCfg
def TritonToolCfg(flags, str model_name, str url, int port=8001, str model_version="", float timeout=0., bool ssl=False, name="TritonTool", **kwargs)
Definition: TritonToolConfig.py:6
python.MainServicesConfig.MainServicesCfg
def MainServicesCfg(flags, LoopMgr='AthenaEventLoopMgr')
Definition: MainServicesConfig.py:312
Trk::open
@ open
Definition: BinningType.h:40
python.AllConfigFlags.initConfigFlags
def initConfigFlags()
Definition: AllConfigFlags.py:19