Skip to content

Commit

Permalink
fixed an issue that prevented proper evaluation of comparisons on Qua…
Browse files Browse the repository at this point in the history
…ntumModulus
  • Loading branch information
positr0nium committed Apr 17, 2024
1 parent a046b64 commit 5db3f5e
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 79 deletions.
51 changes: 50 additions & 1 deletion src/qrisp/qtypes/quantum_modulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,32 @@
from qrisp.qtypes.quantum_float import QuantumFloat
from qrisp.environments import invert

def comparison_wrapper(func):

def res_func(self, other):

if self.m != 0:
raise Exception("Tried to evaluate QuantumModulus comparison with non-zero Montgomery shift")

conversion_flag = False
if isinstance(other, QuantumModulus):

if other.m != 0:
raise Exception("Tried to evaluate QuantumModulus comparison with non-zero Montgomery shift")

if self.modulus != other.modulus:
raise Exception("Tried to compare QuantumModulus instances of differing modulus")

other.__class__ = QuantumFloat
self.__class__ = QuantumFloat
res = func(self, other)
self.__class__ = QuantumModulus
if conversion_flag:
other.__class__ = QuantumModulus
return res

return res_func

class QuantumModulus(QuantumFloat):
r"""
This class is a subtype of :ref:`QuantumFloat`, which can be used to model and
Expand Down Expand Up @@ -258,7 +284,30 @@ def __isub__(self, other):
beauregard_adder(self, other, self.modulus)

return self

@comparison_wrapper
def __lt__(self, other):
return QuantumFloat.__lt__(self, other)

@comparison_wrapper
def __gt__(self, other):
return QuantumFloat.__gt__(self, other)

@comparison_wrapper
def __le__(self, other):
return QuantumFloat.__le__(self, other)

@comparison_wrapper
def __ge__(self, other):
return QuantumFloat.__ge__(self, other)

@comparison_wrapper
def __eq__(self, other):
return QuantumFloat.__eq__(self, other)

@comparison_wrapper
def __ne__(self, other):
return QuantumFloat.__ne__(self, other)


def __hash__(self):
return QuantumFloat.__hash__(self)
206 changes: 128 additions & 78 deletions tests/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,115 +17,165 @@
"""


from qrisp import h, QuantumFloat, multi_measurement
from qrisp import h, QuantumFloat, multi_measurement, QuantumModulus
import numpy as np


def test_comparison():
def test_comparison_helper(qf_0, qf_1, comp):
qf_0 = qf_0.duplicate()
h(qf_0)

if isinstance(qf_1, QuantumFloat):
qf_1 = qf_1.duplicate()
h(qf_1)

if comp == "eq":
qf_res = qf_0 == qf_1
elif comp == "neq":
qf_res = qf_0 != qf_1
elif comp == "leq":
qf_res = qf_0 <= qf_1
elif comp == "geq":
qf_res = qf_0 >= qf_1
elif comp == "lt":
qf_res = qf_0 < qf_1
elif comp == "gt":
qf_res = qf_0 > qf_1

if isinstance(qf_1, QuantumFloat):
mes_res = multi_measurement([qf_0, qf_1, qf_res])
for a, b, c in mes_res.keys():
print(a, b, c)
if comp == "eq":
assert (a == b) == c
elif comp == "neq":
assert (a != b) == c
elif comp == "leq":
assert (a <= b) == c
elif comp == "geq":
assert (a >= b) == c
elif comp == "lt":
assert (a < b) == c
elif comp == "gt":
assert (a > b) == c

# Test correct phase behavior7
statevector = qf_res.qs.statevector("array")
angles = np.angle(
statevector[
np.abs(statevector) > 1 / 2 ** ((qf_0.size + qf_1.size) / 2)
]
)
assert np.sum(np.abs(angles)) < 0.1
else:
mes_res = multi_measurement([qf_0, qf_res])
for a, c in mes_res.keys():
b = qf_1
print(a, b, c)
if comp == "eq":
assert (a == b) == c
elif comp == "neq":
assert (a != b) == c
elif comp == "leq":
assert (a <= b) == c
elif comp == "geq":
assert (a >= b) == c
elif comp == "lt":
assert (a < b) == c
elif comp == "gt":
assert (a > b) == c
def test_quantum_float_comparison():

a = QuantumFloat(5, 2, signed=False)
b = QuantumFloat(5, -1, signed=True)
# Test all the operators
for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
test_comparison_helper(a, b, comp)
comparison_helper(a, b, comp)

