5 from AthenaCommon.Logging
import logging
6 _msg = logging.getLogger(
'AccumulatorCache')
10 from abc
import ABC, abstractmethod
11 from copy
import deepcopy
12 from collections.abc
import Hashable, Iterable
13 from collections
import defaultdict
14 from dataclasses
import dataclass
17 from GaudiKernel.DataHandle
import DataHandle
23 """Exception thrown when AccumulatorCache is applied to non-hashable function call"""
29 """Abstract base for classes needing custom AccumulatorCache behavior."""
33 """This method is called by AccumulatorCache when an object is removed
34 from the cache. Implement this for custom cleanup actions."""
38 class AccumulatorDecorator:
39 """Class for use in function decorators, implements memoization.
41 Instances are callable objects that use the
42 hash value calculated from positional and keyword arguments
43 to implement memoization. Methods for suspending and
44 resuming memoization are provided.
59 _stats = defaultdict(CacheStats)
61 def __init__(self, func, size, verify, deepCopy):
62 """See AccumulatorCache decorator for documentation of arguments."""
64 functools.update_wrapper(self , func)
73 raise RuntimeError(f
"Invalid value for verify ({verify}) in AccumulatorCache for {func}")
76 """Return a dictionary with information about the cache size and cache usage"""
77 return {
"cache_size" : len(self.
_cache),
78 "misses" : self.
_stats[self].misses,
79 "hits" : self.
_stats[self].hits,
80 "function" : self.
_func,
85 """Print cache statistics"""
86 header =
"%-70s | Hits (time) | Misses (time) |" %
"AccumulatorCache"
87 print(
"-"*len(header))
89 print(
"-"*len(header))
91 for func, stats
in sorted(cls.
_stats.
items(), key=
lambda s:s[1].t_hits+s[1].t_misses, reverse=
True):
92 name = f
"{func.__module__}.{func.__name__}"
94 name =
'...' + name[-67:]
95 print(f
"{name:70} | {stats.hits:6} ({stats.t_hits:4.1f}s) | "
96 f
"{stats.misses:6} ({stats.t_misses:4.1f}s) |")
97 print(
"-"*len(header))
101 """Suspend memoization for all instances of AccumulatorDecorator."""
106 """Resume memoization for all instances of AccumulatorDecorator."""
111 """Clear all accumulator caches"""
115 decor._resultCache.clear()
120 if hasattr(x,
"athHash"):
122 elif isinstance(x, Hashable):
124 elif isinstance(x, DataHandle):
129 """Called when x is removed from the cache"""
130 if isinstance(x, AccumulatorCachable):
132 elif isinstance(x, Iterable)
and not isinstance(x, str):
134 AccumulatorDecorator._evict(el)
138 AccumulatorDecorator._evict(v)
141 """Support instance methods."""
142 return functools.partial(self.
__call__, obj)
147 t0 = time.perf_counter()
148 res, cacheHit = self.
_callImpl(*args, **kwargs)
150 except NotHashable
as e:
151 _msg.warning(f
"Argument value '{repr(e.value)}' in {self._func} is not hashable. "
152 "No caching is performed!")
154 return self.
_func(*args, **kwargs)
156 t1 = time.perf_counter()
158 self.
_stats[self].hits += 1
159 self.
_stats[self].t_hits += (t1-t0)
160 elif cacheHit
is False:
161 self.
_stats[self].misses += 1
162 self.
_stats[self].t_misses += (t1-t0)
165 """Implementation of __call__.
167 Returns: (result, cacheHit)
171 if not AccumulatorDecorator._memoize:
172 return (self.
_func(*args , **kwargs),
None)
175 hsh =
hash( (tuple(AccumulatorDecorator._getHash(a)
for a
in args),
176 frozenset((
hash(k), AccumulatorDecorator._getHash(v))
for k,v
in kwargs.items())) )
181 if AccumulatorDecorator.VERIFY_HASH == self.
_verify:
183 chkHsh = AccumulatorDecorator._getHash(res)
185 _msg.debug(
"Hash of function result, cached using AccumulatorDecorator, changed.")
187 res = self.
_func(*args , **kwargs)
189 self.
_resultCache[hsh] = AccumulatorDecorator._getHash(res)
196 return deepcopy(res), cacheHit
199 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
200 if isinstance(res, ComponentAccumulator):
205 _msg.debug(
'Hash not found in AccumulatorCache for function %s' , self.
_func)
207 _msg.warning(
"Cache limit (%d) reached for %s.%s",
210 AccumulatorDecorator._evict(oldest)
212 res = self.
_func(*args , **kwargs)
214 if AccumulatorDecorator.VERIFY_HASH == self.
_verify:
217 self.
_resultCache[hsh] = AccumulatorDecorator._getHash(res)
222 return (deepcopy(res)
if self.
_deepcopy else res,
False)
229 verifyResult = AccumulatorDecorator.VERIFY_NOTHING, deepCopy = True):
230 """Function decorator, implements memoization.
233 maxSize: maximum size for the cache associated with the function (default 128)
234 verifyResult: takes two possible values
236 AccumulatorDecorator.VERIFY_NOTHING - default, the cached function result is returned with no verification
237 AccumulatorDecorator.VERIFY_HASH - before returning the cached function value, the hash of the
238 result is checked to verify if this object was not modified
239 between function calls
240 deepCopy: if True (default) a deep copy of the function result will be stored in the cache.
243 An instance of AccumulatorDecorator.
246 def wrapper_accumulator(func):
249 return wrapper_accumulator(func)
if func
else wrapper_accumulator