ATLAS Offline Software
Loading...
Searching...
No Matches
trfMPITools.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
8
9from enum import Enum
10from copy import deepcopy
11import os
12import re
13import logging
14import pprint
15import itertools as it
16
17from PyJobTransforms.trfExitCodes import trfExit
18import PyJobTransforms.trfExceptions as trfExceptions
19
20msg = logging.getLogger(__name__)
21
22mpiConfig = None
23
24
25class MPIType(Enum):
26 """MPI master, MPI worker, or not using MPI"""
27
28 NOMPI = 0
29 MPIMASTER = 1
30 MPIWORKER = 2
31
32
33def signalError(message):
34 msg.error(message)
36 trfExit.nameToCode("TRF_SETUP"), message
37 )
38
39
41 """Return MPI rank"""
42 if mpiConfig is not None:
43 return int(mpiConfig["rank"])
44 if "RANK" not in os.environ:
45 return -1
46 else:
47 try:
48 return int(os.environ["RANK"])
49 except ValueError:
50 signalError("$RANK environment variable is not an integer")
51 return -2 # Only here to placate PyRight
52
53
55 """Return MPI type"""
56 if mpiConfig is not None:
57 return mpiConfig["type"]
58 if "RANK" not in os.environ:
59 return MPIType.NOMPI
60 if getMPIRank() == 0:
61 return MPIType.MPIMASTER
62 else:
63 return MPIType.MPIWORKER
64
65
66def setupMPIConfig(output, dataDict):
67 """Check environment is correct if we are in MPI mode, and setup dictionaries"""
68 global mpiConfig
69 if "RANK" not in os.environ:
71 "Running in MPI mode but the $RANK environment variable is not set!"
72 )
73 rank = getMPIRank()
74 if not os.getcwd().endswith("rank-{}".format(getMPIRank())):
76 "Running in MPI mode with rank {0} but working directory is not called rank-{0}".format(
78 )
79 )
80 mpiType = getMPIType()
81 mpiConfig = {}
82 mpiConfig["rank"] = rank
83 mpiConfig["type"] = mpiType
84 mpiConfig["outputs"] = {
85 dataType: deepcopy(dataDict[dataType]) for dataType in output
86 }
87 # expand any [ ] lists in output filenames
88 output_proc_regex = re.compile(r"(.+)\[(.*)](.*)")
89 for v in mpiConfig["outputs"].values():
90 v.multipleOK = True
91 new_list = []
92 list_to_remove = []
93 for fn in v.value:
94 if ("[" in fn) and ("]" in fn):
95 match = output_proc_regex.match(fn)
96 new_list.extend(
97 [
98 f"{match.group(1)}{it}{match.group(3)}"
99 for it in match.group(2).split(",")
100 ]
101 )
102 list_to_remove.append(match.group(1))
103 else:
104 new_list.append(fn)
105 list_to_remove.append(fn)
106 v.value = new_list
107 v.list_to_remove = list(set(list_to_remove))
108
109
111 if getMPIType() == MPIType.NOMPI:
112 return True # validate if we're not in MPI mode
113 if getMPIRank() == 0:
114 return True # validate in rank 0
115 return False # don't validate in other ranks
116
117
119 return mpiConfig["outputs"].values()
120
121
123 """Merge outputs into rank 0"""
124 if mpiConfig is None:
125 msg.warn("trfMPITools.mergeOutputs called when we are not in MPI mode")
126 return
127 rank_dir_regex = re.compile("rank-([0-9]+)$")
128 rank_dirs = {
129 int(m.group(1)): m.string
130 for m in (rank_dir_regex.search(d.path) for d in os.scandir("..") if d.is_dir())
131 if m and int(m.group(1)) > 0
132 }
133 if getMPIRank() == 0:
134 msg.info("Rank output directories are:\n{}".format(pprint.pformat(rank_dirs)))
135 all_merge_inputs = list(
136 map(
137 lambda f: f.path,
138 filter(
139 lambda f: f.is_file(),
140 it.chain.from_iterable(map(os.scandir, rank_dirs.values())),
141 ),
142 )
143 )
144 # Remove PoolFileCatalog
145 try:
146 os.remove("PoolFileCatalog.xml")
147 except FileNotFoundError:
148 pass
149 for dtype, defn in mpiConfig["outputs"].items():
150 if getMPIRank() == 0:
151 msg.info(f"Output type is {dtype}")
152 merge_helper = deepcopy(defn)
153 merge_helper.multipleOK = True
154 if getMPIRank() == 0:
155 for fn in defn.list_to_remove:
156 # remove empty files from rank 0
157 try:
158 os.remove(fn)
159 except FileNotFoundError:
160 pass
161 merge_lists = []
162 for fn in defn.value:
163 msg.info(f"Generating merge list by filtering for {fn} in {all_merge_inputs}")
164 merge_inputs = sorted(filter(lambda s: s.endswith(fn), all_merge_inputs))
165 # Add to list
166 merge_helper.value.extend(merge_inputs)
167 merge_lists.append((fn, merge_inputs))
168 # Remove non-existent output files from mpiOutputs
169 defn.value = [x[0] for x in merge_lists if len(x[1]) >= 1]
170 # Merge each final output in a different rank
171 if getMPIRank() >= len(merge_lists):
172 msg.info(f"In rank {getMPIRank()}, not merging")
173 break
174 my_merge = merge_lists[getMPIRank()]
175 if len(my_merge[1]) < 1:
176 msg.info(f"In rank {getMPIRank()}, no inputs for ../rank-0/{my_merge[0]}")
177 open("done_merging", "a").close()
178 break
179 msg.info(
180 f"In rank {getMPIRank()}, merging into ../rank-0/{my_merge[0]}. Inputs are \n{pprint.pformat(my_merge[1])}"
181 )
182 merge_helper.selfMerge(f"../rank-0/{my_merge[0]}", my_merge[1])
183 # Create a file to indicate we are done
184 open("done_merging", "a").close()
185 if getMPIRank() == 0:
186 from functools import reduce
187 from operator import and_
188 from time import sleep
189 import sqlite3 as sq3
190 from glob import glob
191
192 # Merge log databases
193 conn = sq3.connect("mpilog.db")
194 cur = conn.cursor()
195 tables = ["ranks", "files", "event_log"]
196 for db in glob("../rank-[1-9]*/mpilog.db"):
197 cur.execute("ATTACH DATABASE ? as db", (db,))
198 for table in tables:
199 upsert = "INSERT OR IGNORE" if table == "files" else "INSERT"
200 cur.execute(f"{upsert} INTO {table} SELECT * from db.{table}")
201 conn.commit()
202 cur.execute("DETACH DATABASE db")
203 conn.close()
204
205 # In rank 0, wait until all other ranks have finished merging
206 files_to_check = [
207 f"../rank-{rank}/done_merging" for rank in range(0, len(merge_lists))
208 ]
209 count = 0
210 check = [os.path.exists(f) for f in files_to_check]
211 while not reduce(and_, check):
212 if count % 10 == 0:
213 msg.info("Waiting for other ranks to finish merging")
214 msg.debug(f"Looking for {files_to_check}")
215 msg.debug(f"Result: {check}")
216 count = count + 1
217 sleep(6)
218 check = [os.path.exists(f) for f in files_to_check]
219 msg.info("All ranks done merging")
static void reduce(HepMC::GenEvent *ge, HepMC::GenParticle *gp)
Remove an unwanted particle from the event, collapsing the graph structure consistently.
Definition FixHepMC.cxx:39
STL class.
STL class.
std::vector< std::string > split(const std::string &s, const std::string &t=":")
Definition hcg.cxx:177
Module for transform exit codes.
setupMPIConfig(output, dataDict)