ATLAS Offline Software
Loading...
Searching...
No Matches
AccumulatorCache.py
Go to the documentation of this file.
2# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3#
4
5from AthenaCommon.Logging import logging
6_msg = logging.getLogger('AccumulatorCache')
7
8import functools
9import time
10from abc import ABC, abstractmethod
11from copy import deepcopy
12from collections.abc import Hashable, Iterable
13from collections import defaultdict
14from dataclasses import dataclass
15
16try:
17 from GaudiKernel.DataHandle import DataHandle
18except ImportError:
19 class DataHandle: pass # for analysis releases
20
21
22class NotHashable(Exception):
23 """Exception thrown when AccumulatorCache is applied to non-hashable function call"""
24 def __init__(self, value):
25 super().__init__ (self)
26 self.value = value
27
28
30 """Abstract base for classes needing custom AccumulatorCache behavior."""
31
32 @abstractmethod
33 def _cacheEvict(self):
34 """This method is called by AccumulatorCache when an object is removed
35 from the cache. Implement this for custom cleanup actions."""
36 pass
37
38
39class AccumulatorDecorator:
40 """Class for use in function decorators, implements memoization.
41
42 Instances are callable objects that use the
43 hash value calculated from positional and keyword arguments
44 to implement memoization. Methods for suspending and
45 resuming memoization are provided.
46 """
47
48 _memoize = True
49
50 VERIFY_NOTHING = 0
51 VERIFY_HASH = 1
52
53 @dataclass
55 hits : int = 0
56 misses: int = 0
57 t_hits: float = 0
58 t_misses: float = 0
59
60 _stats = defaultdict(CacheStats)
61
62 def __init__(self, func, size, verify, deepCopy):
63 """See AccumulatorCache decorator for documentation of arguments."""
64
65 functools.update_wrapper(self , func)
66 self._maxSize = size
67 self._func = func
68 self._cache = {}
69 self._resultCache = {}
70 self._verify = verify
71 self._deepcopy = deepCopy
72
73 if self._verify not in [self.VERIFY_NOTHING, self.VERIFY_HASH]:
74 raise RuntimeError(f"Invalid value for verify ({verify}) in AccumulatorCache for {func}")
75
76 def getInfo(self):
77 """Return a dictionary with information about the cache size and cache usage"""
78 return {"cache_size" : len(self._cache),
79 "misses" : self._stats[self].misses,
80 "hits" : self._stats[self].hits,
81 "function" : self._func,
82 "result_cache_size" : len(self._resultCache)}
83
84 @classmethod
85 def printStats(cls):
86 """Print cache statistics"""
87 header = "%-70s | Hits (time) | Misses (time) |" % "AccumulatorCache"
88 print("-"*len(header))
89 print(header)
90 print("-"*len(header))
91 # Print sorted by hit+miss time:
92 for func, stats in sorted(cls._stats.items(), key=lambda s:s[1].t_hits+s[1].t_misses, reverse=True):
93 name = f"{func.__module__}.{func.__name__}"
94 if len(name) > 70:
95 name = '...' + name[-67:]
96 print(f"{name:70} | {stats.hits:6} ({stats.t_hits:4.1f}s) | "
97 f"{stats.misses:6} ({stats.t_misses:4.1f}s) |")
98 print("-"*len(header))
99
100 @classmethod
102 """Suspend memoization for all instances of AccumulatorDecorator."""
103 cls._memoize = False
104
105 @classmethod
107 """Resume memoization for all instances of AccumulatorDecorator."""
108 cls._memoize = True
109
110 @classmethod
111 def clearCache(cls):
112 """Clear all accumulator caches"""
113 for decor in cls._stats:
114 decor._evictAll()
115 decor._cache.clear()
116 decor._resultCache.clear()
117
118 cls._stats.clear()
119
120 def _getHash(x):
121 if hasattr(x, "athHash"):
122 return x.athHash()
123 elif isinstance(x, Hashable):
124 return hash(x)
125 elif isinstance(x, DataHandle):
126 return hash(repr(x))
127 raise NotHashable(x)
128
129 def _evict(x):
130 """Called when x is removed from the cache"""
131 if isinstance(x, AccumulatorCachable):
132 x._cacheEvict()
133 elif isinstance(x, Iterable) and not isinstance(x, str):
134 for el in x:
135 AccumulatorDecorator._evict(el)
136
137 def _evictAll(self):
138 for v in self._cache.values():
139 AccumulatorDecorator._evict(v)
140
141 def __get__(self, obj, objtype):
142 """Support instance methods."""
143 return functools.partial(self.__call__, obj)
144
145 def __call__(self, *args, **kwargs):
146 cacheHit = None
147 try:
148 t0 = time.perf_counter()
149 res, cacheHit = self._callImpl(*args, **kwargs)
150 return res
151 except NotHashable as e:
152 _msg.warning(f"Argument value '{repr(e.value)}' in {self._func} is not hashable. "
153 "No caching is performed!")
154 cacheHit = False
155 return self._func(*args, **kwargs) # perform regular function call
156 finally:
157 t1 = time.perf_counter()
158 if cacheHit is True:
159 self._stats[self].hits += 1
160 self._stats[self].t_hits += (t1-t0)
161 elif cacheHit is False:
162 self._stats[self].misses += 1
163 self._stats[self].t_misses += (t1-t0)
164
165 def _callImpl(self, *args, **kwargs):
166 """Implementation of __call__.
167
168 Returns: (result, cacheHit)
169 """
170
171 # AccumulatorCache enabled?
172 if not AccumulatorDecorator._memoize:
173 return (self._func(*args , **kwargs), None)
174
175 # frozen set makes the order of keyword arguments irrelevant
176 hsh = hash( (tuple(AccumulatorDecorator._getHash(a) for a in args),
177 frozenset((hash(k), AccumulatorDecorator._getHash(v)) for k,v in kwargs.items())) )
178
179 res = self._cache.get(hsh, None)
180 if res is not None:
181 cacheHit = None
182 if AccumulatorDecorator.VERIFY_HASH == self._verify:
183 resHsh = self._resultCache[hsh]
184 chkHsh = AccumulatorDecorator._getHash(res)
185 if chkHsh != resHsh:
186 _msg.debug("Hash of function result, cached using AccumulatorDecorator, changed.")
187 cacheHit = False
188 res = self._func(*args , **kwargs)
189 self._cache[hsh] = res
190 self._resultCache[hsh] = AccumulatorDecorator._getHash(res)
191 else:
192 cacheHit = True
193 else:
194 cacheHit = True
195
196 if self._deepcopy:
197 return deepcopy(res), cacheHit
198 else:
199 # shallow copied CA still needs to undergo merging
200 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
201 if isinstance(res, ComponentAccumulator):
202 res._wasMerged=False
203 return res, cacheHit
204
205 else:
206 _msg.debug('Hash not found in AccumulatorCache for function %s' , self._func)
207 if len(self._cache) >= self._maxSize:
208 _msg.warning("Cache limit (%d) reached for %s.%s",
209 self._maxSize, self._func.__module__, self._func.__name__)
210 oldest = self._cache.pop(next(iter(self._cache)))
211 AccumulatorDecorator._evict(oldest)
212
213 res = self._func(*args , **kwargs)
214
215 if AccumulatorDecorator.VERIFY_HASH == self._verify:
216 if len(self._resultCache) >= self._maxSize:
217 del self._resultCache[next(iter(self._resultCache))]
218 self._resultCache[hsh] = AccumulatorDecorator._getHash(res)
219 self._cache[hsh] = res
220 else:
221 self._cache[hsh] = res
222
223 return (deepcopy(res) if self._deepcopy else res, False)
224
225 def __del__(self):
226 self._evictAll()
227
228
229def AccumulatorCache(func = None, maxSize = 128,
230 verifyResult = AccumulatorDecorator.VERIFY_NOTHING, deepCopy = True):
231 """Function decorator, implements memoization.
232
233 Keyword arguments:
234 maxSize: maximum size for the cache associated with the function (default 128)
235 verifyResult: takes two possible values
236
237 AccumulatorDecorator.VERIFY_NOTHING - default, the cached function result is returned with no verification
238 AccumulatorDecorator.VERIFY_HASH - before returning the cached function value, the hash of the
239 result is checked to verify if this object was not modified
240 between function calls
241 deepCopy: if True (default) a deep copy of the function result will be stored in the cache.
242
243 Returns:
244 An instance of AccumulatorDecorator.
245 """
246
247 def wrapper_accumulator(func):
248 return AccumulatorDecorator(func, maxSize, verifyResult, deepCopy)
249
250 return wrapper_accumulator(func) if func else wrapper_accumulator
void print(char *figname, TCanvas *c1)
__init__(self, func, size, verify, deepCopy)
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition hcg.cxx:130