Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More classical code options #27

Merged
merged 5 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 66 additions & 37 deletions qldpc/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,51 +358,80 @@ def has_zero_row_or_column(matrix: galois.FieldArray) -> bool:
return ClassicalCode(matrix)

@classmethod
def repetition(cls, num_bits: int, field: int | None = None) -> ClassicalCode:
"""Construct a repetition code on the given number of bits."""
code_field = galois.GF(field or DEFAULT_FIELD_ORDER)
matrix = code_field.Zeros((num_bits - 1, num_bits))
for row in range(num_bits - 1):
matrix[row, row] = 1
matrix[row, row + 1] = -code_field(1)
return ClassicalCode(matrix)
def from_name(cls, name: str) -> ClassicalCode:
"""Named code in the GAP computer algebra system."""
standardized_name = name.strip().replace(" ", "") # remove whitespace
matrix, field = named_codes.get_code(standardized_name)
return ClassicalCode(matrix, field)

@classmethod
def ring(cls, num_bits: int, field: int | None = None) -> ClassicalCode:
"""Construct a repetition code with periodic boundary conditions."""
code_field = galois.GF(field or DEFAULT_FIELD_ORDER)
matrix = code_field.Zeros((num_bits, num_bits))
for row in range(num_bits):
matrix[row, row] = 1
matrix[row, (row + 1) % num_bits] = -code_field(1)
return ClassicalCode(matrix)
# TODO(?): maybe add modification options
# https://users.math.msu.edu/users/halljo/classes/codenotes/mod.pdf

@classmethod
def hamming(cls, rank: int, field: int | None = None) -> ClassicalCode:

class RepetitionCode(ClassicalCode):
"""Classical repetition code."""

def __init__(self, bits: int, field: int | None = None) -> None:
self._field = galois.GF(field or DEFAULT_FIELD_ORDER)
self._matrix = self.field.Zeros((bits - 1, bits))
for row in range(bits - 1):
self._matrix[row, row] = 1
self._matrix[row, row + 1] = -self.field(1)


class RingCode(ClassicalCode):
"""Classical ring code: repetition code with periodic boundary conditions."""

def __init__(self, bits: int, field: int | None = None) -> None:
self._field = galois.GF(field or DEFAULT_FIELD_ORDER)
self._matrix = self.field.Zeros((bits, bits))
for row in range(bits):
self._matrix[row, row] = 1
self._matrix[row, (row + 1) % bits] = -self.field(1)


class HammingCode(ClassicalCode):
"""Classical Hamming code."""

def __init__(self, rank: int, field: int | None = None) -> None:
"""Construct a Hamming code of a given rank."""
field = field or DEFAULT_FIELD_ORDER
if field == 2:
self._field = galois.GF(field or DEFAULT_FIELD_ORDER)
if self.field.order == 2:
# parity check matrix: columns = all nonzero bitstrings
bitstrings = list(itertools.product([0, 1], repeat=rank))
return ClassicalCode(np.array(bitstrings[1:]).T)
self._matrix = self.field(bitstrings[1:]).T

# More generally, columns = maximal set of linearly independent strings.
# This is achieved by collecting together all strings whose first nonzero element is a 1.
strings = [
(0,) * top_row + (1,) + rest
for top_row in range(rank - 1, -1, -1)
for rest in itertools.product(range(field), repeat=rank - top_row - 1)
]
return ClassicalCode(np.array(strings).T, field=field)
else:
# More generally, columns = [maximal set of linearly independent strings], so collect
# together all strings whose first nonzero element is a 1.
strings = [
(0,) * top_row + (1,) + rest
for top_row in range(rank - 1, -1, -1)
for rest in itertools.product(range(self.field.order), repeat=rank - top_row - 1)
]
self._matrix = self.field(strings).T

@classmethod
def from_name(cls, name: str) -> ClassicalCode:
"""Named code in the GAP computer algebra system."""
standardized_name = name.strip().replace(" ", "") # remove whitespace
matrix, field = named_codes.get_code(standardized_name)
return ClassicalCode(matrix, field)

# see https://mhostetter.github.io/galois/latest/api/#forward-error-correction
class ReedSolomonCode(ClassicalCode):
"""Classical Reed-Solomon code.

Source: https://galois.readthedocs.io/en/v0.3.8/api/galois.ReedSolomon/
Reference: https://errorcorrectionzoo.org/c/reed_solomon
"""

def __init__(self, bits: int, dimension: int) -> None:
ClassicalCode.__init__(self, galois.ReedSolomon(bits, dimension).H)


class BCHCode(ClassicalCode):
"""Classical binary BCH code code.

Source: https://galois.readthedocs.io/en/v0.3.8/api/galois.BCH/
Reference: https://errorcorrectionzoo.org/c/bch
"""

def __init__(self, bits: int, dimension: int) -> None:
ClassicalCode.__init__(self, galois.BCH(bits, dimension).H)


################################################################################
Expand Down
40 changes: 23 additions & 17 deletions qldpc/codes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def get_random_qudit_code(qudits: int, checks: int, field: int = 2) -> codes.Qud
def test_classical_codes() -> None:
"""Construction of a few classical codes."""
assert codes.ClassicalCode.random(5, 3).num_bits == 5
assert codes.ClassicalCode.hamming(3).get_distance() == 3
assert codes.HammingCode(3).get_distance() == 3

