5 from AthenaCommon.Logging
import logging
6 _msg = logging.getLogger(
'AccumulatorCache')
10 from abc
import ABC, abstractmethod
11 from copy
import deepcopy
12 from collections.abc
import Hashable, Iterable
13 from collections
import defaultdict
14 from dataclasses
import dataclass
17 from GaudiKernel.DataHandle
import DataHandle
23 """Exception thrown when AccumulatorCache is applied to non-hashable function call"""
25 super().__init__ (self)
30 """Abstract base for classes needing custom AccumulatorCache behavior."""
34 """This method is called by AccumulatorCache when an object is removed
35 from the cache. Implement this for custom cleanup actions."""
39 class AccumulatorDecorator:
40 """Class for use in function decorators, implements memoization.
42 Instances are callable objects that use the
43 hash value calculated from positional and keyword arguments
44 to implement memoization. Methods for suspending and
45 resuming memoization are provided.
60 _stats = defaultdict(CacheStats)
62 def __init__(self, func, size, verify, deepCopy):
63 """See AccumulatorCache decorator for documentation of arguments."""
65 functools.update_wrapper(self , func)
74 raise RuntimeError(f
"Invalid value for verify ({verify}) in AccumulatorCache for {func}")
77 """Return a dictionary with information about the cache size and cache usage"""
78 return {
"cache_size" : len(self.
_cache),
79 "misses" : self.
_stats[self].misses,
80 "hits" : self.
_stats[self].hits,
81 "function" : self.
_func,
86 """Print cache statistics"""
87 header =
"%-70s | Hits (time) | Misses (time) |" %
"AccumulatorCache"
88 print(
"-"*len(header))
90 print(
"-"*len(header))
92 for func, stats
in sorted(cls.
_stats.
items(), key=
lambda s:s[1].t_hits+s[1].t_misses, reverse=
True):
93 name = f
"{func.__module__}.{func.__name__}"
95 name =
'...' + name[-67:]
96 print(f
"{name:70} | {stats.hits:6} ({stats.t_hits:4.1f}s) | "
97 f
"{stats.misses:6} ({stats.t_misses:4.1f}s) |")
98 print(
"-"*len(header))
102 """Suspend memoization for all instances of AccumulatorDecorator."""
107 """Resume memoization for all instances of AccumulatorDecorator."""
112 """Clear all accumulator caches"""
116 decor._resultCache.clear()
121 if hasattr(x,
"athHash"):
123 elif isinstance(x, Hashable):
125 elif isinstance(x, DataHandle):
130 """Called when x is removed from the cache"""
131 if isinstance(x, AccumulatorCachable):
133 elif isinstance(x, Iterable)
and not isinstance(x, str):
135 AccumulatorDecorator._evict(el)
139 AccumulatorDecorator._evict(v)
142 """Support instance methods."""
143 return functools.partial(self.
__call__, obj)
148 t0 = time.perf_counter()
149 res, cacheHit = self.
_callImpl(*args, **kwargs)
151 except NotHashable
as e:
152 _msg.warning(f
"Argument value '{repr(e.value)}' in {self._func} is not hashable. "
153 "No caching is performed!")
155 return self.
_func(*args, **kwargs)
157 t1 = time.perf_counter()
159 self.
_stats[self].hits += 1
160 self.
_stats[self].t_hits += (t1-t0)
161 elif cacheHit
is False:
162 self.
_stats[self].misses += 1
163 self.
_stats[self].t_misses += (t1-t0)
166 """Implementation of __call__.
168 Returns: (result, cacheHit)
172 if not AccumulatorDecorator._memoize:
173 return (self.
_func(*args , **kwargs),
None)
176 hsh =
hash( (tuple(AccumulatorDecorator._getHash(a)
for a
in args),
177 frozenset((
hash(k), AccumulatorDecorator._getHash(v))
for k,v
in kwargs.items())) )
182 if AccumulatorDecorator.VERIFY_HASH == self.
_verify:
184 chkHsh = AccumulatorDecorator._getHash(res)
186 _msg.debug(
"Hash of function result, cached using AccumulatorDecorator, changed.")
188 res = self.
_func(*args , **kwargs)
190 self.
_resultCache[hsh] = AccumulatorDecorator._getHash(res)
197 return deepcopy(res), cacheHit
200 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
201 if isinstance(res, ComponentAccumulator):
206 _msg.debug(
'Hash not found in AccumulatorCache for function %s' , self.
_func)
208 _msg.warning(
"Cache limit (%d) reached for %s.%s",
211 AccumulatorDecorator._evict(oldest)
213 res = self.
_func(*args , **kwargs)
215 if AccumulatorDecorator.VERIFY_HASH == self.
_verify:
218 self.
_resultCache[hsh] = AccumulatorDecorator._getHash(res)
223 return (deepcopy(res)
if self.
_deepcopy else res,
False)
230 verifyResult = AccumulatorDecorator.VERIFY_NOTHING, deepCopy = True):
231 """Function decorator, implements memoization.
234 maxSize: maximum size for the cache associated with the function (default 128)
235 verifyResult: takes two possible values
237 AccumulatorDecorator.VERIFY_NOTHING - default, the cached function result is returned with no verification
238 AccumulatorDecorator.VERIFY_HASH - before returning the cached function value, the hash of the
239 result is checked to verify if this object was not modified
240 between function calls
241 deepCopy: if True (default) a deep copy of the function result will be stored in the cache.
244 An instance of AccumulatorDecorator.
247 def wrapper_accumulator(func):
250 return wrapper_accumulator(func)
if func
else wrapper_accumulator