Skip to content

Commit

Permalink
Add HE-enabled summation for single-node clusters.
Browse files Browse the repository at this point in the history
  • Loading branch information
lapets committed Dec 18, 2024
1 parent 922035a commit 552b5ad
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 10 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ license = {text = "MIT"}
readme = "README.rst"
requires-python = ">=3.11"
dependencies = [
"bcl~=2.3"
"bcl~=2.3",
"pailliers~=0.1"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion src/nilql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Allow users to access the functions directly."""
from nilql.nilql import secret_key, encrypt, decrypt, share
from nilql.nilql import secret_key, public_key, encrypt, decrypt, share
52 changes: 45 additions & 7 deletions src/nilql/nilql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import secrets
import hashlib
import bcl
import pailliers

_PLAINTEXT_SIGNED_INTEGER_MIN = -2147483648
"""Minimum plaintext 32-bit signed integer value that can be encrypted."""
Expand Down Expand Up @@ -90,27 +91,56 @@ def secret_key(cluster: dict = None, operations: dict = None) -> dict:
# Create instance with default cluster configuration and operations
# specification, updating the configuration and specification with the
# supplied arguments.
operations = {} or operations
instance = {
'value': None,
'cluster': cluster,
'operations': operations
'operations': {} or operations
}

if len([op for (op, status) in instance['operations'].items() if status]) != 1:
raise ValueError('secret key must support exactly one operation')

if instance['operations'].get('store'):
if len(instance['cluster']['nodes']) == 1:
instance['value'] = bcl.symmetric.secret()

if instance['operations'].get('match'):
salt = secrets.token_bytes(64)
instance['value'] = {'salt': salt}

if instance['operations'].get('store'):
if instance['operations'].get('sum'):
if len(instance['cluster']['nodes']) == 1:
instance['value'] = bcl.symmetric.secret()
instance['value'] = pailliers.secret(2048)

return instance

def encrypt(key: dict, plaintext: Union[int, str]) -> bytes:
def public_key(secret_key: dict) -> dict: # pylint: disable=redefined-outer-name
"""
Return a public key built according to what is specified in the supplied
secret key.
>>> sk = secret_key({'nodes': [{}]}, {'sum': True})
>>> isinstance(public_key(sk), dict)
True
"""
# Create instance with default cluster configuration and operations
# specification, updating the configuration and specification with the
# supplied arguments.
instance = {
'value': None,
'cluster': secret_key['cluster'],
'operations': secret_key['operations']
}

if isinstance(secret_key['value'], pailliers.secret):
instance['value'] = pailliers.public(secret_key['value'])

return instance

def encrypt(
key: dict,
plaintext: Union[int, str]
) -> Union[bytes, Sequence[bytes], int, Sequence[int]]:
"""
Return the ciphertext obtained by using the supplied key to encrypt the
supplied plaintext.
Expand Down Expand Up @@ -165,7 +195,9 @@ def encrypt(key: dict, plaintext: Union[int, str]) -> bytes:

# Encrypt a numerical value for summation.
if key['operations'].get('sum'):
if len(key['cluster']['nodes']) > 1:
if len(key['cluster']['nodes']) == 1:
instance = pailliers.encrypt(key['value'], plaintext)
elif len(key['cluster']['nodes']) > 1:
# Use additive secret sharing for multi-node clusters.
shares = []
total = 0
Expand All @@ -179,7 +211,10 @@ def encrypt(key: dict, plaintext: Union[int, str]) -> bytes:

return instance

def decrypt(key: dict, ciphertext: Union[bytes, Sequence[bytes]]) -> bytes:
def decrypt(
key: dict,
ciphertext: Union[bytes, Sequence[bytes], int, Sequence[int]]
) -> Union[bytes, int]:
"""
Return the ciphertext obtained by using the supplied key to encrypt the
supplied plaintext.
Expand Down Expand Up @@ -226,6 +261,9 @@ def decrypt(key: dict, ciphertext: Union[bytes, Sequence[bytes]]) -> bytes:
return _decode(bytes_)

if key['operations'].get('sum'):
if len(key['cluster']['nodes']) == 1:
return pailliers.decrypt(key['value'], ciphertext)

if len(key['cluster']['nodes']) > 1:
total = 0
for share_ in ciphertext:
Expand Down
23 changes: 22 additions & 1 deletion test/test_nilql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_exports(self):
"""
module = import_module('nilql.nilql')
self.assertTrue({
'secret_key', 'encrypt', 'decrypt', 'share'
'secret_key', 'public_key', 'encrypt', 'decrypt', 'share'
}.issubset(module.__dict__.keys()))

def test_secret_key_creation(self):
Expand Down Expand Up @@ -85,6 +85,27 @@ def test_encrypt_of_str_for_match_multiple(self):
)
)

def test_encrypt_of_int_for_sum_single(self):
"""
Test encryption of string for matching.
"""
sk = nilql.secret_key({'nodes': [{}]}, {'sum': True})
pk = nilql.public_key(sk)
plaintext = 123
ciphertext = nilql.encrypt(pk, plaintext)
self.assertTrue(isinstance(ciphertext, int))

def test_decrypt_of_int_for_sum_single(self):
"""
Test encryption of string for matching.
"""
sk = nilql.secret_key({'nodes': [{}]}, {'sum': True})
pk = nilql.public_key(sk)
plaintext = 123
ciphertext = nilql.encrypt(pk, plaintext)
plaintext_ = nilql.decrypt(sk, ciphertext)
self.assertTrue(plaintext == plaintext_)

def test_encrypt_of_int_for_match_error(self):
"""
Test range error during encryption of integer for matching.
Expand Down

0 comments on commit 552b5ad

Please sign in to comment.