Skip to content

Commit

Permalink
fix(mypy): type annotations for cipher algorithms (TheAlgorithms#4306)
Browse files Browse the repository at this point in the history
* fix(mypy): type annotations for cipher algorithms

* Update mypy workflow to include cipher directory

* fix: mypy errors in hill_cipher.py

* fix build errors
  • Loading branch information
dhruvmanila authored Apr 4, 2021
1 parent 806b386 commit 6089536
Show file tree
Hide file tree
Showing 21 changed files with 196 additions and 199 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python -m pip install mypy pytest-cov -r requirements.txt
# FIXME: #4052 fix mypy errors in the exclude directories and remove them below
- run: mypy --ignore-missing-imports
--exclude '(ciphers|conversions|data_structures|digital_image_processing|dynamic_programming|graphs|linear_algebra|maths|matrix|other|project_euler|scripts|searches|strings*)/$' .
--exclude '(conversions|data_structures|digital_image_processing|dynamic_programming|graphs|linear_algebra|maths|matrix|other|project_euler|scripts|searches|strings*)/$' .
- name: Run tests
run: pytest --doctest-modules --ignore=project_euler/ --ignore=scripts/ --cov-report=term-missing:skip-covered --cov=. .
- if: ${{ success() }}
Expand Down
8 changes: 2 additions & 6 deletions ciphers/diffie_hellman.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def generate_shared_key(self, other_key_str: str) -> str:
return sha256(str(shared_key).encode()).hexdigest()

@staticmethod
def is_valid_public_key_static(
local_private_key_str: str, remote_public_key_str: str, prime: int
) -> bool:
def is_valid_public_key_static(remote_public_key_str: int, prime: int) -> bool:
# check if the other public key is valid based on NIST SP800-56
if 2 <= remote_public_key_str and remote_public_key_str <= prime - 2:
if pow(remote_public_key_str, (prime - 1) // 2, prime) == 1:
Expand All @@ -257,9 +255,7 @@ def generate_shared_key_static(
local_private_key = int(local_private_key_str, base=16)
remote_public_key = int(remote_public_key_str, base=16)
prime = primes[group]["prime"]
if not DiffieHellman.is_valid_public_key_static(
local_private_key, remote_public_key, prime
):
if not DiffieHellman.is_valid_public_key_static(remote_public_key, prime):
raise ValueError("Invalid public key")
shared_key = pow(remote_public_key, local_private_key, prime)
return sha256(str(shared_key).encode()).hexdigest()
Expand Down
23 changes: 10 additions & 13 deletions ciphers/hill_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ class HillCipher:

to_int = numpy.vectorize(lambda x: round(x))

def __init__(self, encrypt_key: int):
def __init__(self, encrypt_key: numpy.ndarray) -> None:
"""
encrypt_key is an NxN numpy array
"""
self.encrypt_key = self.modulus(encrypt_key) # mod36 calc's on the encrypt key
self.check_determinant() # validate the determinant of the encryption key
self.decrypt_key = None
self.break_key = encrypt_key.shape[0]

def replace_letters(self, letter: str) -> int:
Expand Down Expand Up @@ -139,8 +138,8 @@ def encrypt(self, text: str) -> str:

for i in range(0, len(text) - self.break_key + 1, self.break_key):
batch = text[i : i + self.break_key]
batch_vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([batch_vec]).T
vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([vec]).T
batch_encrypted = self.modulus(self.encrypt_key.dot(batch_vec)).T.tolist()[
0
]
Expand All @@ -151,7 +150,7 @@ def encrypt(self, text: str) -> str:

return encrypted

def make_decrypt_key(self):
def make_decrypt_key(self) -> numpy.ndarray:
"""
>>> hill_cipher = HillCipher(numpy.array([[2, 5], [1, 6]]))
>>> hill_cipher.make_decrypt_key()
Expand Down Expand Up @@ -184,17 +183,15 @@ def decrypt(self, text: str) -> str:
>>> hill_cipher.decrypt('85FF00')
'HELLOO'
"""
self.decrypt_key = self.make_decrypt_key()
decrypt_key = self.make_decrypt_key()
text = self.process_text(text.upper())
decrypted = ""

for i in range(0, len(text) - self.break_key + 1, self.break_key):
batch = text[i : i + self.break_key]
batch_vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([batch_vec]).T
batch_decrypted = self.modulus(self.decrypt_key.dot(batch_vec)).T.tolist()[
0
]
vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([vec]).T
batch_decrypted = self.modulus(decrypt_key.dot(batch_vec)).T.tolist()[0]
decrypted_batch = "".join(
self.replace_digits(num) for num in batch_decrypted
)
Expand All @@ -203,12 +200,12 @@ def decrypt(self, text: str) -> str:
return decrypted


def main():
def main() -> None:
N = int(input("Enter the order of the encryption key: "))
hill_matrix = []

print("Enter each row of the encryption key with space separated integers")
for i in range(N):
for _ in range(N):
row = [int(x) for x in input().split()]
hill_matrix.append(row)

Expand Down
18 changes: 9 additions & 9 deletions ciphers/mixed_keyword_cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,32 @@ def mixed_keyword(key: str = "college", pt: str = "UNIVERSITY") -> str:
# print(temp)
alpha = []
modalpha = []
for i in range(65, 91):
t = chr(i)
for j in range(65, 91):
t = chr(j)
alpha.append(t)
if t not in temp:
temp.append(t)
# print(temp)
r = int(26 / 4)
# print(r)
k = 0
for i in range(r):
t = []
for _ in range(r):
s = []
for j in range(len_temp):
t.append(temp[k])
s.append(temp[k])
if not (k < 25):
break
k += 1
modalpha.append(t)
modalpha.append(s)
# print(modalpha)
d = {}
j = 0
k = 0
for j in range(len_temp):
for i in modalpha:
if not (len(i) - 1 >= j):
for m in modalpha:
if not (len(m) - 1 >= j):
break
d[alpha[k]] = i[j]
d[alpha[k]] = m[j]
if not k < 25:
break
k += 1
Expand Down
8 changes: 6 additions & 2 deletions ciphers/mono_alphabetic_ciphers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Literal

LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"


def translate_message(key, message, mode):
def translate_message(
key: str, message: str, mode: Literal["encrypt", "decrypt"]
) -> str:
"""
>>> translate_message("QWERTYUIOPASDFGHJKLZXCVBNM","Hello World","encrypt")
'Pcssi Bidsm'
Expand Down Expand Up @@ -40,7 +44,7 @@ def decrypt_message(key: str, message: str) -> str:
return translate_message(key, message, "decrypt")


def main():
def main() -> None:
message = "Hello World"
key = "QWERTYUIOPASDFGHJKLZXCVBNM"
mode = "decrypt" # set to 'encrypt' or 'decrypt'
Expand Down
2 changes: 1 addition & 1 deletion ciphers/morse_code_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def decrypt(message: str) -> str:
return decipher


def main():
def main() -> None:
message = "Morse code here"
result = encrypt(message.upper())
print(result)
Expand Down
9 changes: 5 additions & 4 deletions ciphers/onepad_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@


class Onepad:
def encrypt(self, text: str) -> ([str], [int]):
@staticmethod
def encrypt(text: str) -> tuple[list[int], list[int]]:
"""Function to encrypt text using pseudo-random numbers"""
plain = [ord(i) for i in text]
key = []
Expand All @@ -14,14 +15,14 @@ def encrypt(self, text: str) -> ([str], [int]):
key.append(k)
return cipher, key

def decrypt(self, cipher: [str], key: [int]) -> str:
@staticmethod
def decrypt(cipher: list[int], key: list[int]) -> str:
"""Function to decrypt text using pseudo-random numbers."""
plain = []
for i in range(len(key)):
p = int((cipher[i] - (key[i]) ** 2) / key[i])
plain.append(chr(p))
plain = "".join([i for i in plain])
return plain
return "".join([i for i in plain])


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions ciphers/playfair_cipher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import string
from typing import Generator, Iterable


def chunker(seq, size):
def chunker(seq: Iterable[str], size: int) -> Generator[tuple[str, ...], None, None]:
it = iter(seq)
while True:
chunk = tuple(itertools.islice(it, size))
Expand Down Expand Up @@ -37,7 +38,7 @@ def prepare_input(dirty: str) -> str:
return clean


def generate_table(key: str) -> [str]:
def generate_table(key: str) -> list[str]:

# I and J are used interchangeably to allow
# us to use a 5x5 table (25 letters)
Expand Down
49 changes: 21 additions & 28 deletions ciphers/porta_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}


def generate_table(key: str) -> [(str, str)]:
def generate_table(key: str) -> list[tuple[str, str]]:
"""
>>> generate_table('marvin') # doctest: +NORMALIZE_WHITESPACE
[('ABCDEFGHIJKLM', 'UVWXYZNOPQRST'), ('ABCDEFGHIJKLM', 'NOPQRSTUVWXYZ'),
Expand Down Expand Up @@ -60,30 +60,21 @@ def decrypt(key: str, words: str) -> str:
return encrypt(key, words)


def get_position(table: [(str, str)], char: str) -> (int, int) or (None, None):
def get_position(table: tuple[str, str], char: str) -> tuple[int, int]:
"""
>>> table = [
... ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST'), ('ABCDEFGHIJKLM', 'NOPQRSTUVWXYZ'),
... ('ABCDEFGHIJKLM', 'STUVWXYZNOPQR'), ('ABCDEFGHIJKLM', 'QRSTUVWXYZNOP'),
... ('ABCDEFGHIJKLM', 'WXYZNOPQRSTUV'), ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST')]
>>> get_position(table, 'A')
(None, None)
>>> get_position(generate_table('marvin')[0], 'M')
(0, 12)
"""
if char in table[0]:
row = 0
else:
row = 1 if char in table[1] else -1
return (None, None) if row == -1 else (row, table[row].index(char))
# `char` is either in the 0th row or the 1st row
row = 0 if char in table[0] else 1
col = table[row].index(char)
return row, col


def get_opponent(table: [(str, str)], char: str) -> str:
def get_opponent(table: tuple[str, str], char: str) -> str:
"""
>>> table = [
... ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST'), ('ABCDEFGHIJKLM', 'NOPQRSTUVWXYZ'),
... ('ABCDEFGHIJKLM', 'STUVWXYZNOPQR'), ('ABCDEFGHIJKLM', 'QRSTUVWXYZNOP'),
... ('ABCDEFGHIJKLM', 'WXYZNOPQRSTUV'), ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST')]
>>> get_opponent(table, 'A')
'A'
>>> get_opponent(generate_table('marvin')[0], 'M')
'T'
"""
row, col = get_position(table, char.upper())
if row == 1:
Expand All @@ -97,14 +88,16 @@ def get_opponent(table: [(str, str)], char: str) -> str:

doctest.testmod() # Fist ensure that all our tests are passing...
"""
ENTER KEY: marvin
ENTER TEXT TO ENCRYPT: jessica
ENCRYPTED: QRACRWU
DECRYPTED WITH KEY: JESSICA
Demo:
Enter key: marvin
Enter text to encrypt: jessica
Encrypted: QRACRWU
Decrypted with key: JESSICA
"""
key = input("ENTER KEY: ").strip()
text = input("ENTER TEXT TO ENCRYPT: ").strip()
key = input("Enter key: ").strip()
text = input("Enter text to encrypt: ").strip()
cipher_text = encrypt(key, text)

print(f"ENCRYPTED: {cipher_text}")
print(f"DECRYPTED WITH KEY: {decrypt(key, cipher_text)}")
print(f"Encrypted: {cipher_text}")
print(f"Decrypted with key: {decrypt(key, cipher_text)}")
10 changes: 5 additions & 5 deletions ciphers/rail_fence_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def encrypt(input_string: str, key: int) -> str:
...
TypeError: sequence item 0: expected str instance, int found
"""
grid = [[] for _ in range(key)]
temp_grid: list[list[str]] = [[] for _ in range(key)]
lowest = key - 1

if key <= 0:
Expand All @@ -31,8 +31,8 @@ def encrypt(input_string: str, key: int) -> str:
for position, character in enumerate(input_string):
num = position % (lowest * 2) # puts it in bounds
num = min(num, lowest * 2 - num) # creates zigzag pattern
grid[num].append(character)
grid = ["".join(row) for row in grid]
temp_grid[num].append(character)
grid = ["".join(row) for row in temp_grid]
output_string = "".join(grid)

return output_string
Expand Down Expand Up @@ -63,7 +63,7 @@ def decrypt(input_string: str, key: int) -> str:
if key == 1:
return input_string

temp_grid = [[] for _ in range(key)] # generates template
temp_grid: list[list[str]] = [[] for _ in range(key)] # generates template
for position in range(len(input_string)):
num = position % (lowest * 2) # puts it in bounds
num = min(num, lowest * 2 - num) # creates zigzag pattern
Expand All @@ -84,7 +84,7 @@ def decrypt(input_string: str, key: int) -> str:
return output_string


def bruteforce(input_string: str) -> dict:
def bruteforce(input_string: str) -> dict[int, str]:
"""Uses decrypt function by guessing every key
>>> bruteforce("HWe olordll")[4]
Expand Down
2 changes: 1 addition & 1 deletion ciphers/rot13.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dencrypt(s: str, n: int = 13) -> str:
return out


def main():
def main() -> None:
s0 = input("Enter message: ")

s1 = dencrypt(s0, 13)
Expand Down
Loading

0 comments on commit 6089536

Please sign in to comment.