58 forceFold=-1, **kwargs):
59 """Configure the Normalizing Flow-based photon shower shape correction tool"""
60 acc = ComponentAccumulator()
61
62 from AthenaConfiguration.Enums import LHCPeriod
63
64 if not flags.Input.isMC:
65 raise RuntimeError("ElectronPhotonVariableNFCorrectionToolCfg: "
66 "NF correction tool should not be called for data"
67 )
68
69 isFullSim = flags.Sim.ISF.Simulator.isFullSim()
70 isRun3 = flags.GeoModel.Run is LHCPeriod.Run3
71 isRun2 = flags.GeoModel.Run is LHCPeriod.Run2
72
73 if isFullSim and isRun3:
74 default_conf = "EGammaVariableCorrection/NF_y_TUNE1/Run3FS/ElectronPhotonVariableNFCorrectionTool.conf"
75 elif isFullSim and isRun2:
76 default_conf = "EGammaVariableCorrection/NF_y_TUNE1/Run2FS/ElectronPhotonVariableNFCorrectionTool.conf"
77 elif not isFullSim and isRun3:
78 default_conf = "EGammaVariableCorrection/NF_y_TUNE1/Run3AF3/ElectronPhotonVariableNFCorrectionTool.conf"
79 elif not isFullSim and isRun2:
80
81 default_conf = "EGammaVariableCorrection/NF_y_TUNE1/Run3AF3/ElectronPhotonVariableNFCorrectionTool.conf"
82 else:
83 raise RuntimeError(
84 f"ElectronPhotonVariableNFCorrectionToolCfg: no NF correction config available for Run period {flags.GeoModel.Run} "
85 f"(isFullSim={isFullSim}). Only Run2 and Run3 are supported."
86 )
87
88 conf_key = kwargs.setdefault("ConfigFile", default_conf)
89 if forceFold>=0:
90 kwargs.setdefault("forceOneFold", True)
91
93 if not conf_file:
94 raise RuntimeError(f"PathResolver cannot find {conf_key}")
95
96
97 n_folds = None
98 pattern = None
99 with open(conf_file, 'r') as f:
100 for line in f:
101 line = line.strip()
102 if not line or line.startswith('#'):
103 continue
104 key, _, value = line.partition(':')
105 key = key.strip()
106 value = value.strip()
107 if key == 'NFolds':
108 n_folds = int(value)
109 elif key == 'ONNXnamePattern':
110 pattern = value
111
112 if n_folds is None or pattern is None:
113 raise RuntimeError(f'NFolds or ONNXnamePattern not found in config: {conf_file}')
114
115
116 forward_tools = []
117 backward_tools = []
118 for i in range(n_folds):
119
120 if forceFold>=0 and i!=forceFold:
121 continue
122
123 fwd_session = CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(
124 f'NFCorrectionORTSessionToolForward_{i}',
125 ModelFileName=f'{pattern}_forward_{i}.onnx')
126 fwd_tool = CompFactory.AthOnnx.OnnxRuntimeInferenceTool(
127 f'NFCorrectionOnnxToolForward_{i}',
128 ORTSessionTool=fwd_session)
129 forward_tools.append(fwd_tool)
130
131 bwd_session = CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(
132 f'NFCorrectionORTSessionToolBackward_{i}',
133 ModelFileName=f'{pattern}_backward_{i}.onnx')
134 bwd_tool = CompFactory.AthOnnx.OnnxRuntimeInferenceTool(
135 f'NFCorrectionOnnxToolBackward_{i}',
136 ORTSessionTool=bwd_session)
137 backward_tools.append(bwd_tool)
138
139 kwargs.setdefault("OnnxInferenceToolsForward", forward_tools)
140 kwargs.setdefault("OnnxInferenceToolsBackward", backward_tools)
141
142 acc.setPrivateTools(
143 CompFactory.ElectronPhotonVariableNFCorrectionTool(name, **kwargs))
144 return acc
145
146
static std::string FindCalibFile(const std::string &logical_file_name)