diff --git a/memcache.py b/memcache.py index 9823a2f..ba17616 100644 --- a/memcache.py +++ b/memcache.py @@ -148,6 +148,7 @@ class Client(threading.local): _FLAG_INTEGER = 1 << 1 _FLAG_LONG = 1 << 2 _FLAG_COMPRESSED = 1 << 3 + _FLAG_TEXT = 1 << 4 _SERVER_RETRIES = 10 # how many times to try finding a free server. @@ -968,18 +969,23 @@ def _val_to_store_info(self, val, min_compress_len): the new value itself. """ flags = 0 - if isinstance(val, six.binary_type): + # Check against the exact type, rather than using isinstance, so that + # subclasses of native types (such as markup-safe strings) are pickled + # and restored as instances of the correct class. + type_ = type(val) + if type_ == six.binary_type: pass - elif isinstance(val, six.text_type): + elif type_ == six.text_type: + flags |= Client._FLAG_TEXT val = val.encode('utf-8') - elif isinstance(val, int): + elif type_ == int: flags |= Client._FLAG_INTEGER val = '%d' % val if six.PY3: val = val.encode('ascii') # force no attempt to compress this silly string. min_compress_len = 0 - elif six.PY2 and isinstance(val, long): + elif six.PY2 and type_ == long: flags |= Client._FLAG_LONG val = str(val) if six.PY3: @@ -1266,11 +1272,10 @@ def _recv_value(self, server, flags, rlen): flags &= ~Client._FLAG_COMPRESSED if flags == 0: - # Bare string - if six.PY3: - val = buf.decode('utf8') - else: + # Bare bytes val = buf + elif flags & Client._FLAG_TEXT: + val = buf.decode('utf8') elif flags & Client._FLAG_INTEGER: val = int(buf) elif flags & Client._FLAG_LONG: