diff --git a/roll/diceparser.py b/roll/diceparser.py index 3e41eab..4d3afef 100644 --- a/roll/diceparser.py +++ b/roll/diceparser.py @@ -25,7 +25,7 @@ Website used to do railroad diagrams: https://www.bottlecaps.de/rr/ui """ -from math import ceil, e, factorial, pi +from math import ceil, e, factorial, pi, sqrt from operator import add, floordiv, mod, mul, sub, truediv from random import randint from sys import version_info @@ -115,7 +115,8 @@ class DiceParser: "^": pow, "**": pow, "d": _roll_dice, - "!": factorial + "!": factorial, + "sqrt": sqrt } constants = { @@ -138,6 +139,8 @@ def _create_parser() -> Forward: ) expression = operatorPrecedence(atom, [ + (Literal('-'), 1, opAssoc.RIGHT), + (CaselessLiteral('sqrt'), 1, opAssoc.RIGHT), (oneOf('^ **'), 2, opAssoc.RIGHT), (Literal('-'), 1, opAssoc.RIGHT), @@ -212,6 +215,8 @@ def evaluate( result = current_rolls['total'] dice_rolls.append(current_rolls) + elif operator is sqrt: + result = operator(val) else: result = operator(result, val) else: diff --git a/roll/roll.py b/roll/roll.py index 4787614..a5eceaa 100644 --- a/roll/roll.py +++ b/roll/roll.py @@ -23,7 +23,7 @@ def roll(expression: str = '', verbose: bool = False) -> Union[int, float, dp.EvaluationResults]: """Evalute a string for dice and mathematical operations and calculate.""" - bad_chars: str = "0123456789d-/*() %+.!^pie" + bad_chars: str = "0123456789d-/*() %+.!^piesqrt" input_had_bad_chars: bool = len(expression.strip(bad_chars)) > 0 if input_had_bad_chars: diff --git a/tests/test_basic_math.py b/tests/test_basic_math.py index daab6ab..9558758 100644 --- a/tests/test_basic_math.py +++ b/tests/test_basic_math.py @@ -1,32 +1,25 @@ -from math import e, pi +from math import e, pi, sqrt +from typing import Union import pytest from roll import roll -def test_addition1(): - assert roll('2 + 2') == 4 - - -def test_addition2(): - assert roll('10 + 52') == 62 - - -def test_addition3(): - assert roll('1 + 10 + 100 + 1000') == 1111 - - -def test_addition_without_spaces1(): - assert roll('8+16') == 24 - - -def test_addition_without_spaces2(): - assert roll('5+17+202+81') == 305 - - -def test_addition_with_uneven_spaces1(): - assert roll('19 + 57') == 76 +@pytest.mark.parametrize("equation,result", [ + ('2 + 2', 4), + ('10 + 52', 62), + ('1 + 10 + 100 + 1000', 1111), + ('8 + 16', 24), + ('5 + 17 + 202 + 81', 305), + ('19 + 57', 76), + ('-5 + 20', 15), + ('50 + -25', 25), + ('1.5 + 0.5', 2), + ('204.5 + 20', 224.5), +]) +def test_addition(equation: str, result: Union[int, float]): + assert roll(equation) == result def test_addition_with_uneven_spaces2(): @@ -34,56 +27,35 @@ def test_addition_with_uneven_spaces2(): roll('321\t\n + \t\t\t \t 18') -def test_subtraction1(): - assert roll('10 - 5') == 5 - - -def test_subtraction2(): - assert roll('100 - 52') == 48 - - -def test_subtraction3(): - assert roll('1111 - 100 - 10 - 1') == 1000 - - -def test_subtraction4(): - assert roll('73 - 100') == -27 - - -def test_subtraction5(): - assert roll('5 - -10') == 15 - - -def test_multiplication1(): - assert roll('2 * 2') == 4 - - -def test_multiplication2(): - assert roll('20 * 5') == 100 - - -def test_multiplication3(): - assert roll('1 * 10 * 100') == 1000 - - -def test_multiplication4(): - assert roll('6 * 8 * 2 * 10') == 960 - - -def test_division1(): - assert roll('5 / 5') == 1 - - -def test_division2(): - assert roll('15 / 3') == 5 +@pytest.mark.parametrize("equation,result", [ + ('10 - 5', 5), + ('100 - 52', 48), + ('1111 - 100 - 10 - 1', 1000), + ('73 - 100', -27), + ('5 - -10', 15), +]) +def test_subtraction(equation: str, result: Union[int, float]): + assert roll(equation) == result -def test_division3(): - assert roll('48 / 6') == 8 +@pytest.mark.parametrize("equation,result", [ + ('2 * 2', 4), + ('20 * 5', 100), + ('1 * 10 * 100', 1000), + ('6 * 8 * 2 * 10', 960), +]) +def test_multiplication(equation: str, result: Union[int, float]): + assert roll(equation) == result -def test_division4(): - assert roll('54 / 9') == 6 +@pytest.mark.parametrize("equation,result", [ + ('5 / 5', 1), + ('15 / 3', 5), + ('48 / 6', 8), + ('54 / 9', 6), +]) +def test_division(equation: str, result: Union[int, float]): + assert roll(equation) == result def test_add_mul(): @@ -176,8 +148,38 @@ def test_bad_parens2(): def test_add_explonential(): assert roll("3 + 7**3") == 346 -def test_pi(): - assert roll('pi') == pi -def test_e(): - assert roll('e') == e +@pytest.mark.parametrize("equation,result", [ + ('pi', pi), + ('e', e), +]) +def test_constants(equation: str, result: Union[int, float]): + assert roll(equation) == result + + +@pytest.mark.parametrize("equation,result", [ + ('sqrt 25', 5), + # Addition + ('2 + sqrt 9', 5), + ('sqrt 36 + 7', 13), + # Subtraction + ('sqrt 16 - 4', 0), + ('20 - sqrt 100', 10), + # Multiplication + ('sqrt 4 * 12', 24), + ('10 * sqrt 81', 90), + # Division + ('sqrt 25 / 5', 1), + ('60 / sqrt 36', 10), + # Unary minus + ('sqrt --16', 4), + ('- sqrt 49', -7), + # Constants + ('sqrt e', sqrt(e)), + ('sqrt pi', sqrt(pi)), + # Exponentiation + ('sqrt 169 ** 2', 169), + ('5 ** sqrt 9', 125), +]) +def test_sqrt(equation: str, result: Union[int, float]): + assert roll(equation) == result diff --git a/tests/test_dice_rolling_with_math.py b/tests/test_dice_rolling_with_math.py index 7e4094a..5fb6903 100644 --- a/tests/test_dice_rolling_with_math.py +++ b/tests/test_dice_rolling_with_math.py @@ -1,4 +1,5 @@ +import pytest from roll import roll @@ -32,3 +33,11 @@ def test_dice_expo(): def test_dice_expo1(): assert roll('5d1**5') == 5 + + +@pytest.mark.parametrize("equation,range_low,range_high", [ + ('sqrt 25 d 6', 5, 30), + ('1d sqrt 36', 1, 6), +]) +def test_dice_sqrt(equation: str, range_low: int, range_high: int): + assert roll(equation) in range(range_low, range_high + 1)