# Test specific constellations of QuantumFloat parameters
a = QuantumFloat(3, -1, signed=True)
b = QuantumFloat(5, 1, signed=True)
test_comparison_helper(a, b, "lt")
comparison_helper(a, b, "lt")

a = QuantumFloat(5, 2, signed=False)
b = QuantumFloat(5, -1, signed=True)
test_comparison_helper(a, b, "lt")
comparison_helper(a, b, "lt")

a = QuantumFloat(4, -4, signed=False)
b = QuantumFloat(4, 0, signed=True)
test_comparison_helper(a, b, "lt")
comparison_helper(a, b, "lt")

a = QuantumFloat(4, 1, signed=True)
b = QuantumFloat(3, -1, signed=False)
test_comparison_helper(a, b, "lt")
comparison_helper(a, b, "lt")

# Test semi-classical comparisons
a = QuantumFloat(4, signed=False)
b = 4
test_comparison_helper(a, b, "lt")
test_comparison_helper(a, b, "gt")
comparison_helper(a, b, "lt")
comparison_helper(a, b, "gt")

a = QuantumFloat(4, signed=True)
b = 4
test_comparison_helper(a, b, "lt")
test_comparison_helper(a, b, "gt")
comparison_helper(a, b, "lt")
comparison_helper(a, b, "gt")

a = QuantumFloat(4, -2, signed=False)
b = 4
test_comparison_helper(a, b, "lt")
test_comparison_helper(a, b, "gt")
comparison_helper(a, b, "lt")
comparison_helper(a, b, "gt")

a = QuantumFloat(8, 3, signed=False)
b = 16
test_comparison_helper(a, b, "lt")
test_comparison_helper(a, b, "gt")
comparison_helper(a, b, "lt")
comparison_helper(a, b, "gt")


def test_quantum_modulus_comparison():

a = QuantumModulus(5)
b = QuantumModulus(5)

for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
comparison_helper(a, b, comp)

a = QuantumModulus(7)
b = QuantumModulus(7)

for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
comparison_helper(a, b, comp)

a = QuantumModulus(13)
b = QuantumModulus(13)

for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
comparison_helper(a, b, comp)

a = QuantumModulus(5)
b = 3

for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
comparison_helper(a, b, comp)

a = QuantumModulus(7)
b = 5

for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
comparison_helper(a, b, comp)

a = QuantumModulus(13)
b = 7

for comp in ["eq", "neq", "lt", "gt", "leq", "geq"]:
comparison_helper(a, b, comp)

def comparison_helper(qf_0, qf_1, comp):
qf_0 = qf_0.duplicate()
h(qf_0)

if isinstance(qf_1, (QuantumFloat, QuantumModulus)):
qf_1 = qf_1.duplicate()
h(qf_1)

if comp == "eq":
qf_res = qf_0 == qf_1
elif comp == "neq":
qf_res = qf_0 != qf_1
elif comp == "leq":
qf_res = qf_0 <= qf_1
elif comp == "geq":
qf_res = qf_0 >= qf_1
elif comp == "lt":
qf_res = qf_0 < qf_1
elif comp == "gt":
qf_res = qf_0 > qf_1

if isinstance(qf_1, (QuantumFloat, QuantumModulus)):
mes_res = multi_measurement([qf_0, qf_1, qf_res])
for a, b, c in mes_res.keys():

if a is np.nan or b is np.nan:
continue

print(a, b, c)


if comp == "eq":
assert (a == b) == c
elif comp == "neq":
assert (a != b) == c
elif comp == "leq":
assert (a <= b) == c
elif comp == "geq":
assert (a >= b) == c
elif comp == "lt":
assert (a < b) == c
elif comp == "gt":
assert (a > b) == c

# Test correct phase behavior7
statevector = qf_res.qs.statevector("array")
angles = np.angle(
statevector[
np.abs(statevector) > 1 / 2 ** ((qf_0.size + qf_1.size) / 2)
]
)
assert np.sum(np.abs(angles)) < 0.1
else:
mes_res = multi_measurement([qf_0, qf_res])
for a, c in mes_res.keys():

if a is np.nan:
continue

b = qf_1
print(a, b, c)
if comp == "eq":
assert (a == b) == c
elif comp == "neq":
assert (a != b) == c
elif comp == "leq":
assert (a <= b) == c
elif comp == "geq":
assert (a >= b) == c
elif comp == "lt":
assert (a < b) == c
elif comp == "gt":
assert (a > b) == c

0 comments on commit 5db3f5e

Please sign in to comment.