ATLAS Offline Software
TrigTauRecConfig.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
3 from AthenaCommon.Logging import logging
4 log = logging.getLogger('TrigTauRecConfig')
5 
6 def trigTauRecMergedPrecisionMVACfg(flags, name, tau_ids=None, input_rois='', input_tracks='', output_name=None):
7  '''
8  Reconstruct the precision TauJet, from the first-step CaloMVA TauJet and precision-refitted tracks.
9 
10  :param flags: Config flags.
11  :param name: Suffix for the main TrigTauRecMerged algorithm name.
12  :param tau_ids: List of inference algorithms to execute.
13  The specific configuration will be loaded from the matching ConfigFlags (Trigger.Offline.Tau.<alg-name>)
14  Currently, only the `DeepSet` and `RNNLLP` algorithms will use the LVNN inference setup (json config files);
15  all other ID algorithms will use the ONNX inference setup by default.
16  If the algorithm name (`name` input variable) is `MVA`, `LLP` or `LRT`, and `tau_ids=['DeepSet', 'MesonCuts']` or `tau_ids=['RNNLLP']`,
17  then the default TauJet RNN score and WP `isTau` decorators will be used (for the legacy
18  `mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT` triggers).
19  Otherwise, all scores and WPs will be stored as `{tau_id}_Score`, `{tau_id}_ScoreSigTrans`, and `{tau_id}_{wp_name}`.
20  :param input_rois: RoIs container, where the reconstruction will be run.
21  :param input_tracks: TrackParticle container, with the refitted precision tracks.
22  :param output_name: Suffix for the output TauJet and TauTrack collections. If `None`, `name` will be used.
23 
24  :return: CA with the TauJet Precision reconstruction sequence.
25  '''
26 
27  # Output collections
28  if output_name is None: output_name = name
29  from TrigEDMConfig.TriggerEDM import recordable
30  trigTauJetOutputContainer = recordable(f'HLT_TrigTauRecMerged_{output_name}')
31  trigTauTrackOutputContainer = recordable(f'HLT_tautrack_{output_name}')
32 
33  # Main CA
34  from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
35  acc = ComponentAccumulator()
36 
37 
38  # The TauJet reconstruction is handled by a set of tools, executed in the following order:
39  vftools = [] # Vertex Finder tools
40  tools_beforetf = [] # Common tools, ran before the Track Finder tools
41  tftools = [] # Track Finder tools
42  tools = [] # Common tools
43  vvtools = [] # Vertex Vars tools
44  idtools = [] # ID tools
45 
46 
47  from TrigTauRec.TrigTauRecToolsConfig import trigTauVertexFinderCfg, trigTauTrackFinderCfg, tauVertexVariablesCfg
48  from AthenaConfiguration.ComponentFactory import CompFactory
49 
50  # Associate RoI vertex or Beamspot to the tau - don't use TJVA
51  vftools.append(acc.popToolsAndMerge(trigTauVertexFinderCfg(flags, name='TrigTau_TauVertexFinder')))
52 
53  # Set LC energy scale (0.2 cone) and intermediate axis (corrected for vertex: useless at trigger)
54  tools_beforetf.append(CompFactory.TauAxisSetter(name='TrigTau_TauAxis', VertexCorrection=False))
55 
56  # Associate tracks to the tau
57  tftools.append(acc.popToolsAndMerge(trigTauTrackFinderCfg(flags, name='TrigTauTightDZ_TauTrackFinder', TrackParticlesContainer=input_tracks)))
58 
59  # Decorate the clusters
60  tools.append(CompFactory.TauClusterFinder(name='TrigTau_TauClusterFinder', UseOriginalCluster=False))
61  tools.append(CompFactory.TauVertexedClusterDecorator(name='TrigTau_TauVertexedClusterDecorator', SeedJet=''))
62 
63  # Calculate cell-based quantities: strip variables, EM and Had energies/radii, centFrac, isolFrac and ring energies
64  tools.append(CompFactory.TauCellVariables(name='TrigTau_CellVariables', VertexCorrection=False))
65 
66  # Compute MVA TES (ATR-17649), stores MVA TES as default tau pt
67  tools.append(CompFactory.MvaTESVariableDecorator(name='TrigTau_MvaTESVariableDecorator', Key_vertexInputContainer='', EventShapeKey='', VertexCorrection=False))
68  acc.addPublicTool(tools[-1])
69  tools.append(CompFactory.MvaTESEvaluator(name='TrigTau_MvaTESEvaluator', WeightFileName=flags.Trigger.Offline.Tau.MvaTESConfig))
70  acc.addPublicTool(tools[-1])
71 
72  # Vertex variables
73  vvtools.append(acc.popToolsAndMerge(tauVertexVariablesCfg(flags, name='TrigTau_TauVertexVariables')))
74 
75  # Variables combining tracking and calorimeter information
76  idtools.append(CompFactory.TauCommonCalcVars(name='TrigTau_TauCommonCalcVars'))
77 
78  # Cluster-based sub-structure, with dRMax
79  idtools.append(CompFactory.TauSubstructureVariables(name='TrigTau_TauSubstructure', VertexCorrection=False))
80 
81  #---------------------------------------------------------------
82  # Tau ID and score flattenning
83  #---------------------------------------------------------------
84  # We can run multiple inferences at once. Each will be stored on different decorated variables
85  # (or isTau(...) flags in the case of the legacy RNN/DeepSet tracktwoMVA/LLP/LRT triggers
86 
87  # We first "remove" the "ids" that don't require any inference, and any duplicates
88  tau_ids = list(set(tau_ids if tau_ids else []) - {'perf', 'idperf', 'MesonCuts'})
89 
90  from TriggerMenuMT.HLT.Tau.TauConfigurationTools import getTauIDScoreVariables
91  id_score_monitoring = {}
92 
93  # We can only have at most one TauID algorithm score/WPs being stored in the built-in TauJet RNN variables
94  used_builtin_rnnscore = False
95 
96  for tau_id in tau_ids:
97  # First check that the TauID algorithm has the necesary config flags defined
98  try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
99  except NameError: raise ValueError(f'Missing TauID ConfigFlags: Trigger.Offline.Tau.{tau_id}')
100 
101  # Now check if it's an ONNX-based TauID, or an LVNN-based TauID
102  is_onnx = hasattr(id_flags, 'ONNXConfig')
103 
104  if is_onnx: # ONNX inference
105  log.debug('Configuring TrigTauRecMerged with the ONNX Tau ID score inference: %s', tau_id)
106 
107  from TrigTauRec.TrigTauRecToolsConfig import trigTauJetONNXEvaluatorCfg, trigTauWPDecoratorCfg
108 
109  # ONNX (GNTau) inference
110  idtools.append(acc.popToolsAndMerge(trigTauJetONNXEvaluatorCfg(flags, tau_id=tau_id)))
111  acc.addPublicTool(idtools[-1])
112 
113  # ID score flattening and WPs
114  idtools.append(acc.popToolsAndMerge(trigTauWPDecoratorCfg(flags, tau_id=tau_id, precision_seq_name=name)))
115  acc.addPublicTool(idtools[-1])
116 
117 
118  else: # LVNN inference
119  log.debug('Configuring TrigTauRecMerged with the LVNN Tau ID score inference: %s', tau_id)
120 
121  from TriggerMenuMT.HLT.Tau.TauConfigurationTools import useBuiltInTauJetRNNScore
122 
123  # To support the legacy tracktwoMVA/LLP/LRT chains, only in those cases we store the
124  # passed WPs in the built-in TauJet variables
125  use_builtin_rnnscore = useBuiltInTauJetRNNScore(tau_id, precision_sequence=name)
126  if use_builtin_rnnscore:
127  if used_builtin_rnnscore:
128  log.error('Cannot store more than one TauID score in the built-in TauJet RNN score variables')
129  raise ValueError()
130  used_builtin_rnnscore = True
131 
132  # LVNN (RNN/DeepSet) inference
133  from TrigTauRec.TrigTauRecToolsConfig import trigTauJetLVNNEvaluatorCfg
134  idtools.append(acc.popToolsAndMerge(trigTauJetLVNNEvaluatorCfg(flags, tau_id=tau_id, use_taujet_rnnscore=use_builtin_rnnscore)))
135  acc.addPublicTool(idtools[-1])
136 
137  # ID score flattening and WPs
138  if use_builtin_rnnscore:
139  from TrigTauRec.TrigTauRecToolsConfig import trigTauWPDecoratorRNNCfg
140  idtools.append(acc.popToolsAndMerge(trigTauWPDecoratorRNNCfg(flags, tau_id=tau_id, precision_seq_name=name)))
141  acc.addPublicTool(idtools[-1])
142  else:
143  from TrigTauRec.TrigTauRecToolsConfig import trigTauWPDecoratorCfg
144  idtools.append(acc.popToolsAndMerge(trigTauWPDecoratorCfg(flags, tau_id=tau_id, precision_seq_name=name)))
145  acc.addPublicTool(idtools[-1])
146 
147  id_score_monitoring[tau_id] = getTauIDScoreVariables(tau_id, precision_sequence=name)
148 
149 
150  # Set trigger-specific configuration for all the reconstruction tools
151  for tool in vftools + tools_beforetf + tftools + tools + vvtools + idtools:
152  tool.inTrigger = True
153  tool.calibFolder = flags.Trigger.Offline.Tau.tauRecToolsCVMFSPath
154 
155 
156  from TrigTauRec.TrigTauRecMonitoring import tauMonitoringPrecision
157  acc.addEventAlgo(CompFactory.TrigTauRecMerged(
158  name=f'TrigTauRecMerged_Precision_{name}',
159  VertexFinderTools=vftools,
160  CommonToolsBeforeTF=tools_beforetf,
161  TrackFinderTools=tftools,
162  CommonTools=tools,
163  VertexVarsTools=vvtools,
164  IDTools=idtools,
165  MonTool=tauMonitoringPrecision(flags, RoI_name='tauLRT' if 'LRT' in name else 'tauIso', tau_ids=id_score_monitoring.keys(), alg_name=name),
166  MonitoredIDScores=id_score_monitoring,
167  InputRoIs=input_rois,
168  InputVertexContainer=flags.Tracking.ActiveConfig.vertex,
169  InputTauTrackContainer='HLT_tautrack_dummy',
170  InputTauJetContainer='HLT_TrigTauRecMerged_CaloMVAOnly',
171  OutputTauTrackContainer=trigTauTrackOutputContainer,
172  OutputTauJetContainer=trigTauJetOutputContainer,
173  ))
174 
175  return acc
176 
177 
179  '''
180  Reconstruct the CaloMVA TauJet from the calo-clusters.
181 
182  :param flags: Config flags.
183  :return: CA with the TauJet CaloMVA reconstruction sequence.
184  '''
185  # Main CA
186  from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
187  acc = ComponentAccumulator()
188 
189  tools = []
190 
191  from AthenaConfiguration.ComponentFactory import CompFactory
192 
193  # Set seedcalo energy scale (Full RoI)
194  tools.append(CompFactory.JetSeedBuilder())
195 
196  # Set LC energy scale (0.2 cone) and intermediate axis (corrected for vertex: useless at trigger)
197  tools.append(CompFactory.TauAxisSetter(ClusterCone=0.2, VertexCorrection=False))
198 
199  # Decorate the clusters
200  tools.append(CompFactory.TauClusterFinder(UseOriginalCluster=False)) # TODO: use JetRec.doVertexCorrection once available
201  tools.append(CompFactory.TauVertexedClusterDecorator(SeedJet=''))
202 
203  # Calculate cell-based quantities: strip variables, EM and Had energies/radii, centFrac, isolFrac and ring energies
204  from AthenaCommon.SystemOfUnits import GeV
205  tools.append(CompFactory.TauCellVariables(StripEthreshold=0.2*GeV, CellCone=0.2, VertexCorrection = False))
206 
207  # Compute MVA TES (ATR-17649), stores MVA TES as the default tau pt
208  tools.append(CompFactory.MvaTESVariableDecorator(Key_vertexInputContainer='', EventShapeKey='', VertexCorrection=False))
209  acc.addPublicTool(tools[-1])
210  tools.append(CompFactory.MvaTESEvaluator(WeightFileName=flags.Trigger.Offline.Tau.MvaTESConfig))
211  acc.addPublicTool(tools[-1])
212 
213 
214  # Set trigger-specific configuration for all the reconstruction tools
215  for tool in tools:
216  tool.inTrigger = True
217  tool.calibFolder = flags.Trigger.Offline.Tau.tauRecToolsCVMFSPath
218 
219 
220  from TrigEDMConfig.TriggerEDM import recordable
221  from TrigTauRec.TrigTauRecMonitoring import tauMonitoringCaloOnlyMVA
222  acc.addEventAlgo(CompFactory.TrigTauRecMerged(
223  name='TrigTauRecMerged_TauCaloOnlyMVA',
224  CommonTools=tools,
225  MonTool=tauMonitoringCaloOnlyMVA(flags),
226  InputRoIs='UpdatedCaloRoI',
227  InputCaloClusterContainer='HLT_TopoCaloClustersLC',
228  OutputTauTrackContainer='HLT_tautrack_dummy',
229  OutputTauJetContainer='HLT_TrigTauRecMerged_CaloMVAOnly',
230  OutputJetSeed=recordable('HLT_jet_seed'),
231  ))
232 
233  return acc
234 
235 
236 
237 if __name__ == '__main__':
238  from AthenaConfiguration.AllConfigFlags import initConfigFlags
239  from AthenaConfiguration.TestDefaults import defaultTestFiles
240  flags = initConfigFlags()
241  flags.Input.Files = defaultTestFiles.RAW_RUN2
242  flags.lock()
243 
245  acc.printConfig(withDetails=True, summariseProps=True)
246  acc.wasMerged() # Do not run, do not save, we just want to see the config
SystemOfUnits
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
TrigTauRecToolsConfig.trigTauTrackFinderCfg
def trigTauTrackFinderCfg(flags, name='', TrackParticlesContainer='')
Definition: TrigTauRecToolsConfig.py:23
TrigTauRecToolsConfig.trigTauWPDecoratorRNNCfg
def trigTauWPDecoratorRNNCfg(flags, str tau_id, str precision_seq_name)
Definition: TrigTauRecToolsConfig.py:152
TrigTauRecConfig.trigTauRecMergedPrecisionMVACfg
def trigTauRecMergedPrecisionMVACfg(flags, name, tau_ids=None, input_rois='', input_tracks='', output_name=None)
Definition: TrigTauRecConfig.py:6
python.HLT.Tau.TauConfigurationTools.useBuiltInTauJetRNNScore
bool useBuiltInTauJetRNNScore(str tau_id, str precision_sequence)
Definition: TauConfigurationTools.py:82
TrigTauRecToolsConfig.trigTauVertexFinderCfg
def trigTauVertexFinderCfg(flags, name='')
Definition: TrigTauRecToolsConfig.py:7
TrigTauRecToolsConfig.trigTauWPDecoratorCfg
def trigTauWPDecoratorCfg(flags, str tau_id, str precision_seq_name)
Definition: TrigTauRecToolsConfig.py:192
TrigTauRecMonitoring.tauMonitoringPrecision
def tauMonitoringPrecision(flags, str name='Precision', str RoI_name='tauIso', list[str] tau_ids=[], str alg_name='')
Definition: TrigTauRecMonitoring.py:71
python.HLT.Tau.TauConfigurationTools.getTauIDScoreVariables
tuple[str, str] getTauIDScoreVariables(str tau_id, str precision_sequence)
Definition: TauConfigurationTools.py:91
TrigTauRecConfig.trigTauRecMergedCaloMVACfg
def trigTauRecMergedCaloMVACfg(flags)
Definition: TrigTauRecConfig.py:178
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
TrigTauRecMonitoring.tauMonitoringCaloOnlyMVA
def tauMonitoringCaloOnlyMVA(flags, str name='CaloMVA', str RoI_name='L1')
Definition: TrigTauRecMonitoring.py:5
CxxUtils::set
constexpr std::enable_if_t< is_bitmask_v< E >, E & > set(E &lhs, E rhs)
Convenience function to set bits in a class enum bitmask.
Definition: bitmask.h:232
TrigTauRecToolsConfig.tauVertexVariablesCfg
def tauVertexVariablesCfg(flags, name='')
Definition: TrigTauRecToolsConfig.py:57
python.AllConfigFlags.initConfigFlags
def initConfigFlags()
Definition: AllConfigFlags.py:19
TrigTauRecToolsConfig.trigTauJetLVNNEvaluatorCfg
def trigTauJetLVNNEvaluatorCfg(flags, tau_id='', use_taujet_rnnscore=True)
Definition: TrigTauRecToolsConfig.py:75
python.TriggerEDM.recordable
def recordable(arg, runVersion=3)
Definition: TriggerEDM.py:34
TrigTauRecToolsConfig.trigTauJetONNXEvaluatorCfg
def trigTauJetONNXEvaluatorCfg(flags, tau_id='')
Definition: TrigTauRecToolsConfig.py:114