-
-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a coerce_collections key_factory for @memoized to reduce boilerplate #7127
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
import inspect | ||
from contextlib import contextmanager | ||
|
||
from twitter.common.collections import OrderedSet | ||
|
||
from pants.util.meta import classproperty, staticproperty | ||
|
||
|
||
|
@@ -24,6 +26,44 @@ def equal_args(*args, **kwargs): | |
return key | ||
|
||
|
||
def coercing_arg_normalizer(type_coercions, base_normalizer, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be wonderful to keep consistent terminology in a single file. base_normalizer is elsewhere key_factory. Maybe |
||
"""Generate a tuple based off of arguments, applying any specified coercions. | ||
|
||
:param dict type_coercions: Map of type -> 1-arg function. If an argument's type matches any | ||
element, the function is called on the argument and stored in the | ||
returned tuple in its place. | ||
:param func base_normalizer: A key factory like `equal_args`. | ||
|
||
:rtype: tuple | ||
""" | ||
args_key = base_normalizer(*args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no contract whatsoever on what the cache key that is output is shaped like. It could be a cryptographic-grade hash value. Normalization should really apply to *args and **kwargs and the transformed versions of those be passed to the |
||
coerced = [] | ||
for arg in args_key: | ||
arg_type = type(arg) | ||
coercing_function = type_coercions.get(arg_type, None) | ||
if coercing_function: | ||
arg = coercing_function(arg) | ||
coerced.append(arg) | ||
return tuple(coerced) | ||
|
||
|
||
collection_coercions = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
list: tuple, | ||
OrderedSet: tuple, | ||
} | ||
|
||
|
||
def coerce_collections(*args, **kwargs): | ||
"""Generate a key based off of arguments like `equal_args`, coercing ordered collections to tuple. | ||
|
||
Although `list` and `OrderedSet` are mutable and therefore python doesn't let them be hashable, | ||
since we convert these arguments to tuple (a hashable type) before entering them in the cache, we | ||
can accept a greater range of inputs. | ||
""" | ||
return coercing_arg_normalizer(type_coercions=collection_coercions, base_normalizer=equal_args, | ||
args=args, kwargs=kwargs) | ||
|
||
|
||
class InstanceKey(object): | ||
"""An equality wrapper for an arbitrary object instance. | ||
|
||
|
@@ -58,6 +98,12 @@ def per_instance(*args, **kwargs): | |
return equal_args(*instance_and_rest, **kwargs) | ||
|
||
|
||
def coerce_collections_per_instance(*args, **kwargs): | ||
"""Analogous to `coerce_collections`, but uses an `InstanceKey` like `per_instance`.""" | ||
return coercing_arg_normalizer(type_coercions=collection_coercions, base_normalizer=per_instance, | ||
args=args, kwargs=kwargs) | ||
|
||
|
||
def memoized(func=None, key_factory=equal_args, cache_factory=dict): | ||
"""Memoizes the results of a function call. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,16 @@ | |
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import re | ||
import unittest | ||
from builtins import object | ||
|
||
from pants.util.memo import (memoized, memoized_classmethod, memoized_classproperty, | ||
memoized_method, memoized_property, memoized_staticmethod, | ||
memoized_staticproperty, per_instance, testable_memoized_property) | ||
from twitter.common.collections import OrderedSet | ||
|
||
from pants.util.memo import (coerce_collections, coerce_collections_per_instance, memoized, | ||
memoized_classmethod, memoized_classproperty, memoized_method, | ||
memoized_property, memoized_staticmethod, memoized_staticproperty, | ||
per_instance, testable_memoized_property) | ||
|
||
|
||
class MemoizeTest(unittest.TestCase): | ||
|
@@ -354,3 +358,48 @@ def calls(self): | |
|
||
self.assertEqual(4, foo2.calls) | ||
self.assertEqual(4, foo2.calls) | ||
|
||
def test_collection_coercion(self): | ||
@memoized | ||
def f(x): | ||
return sum(x) | ||
with self.assertRaisesRegexp(TypeError, re.escape("unhashable type: 'list'")): | ||
f([3, 4]) | ||
|
||
g_called = self._Called(increment=1) | ||
@memoized(key_factory=coerce_collections) | ||
def g(x): | ||
g_called._called() | ||
return sum(x) | ||
x = [3, 4] | ||
# x is converted into a tuple by coerce_collections, so this will only call g once. | ||
self.assertEqual(7, g(tuple(x))) | ||
self.assertEqual(7, g(x)) | ||
x[0] = 2 | ||
self.assertEqual(6, g(x)) | ||
# OrderedSet is converted into a tuple which is equal to the previous call, so this should not | ||
# increase the call count. | ||
self.assertEqual(6, g(OrderedSet(x))) | ||
self.assertEqual(2, g_called._calls) | ||
|
||
class C(self._Called): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should break out into its own test method, its testing a different function. |
||
def __init__(self): | ||
super(C, self).__init__(increment=1) | ||
|
||
@memoized_method | ||
def f(self, x): | ||
return sum(x) | ||
|
||
@memoized(key_factory=coerce_collections_per_instance) | ||
def g(self, x): | ||
self._called() | ||
return sum(x) | ||
c = C() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The key_factory is _per_instance but you only ever make one instance. It would be good to make two and confirm each instance carries its own independent cache. |
||
x = [3, 4] | ||
with self.assertRaisesRegexp(TypeError, re.escape("unhashable type: 'list'")): | ||
c.f(x) | ||
self.assertEqual(7, c.f(tuple(x))) | ||
self.assertEqual(7, c.g(x)) | ||
x[0] = 2 | ||
self.assertEqual(6, c.g(x)) | ||
self.assertEqual(2, c._calls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this can be _private, fwict you only intend to export
coerce_collections
andcoerce_collections_per_instance
.