diff --git a/src/nilql/nilql.py b/src/nilql/nilql.py index 8a5b9d2..6f0bf4e 100644 --- a/src/nilql/nilql.py +++ b/src/nilql/nilql.py @@ -170,9 +170,9 @@ def encrypt(key: dict, plaintext: Union[int, str]) -> bytes: shares = [] total = 0 for _ in range(len(key['cluster']['nodes']) - 1): - share = secrets.randbelow(_SECRET_SHARED_SIGNED_INTEGER_MODULUS) - shares.append(share) - total = (total + share) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS + share_ = secrets.randbelow(_SECRET_SHARED_SIGNED_INTEGER_MODULUS) + shares.append(share_) + total = (total + share_) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS shares.append((plaintext - total) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS) instance = shares @@ -220,16 +220,16 @@ def decrypt(key: dict, ciphertext: Union[bytes, Sequence[bytes]]) -> bytes: # Multi-node clusters use XOR-based secret sharing. shares = ciphertext bytes_ = bytes(len(shares[0])) - for share in shares: - bytes_ = bytes(a ^ b for (a, b) in zip(bytes_, share)) + for share_ in shares: + bytes_ = bytes(a ^ b for (a, b) in zip(bytes_, share_)) return _decode(bytes_) if key['operations'].get('sum'): if len(key['cluster']['nodes']) > 1: total = 0 - for share in ciphertext: - total = (total + share) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS + for share_ in ciphertext: + total = (total + share_) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS if total > _PLAINTEXT_SIGNED_INTEGER_MAX: total -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS @@ -238,5 +238,83 @@ def decrypt(key: dict, ciphertext: Union[bytes, Sequence[bytes]]) -> bytes: raise ValueError('cannot decrypt supplied ciphertext using the supplied key') +def share(document: Union[int, str, dict]) -> Sequence[dict]: + """ + Convert a document that may contain ciphertexts intended for decentralized + clusters into secret shares of that document. Shallow copies are created + whenever possible. + + >>> d = { + ... 'id': 0, + ... 'age': {'$share': [1, 2, 3]}, + ... 'dat': {'loc': {'$share': [4, 5, 6]}} + ... } + >>> for d in share(d): print(d) + {'id': 0, 'age': {'%share': 1}, 'dat': {'loc': {'%share': 4}}} + {'id': 0, 'age': {'%share': 2}, 'dat': {'loc': {'%share': 5}}} + {'id': 0, 'age': {'%share': 3}, 'dat': {'loc': {'%share': 6}}} + + A document with no ciphertexts intended for decentralized clusters is + unmodofied; a list containing this document is returned. + + >>> share({'id': 0, 'age': 23}) + [{'id': 0, 'age': 23}] + + Any attempt to convert a document that has an incorrect structure raises + an exception. + + >>> share([]) + Traceback (most recent call last): + ... + TypeError: document must be an integer, string, or dictionary + >>> share({'id': 0, 'age': {'$share': [1, 2, 3], 'extra': [1, 2, 3]}}) + Traceback (most recent call last): + ... + ValueError: share object has incorrect structure + >>> share({ + ... 'id': 0, + ... 'age': {'$share': [1, 2, 3]}, + ... 'dat': {'loc': {'$share': [4, 5]}} + ... }) + Traceback (most recent call last): + ... + ValueError: inconsistent share quantities in document + """ + # Return a single share for integer and string values. + if isinstance(document, (int, str)): + return [document] + + if not isinstance(document, dict): + raise TypeError('document must be an integer, string, or dictionary') + + # Handle the relevant base case: a document containing shares that were + # obtained using the ``encrypt`` function. + keys = set(document.keys()) + if '$share' in keys: + shares = document['$share'] + if not isinstance(shares, list) or len(keys) != 1: + raise ValueError('share object has incorrect structure') + return [{'%share': s} for s in shares] + + # Determine the number of shares in each subdocument. + k_to_vs = {} + for k, v in document.items(): + k_to_vs[k] = share(v) + quantity = max(len(vs) for vs in k_to_vs.values()) + + # Build each of the shares. + shares = [{} for _ in range(quantity)] + for k, vs in k_to_vs.items(): + if len(vs) == 1: + for i in range(quantity): + shares[i][k] = vs[0] + elif len(vs) == quantity: + for i in range(quantity): + shares[i][k] = vs[i] + else: + raise ValueError('inconsistent share quantities in document') + + return shares + if __name__ == '__main__': doctest.testmod() # pragma: no cover diff --git a/test/test_nilql.py b/test/test_nilql.py index 134db3a..a6d85ce 100644 --- a/test/test_nilql.py +++ b/test/test_nilql.py @@ -17,7 +17,7 @@ def test_exports(self): """ module = import_module('nilql.nilql') self.assertTrue({ - 'secret_key', 'encrypt', 'decrypt' + 'secret_key', 'encrypt', 'decrypt', 'share' }.issubset(module.__dict__.keys())) def test_secret_key_creation(self):