num_bits = 2
for code in [
codes.ClassicalCode.repetition(num_bits, field=3),
codes.ClassicalCode.ring(num_bits, field=3),
codes.RepetitionCode(num_bits, field=3),
codes.RingCode(num_bits, field=3),
]:
assert code.num_bits == num_bits
assert code.dimension == 1
Expand All @@ -47,9 +47,13 @@ def test_classical_codes() -> None:
assert code.get_weight() == 2
assert code.get_random_word() in code

# test that rank of repetition and hamming codes is independent of the field
assert codes.ClassicalCode.repetition(3, 2).rank == codes.ClassicalCode.repetition(3, 3).rank
assert codes.ClassicalCode.hamming(3, 2).rank == codes.ClassicalCode.hamming(3, 3).rank
# test that rank of repetition and Hamming codes is independent of the field
assert codes.RepetitionCode(3, 2).rank == codes.RepetitionCode(3, 3).rank
assert codes.HammingCode(3, 2).rank == codes.HammingCode(3, 3).rank

# check dimension of Reed-Solomon and BCH codes
assert codes.ReedSolomonCode(3, 2).dimension == 2
assert codes.BCHCode(15, 7).dimension == 7

# test invalid classical code construction
with pytest.raises(ValueError, match="inconsistent"):
Expand All @@ -58,10 +62,10 @@ def test_classical_codes() -> None:

def test_named_codes(order: int = 2) -> None:
"""Named codes from the GAP computer algebra system."""
code = codes.ClassicalCode.repetition(order)
matrix = [list(row) for row in code.matrix.view(np.ndarray)]
code = codes.RepetitionCode(order)
checks = [list(row) for row in code.matrix.view(np.ndarray)]

with unittest.mock.patch("qldpc.named_codes.get_code", return_value=(matrix, None)):
with unittest.mock.patch("qldpc.named_codes.get_code", return_value=(checks, None)):
named_code = codes.ClassicalCode.from_name(f"RepetitionCode({order})")
assert np.array_equal(named_code.matrix, code.matrix)

Expand Down Expand Up @@ -108,7 +112,7 @@ def test_conversions(bits: int = 5, checks: int = 3, field: int = 3) -> None:

def test_distance_from_classical_code(bits: int = 3) -> None:
"""Distance of a vector from a classical code."""
rep_code = codes.ClassicalCode.repetition(bits, field=2)
rep_code = codes.RepetitionCode(bits, field=2)
for vector in itertools.product(rep_code.field.elements, repeat=bits):
weight = np.count_nonzero(vector)
dist_brute = rep_code.get_distance_exact(vector=vector)
Expand Down Expand Up @@ -251,8 +255,8 @@ def test_twisted_XZZX(width: int = 3) -> None:
num_qudits = 2 * width**2

# construct check matrix directly
ring = codes.ClassicalCode.ring(width).matrix
mat_1 = codes.ClassicalCode.ring(num_qudits // 2).matrix
ring = codes.RingCode(width).matrix
mat_1 = codes.RingCode(num_qudits // 2).matrix
mat_2 = np.kron(ring, np.eye(width, dtype=int))
zero_1 = np.zeros((mat_1.shape[1],) * 2, dtype=int)
zero_2 = np.zeros((mat_1.shape[0],) * 2, dtype=int)
Expand Down Expand Up @@ -348,15 +352,17 @@ def test_tanner_code() -> None:

def test_surface_HGP_codes(distance: int = 2, field: int = 3) -> None:
"""The surface and toric codes as hypergraph product codes."""
bit_code: codes.ClassicalCode

# surface code
bit_code = codes.ClassicalCode.repetition(distance, field=field)
bit_code = codes.RepetitionCode(distance, field=field)
code = codes.HGPCode(bit_code)
assert code.num_qudits == distance**2 + (distance - 1) ** 2
assert code.dimension == 1
assert code.get_distance(bound=10) == distance

# toric code
bit_code = codes.ClassicalCode.ring(distance, field=field)
bit_code = codes.RingCode(distance, field=field)
code = codes.HGPCode(bit_code)
assert code.num_qudits == 2 * distance**2
assert code.dimension == 2
Expand All @@ -380,20 +386,20 @@ def test_toric_tanner_code(size: int = 4) -> None:
shift_x, shift_y = group.generators
subset_a = [shift_x, ~shift_x]
subset_b = [shift_y, ~shift_y]
subcode_a = codes.ClassicalCode.repetition(2, field=2)
subcode_a = codes.RepetitionCode(2, field=2)
code = codes.QTCode(subset_a, subset_b, subcode_a, bipartite=False)

# verify rotated toric code parameters
assert code.get_code_params(bound=10) == (size**2, 2, size, 4)

# raise error if constructing QTCode with codes over different fields
subcode_b = codes.ClassicalCode.repetition(2, field=subcode_a.field.order**2)
subcode_b = codes.RepetitionCode(2, field=subcode_a.field.order**2)
with pytest.raises(ValueError, match="different fields"):
code = codes.QTCode(subset_a, subset_b, subcode_a, subcode_b)


@pytest.mark.parametrize("field", [3, 4])
def test_qudit_distance(field: int) -> None:
"""Distance calculations for qudits."""
code = codes.HGPCode(codes.ClassicalCode.repetition(2, field=field))
code = codes.HGPCode(codes.RepetitionCode(2, field=field))
assert code.get_distance() == 2
Loading