Skip to content

Commit

Permalink
fix: compatibility with mypyc
Browse files Browse the repository at this point in the history
  • Loading branch information
akaihola committed Sep 18, 2024
1 parent 4a555ae commit ecf9055
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 29 deletions.
17 changes: 8 additions & 9 deletions pgtricks/pg_dump_splitsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,23 @@
MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE}


def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]:
def try_float(s1: str, s2: str) -> tuple[float, float]:
"""Convert two strings to floats. Return original ones on conversion error."""
if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-':
# optimization
return s1, s2
try:
return float(s1), float(s2)
except ValueError:
return s1, s2
raise ValueError
return float(s1), float(s2)


def linecomp(l1: str, l2: str) -> int:
p1 = l1.split('\t', 1)
p2 = l2.split('\t', 1)
# TODO: unquote cast after support for Python 3.8 is dropped
v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0]))
result = (v1 > v2) - (v1 < v2)
# modifying a line to see whether Darker works:
try:
v1, v2 = try_float(p1[0], p2[0])
result = (v1 > v2) - (v1 < v2)
except ValueError:
result = (p1[0] > p2[0]) - (p1[0] < p2[0])
if not result and len(p1) == len(p2) == 2:
return linecomp(p1[1], p2[1])
return result
Expand Down
44 changes: 24 additions & 20 deletions pgtricks/tests/test_pg_dump_splitsort.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from functools import cmp_to_key
from textwrap import dedent

Expand Down Expand Up @@ -36,29 +37,32 @@ def test_sql_copy_regular_expression(test_input, expected):
@pytest.mark.parametrize(
's1, s2, expect',
[
('', '', ('', '')),
('foo', '', ('foo', '')),
('foo', 'bar', ('foo', 'bar')),
('0', '1', (0.0, 1.0)),
('0', 'one', ('0', 'one')),
('0.0', '0.0', (0.0, 0.0)),
('0.0', 'one point zero', ('0.0', 'one point zero')),
('0.', '1.', (0.0, 1.0)),
('0.', 'one', ('0.', 'one')),
('4.2', '0.42', (4.2, 0.42)),
('4.2', 'four point two', ('4.2', 'four point two')),
('-.42', '-0.042', (-0.42, -0.042)),
('-.42', 'minus something', ('-.42', 'minus something')),
(r'\N', r'\N', (r'\N', r'\N')),
('foo', r'\N', ('foo', r'\N')),
('-4.2', r'\N', ('-4.2', r'\N')),
("", "", ValueError),
("foo", "", ValueError),
("foo", "bar", ValueError),
("0", "1", (0.0, 1.0)),
("0", "one", ValueError),
("0.0", "0.0", (0.0, 0.0)),
("0.0", "one point zero", ValueError),
("0.", "1.", (0.0, 1.0)),
("0.", "one", ValueError),
("4.2", "0.42", (4.2, 0.42)),
("4.2", "four point two", ValueError),
("-.42", "-0.042", (-0.42, -0.042)),
("-.42", "minus something", ValueError),
(r"\N", r"\N", ValueError),
("foo", r"\N", ValueError),
("-4.2", r"\N", ValueError),
],
)
def test_try_float(s1, s2, expect):
result1, result2 = try_float(s1, s2)
assert type(result1) is type(expect[0])
assert type(result2) is type(expect[1])
assert (result1, result2) == expect
with pytest.raises(expect) if expect is ValueError else nullcontext():

result1, result2 = try_float(s1, s2)

assert type(result1) is type(expect[0])
assert type(result2) is type(expect[1])
assert (result1, result2) == expect


@pytest.mark.parametrize(
Expand Down

0 comments on commit ecf9055

Please sign in to comment.