ATLAS Offline Software
trfMPITools.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 
3 
8 
9 from enum import Enum
10 from copy import deepcopy
11 import os
12 import re
13 import logging
14 import pprint
15 import itertools as it
16 
17 from PyJobTransforms.trfExitCodes import trfExit
18 import PyJobTransforms.trfExceptions as trfExceptions
19 
20 msg = logging.getLogger(__name__)
21 
22 mpiConfig = None
23 
24 
25 class MPIType(Enum):
26  """MPI master, MPI worker, or not using MPI"""
27 
28  NOMPI = 0
29  MPIMASTER = 1
30  MPIWORKER = 2
31 
32 
33 def signalError(message):
34  msg.error(message)
36  trfExit.nameToCode("TRF_SETUP"), message
37  )
38 
39 
40 def getMPIRank():
41  """Return MPI rank"""
42  if mpiConfig is not None:
43  return int(mpiConfig["rank"])
44  if "RANK" not in os.environ:
45  return -1
46  else:
47  try:
48  return int(os.environ["RANK"])
49  except ValueError:
50  signalError("$RANK environment variable is not an integer")
51  return -2 # Only here to placate PyRight
52 
53 
54 def getMPIType():
55  """Return MPI type"""
56  if mpiConfig is not None:
57  return mpiConfig["type"]
58  if "RANK" not in os.environ:
59  return MPIType.NOMPI
60  if getMPIRank() == 0:
61  return MPIType.MPIMASTER
62  else:
63  return MPIType.MPIWORKER
64 
65 
66 def setupMPIConfig(output, dataDict):
67  """Check environment is correct if we are in MPI mode, and setup dictionaries"""
68  global mpiConfig
69  if "RANK" not in os.environ:
71  "Running in MPI mode but the $RANK environment variable is not set!"
72  )
73  rank = getMPIRank()
74  if not os.getcwd().endswith("rank-{}".format(getMPIRank())):
76  "Running in MPI mode with rank {0} but working directory is not called rank-{0}".format(
77  getMPIRank()
78  )
79  )
80  mpiType = getMPIType()
81  mpiConfig = {}
82  mpiConfig["rank"] = rank
83  mpiConfig["type"] = mpiType
84  mpiConfig["outputs"] = {
85  dataType: deepcopy(dataDict[dataType]) for dataType in output
86  }
87  # expand any [ ] lists in output filenames
88  output_proc_regex = re.compile(r"(.+)\[(.*)](.*)")
89  for v in mpiConfig["outputs"].values():
90  v.multipleOK = True
91  new_list = []
92  list_to_remove = []
93  for fn in v.value:
94  if ("[" in fn) and ("]" in fn):
95  match = output_proc_regex.match(fn)
96  new_list.extend(
97  [
98  f"{match.group(1)}{it}{match.group(3)}"
99  for it in match.group(2).split(",")
100  ]
101  )
102  list_to_remove.append(match.group(1))
103  else:
104  new_list.append(fn)
105  list_to_remove.append(fn)
106  v.value = new_list
107  v.list_to_remove = list(set(list_to_remove))
108 
109 
111  if getMPIType() == MPIType.NOMPI:
112  return True # validate if we're not in MPI mode
113  if getMPIRank() == 0:
114  return True # validate in rank 0
115  return False # don't validate in other ranks
116 
117 
119  return mpiConfig["outputs"].values()
120 
121 
123  """Merge outputs into rank 0"""
124  if mpiConfig is None:
125  msg.warn("trfMPITools.mergeOutputs called when we are not in MPI mode")
126  return
127  rank_dir_regex = re.compile("rank-([0-9]+)$")
128  rank_dirs = {
129  int(m.group(1)): m.string
130  for m in (rank_dir_regex.search(d.path) for d in os.scandir("..") if d.is_dir())
131  if m and int(m.group(1)) > 0
132  }
133  if getMPIRank() == 0:
134  msg.info("Rank output directories are:\n{}".format(pprint.pformat(rank_dirs)))
135  all_merge_inputs = list(
136  map(
137  lambda f: f.path,
138  filter(
139  lambda f: f.is_file(),
140  it.chain.from_iterable(map(os.scandir, rank_dirs.values())),
141  ),
142  )
143  )
144  # Remove PoolFileCatalog
145  try:
146  os.remove("PoolFileCatalog.xml")
147  except FileNotFoundError:
148  pass
149  for dtype, defn in mpiConfig["outputs"].items():
150  if getMPIRank() == 0:
151  msg.info(f"Output type is {dtype}")
152  merge_helper = deepcopy(defn)
153  merge_helper.multipleOK = True
154  if getMPIRank() == 0:
155  for fn in defn.list_to_remove:
156  # remove empty files from rank 0
157  try:
158  os.remove(fn)
159  except FileNotFoundError:
160  pass
161  merge_lists = []
162  for fn in defn.value:
163  msg.info(f"Generating merge list by filtering for {fn} in {all_merge_inputs}")
164  merge_inputs = sorted(filter(lambda s: s.endswith(fn), all_merge_inputs))
165  # Add to list
166  merge_helper.value.extend(merge_inputs)
167  merge_lists.append((fn, merge_inputs))
168  # Remove non-existent output files from mpiOutputs
169  defn.value = [x[0] for x in merge_lists if len(x[1]) >= 1]
170  # Merge each final output in a different rank
171  if getMPIRank() >= len(merge_lists):
172  msg.info(f"In rank {getMPIRank()}, not merging")
173  break
174  my_merge = merge_lists[getMPIRank()]
175  if len(my_merge[1]) < 1:
176  msg.info(f"In rank {getMPIRank()}, no inputs for ../rank-0/{my_merge[0]}")
177  open("done_merging", "a").close()
178  break
179  msg.info(
180  f"In rank {getMPIRank()}, merging into ../rank-0/{my_merge[0]}. Inputs are \n{pprint.pformat(my_merge[1])}"
181  )
182  merge_helper.selfMerge(f"../rank-0/{my_merge[0]}", my_merge[1])
183  # Create a file to indicate we are done
184  open("done_merging", "a").close()
185  if getMPIRank() == 0:
186  from functools import reduce
187  from operator import and_
188  from time import sleep
189  import sqlite3 as sq3
190  from glob import glob
191 
192  # Merge log databases
193  tables = ["ranks", "event_log"]
194  conn = sq3.connect("mpilog.db")
195  cur = conn.cursor()
196  for db in glob("../rank-[1-9]*/mpilog.db"):
197  cur.execute("ATTACH DATABASE ? as db", (db,))
198  for table in tables:
199  cur.execute(f"INSERT INTO {table} SELECT * from db.{table}")
200  conn.commit()
201  cur.execute("DETACH DATABASE db")
202  conn.close()
203 
204  # In rank 0, wait until all other ranks have finished merging
205  files_to_check = [
206  f"../rank-{rank}/done_merging" for rank in range(0, len(merge_lists))
207  ]
208  count = 0
209  check = [os.path.exists(f) for f in files_to_check]
210  while not reduce(and_, check):
211  if count % 10 == 0:
212  msg.info("Waiting for other ranks to finish merging")
213  msg.debug(f"Looking for {files_to_check}")
214  msg.debug(f"Result: {check}")
215  count = count + 1
216  sleep(6)
217  check = [os.path.exists(f) for f in files_to_check]
218  msg.info("All ranks done merging")
DerivationFramework::TriggerMatchingUtils::sorted
std::vector< typename R::value_type > sorted(const R &r, PROJ proj={})
Helper function to create a sorted vector from an unsorted range.
python.trfMPITools.mergeOutputs
def mergeOutputs()
Definition: trfMPITools.py:122
python.trfExceptions.TransformSetupException
Setup exceptions.
Definition: trfExceptions.py:42
vtune_athena.format
format
Definition: vtune_athena.py:14
python.trfMPITools.signalError
def signalError(message)
Definition: trfMPITools.py:33
PyJobTransforms.trfExitCodes
Module for transform exit codes.
reduce
void reduce(HepMC::GenEvent *ge, std::vector< HepMC::GenParticlePtr > toremove)
Remove unwanted particles from the event, collapsing the graph structure consistently.
Definition: FixHepMC.cxx:84
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:808
covarianceTool.filter
filter
Definition: covarianceTool.py:514
plotBeamSpotVxVal.range
range
Definition: plotBeamSpotVxVal.py:194
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
python.trfMPITools.getMPIRank
def getMPIRank()
Definition: trfMPITools.py:40
CxxUtils::set
constexpr std::enable_if_t< is_bitmask_v< E >, E & > set(E &lhs, E rhs)
Convenience function to set bits in a class enum bitmask.
Definition: bitmask.h:232
python.trfMPITools.setupMPIConfig
def setupMPIConfig(output, dataDict)
Definition: trfMPITools.py:66
python.trfMPITools.mpiShouldValidate
def mpiShouldValidate()
Definition: trfMPITools.py:110
TrigJetMonitorAlgorithm.items
items
Definition: TrigJetMonitorAlgorithm.py:71
Trk::open
@ open
Definition: BinningType.h:40
python.trfMPITools.getMPIType
def getMPIType()
Definition: trfMPITools.py:54
python.CaloAddPedShiftConfig.int
int
Definition: CaloAddPedShiftConfig.py:45
python.trfMPITools.mpiOutputs
def mpiOutputs()
Definition: trfMPITools.py:118
python.trfMPITools.MPIType
Definition: trfMPITools.py:25
Trk::split
@ split
Definition: LayerMaterialProperties.h:38