Skip to content

Commit

Permalink
Add and/or/xor
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Dec 20, 2024
1 parent 903c217 commit 7ed730f
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/generate_costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess
from pathlib import Path

from uplc.tests.test_acceptance import acceptance_test_dirs
from tests.test_acceptance import acceptance_test_dirs

for dirpath in acceptance_test_dirs():
files = os.listdir(dirpath)
Expand Down
88 changes: 88 additions & 0 deletions tests/run_acceptance_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import sys
from pathlib import Path
import os

from uplc import parse, dumps, UPLCDialect, eval
from uplc.cost_model import Budget
from uplc.transformer import unique_variables



def acceptance_test_dirs(root):
res = []
for dirpath, dirs, files in sorted(os.walk(root, topdown=True)):
if dirs:
# not a leaf directory
continue
res.append(dirpath)
return res



def run_acceptance_test(dirpath, log=False):
files = os.listdir(dirpath)
input_file = next(f for f in files if f.endswith("uplc"))
input_file_path = os.path.join(dirpath, input_file)
with open(input_file_path, "r") as fp:
input = fp.read()
if log:
print("----- Input -------")
print(input)
output_file = next(f for f in files if f.endswith("uplc.expected"))
output_file_path = os.path.join(dirpath, output_file)
with open(output_file_path, "r") as fp:
output = fp.read().strip()

if log:
print("----- Expected output -------")
print(output)
try:
input_parsed = parse(input, filename=input_file_path)
except Exception:
assert "parse error" == output, "Parsing program failed unexpectedly"
return
comp_res = eval(input_parsed)
res = comp_res.result
if log:
print("----- Actual output -------")
print(dumps(res))
if isinstance(res, Exception):
assert output == "evaluation failure", "Machine failed but should not fail."
return
assert output not in ("parse error", "evaluation failure"), "Program parsed and evaluated but should have thrown error"
output_parsed = parse(output, filename=output_file_path).term
res_parsed_unique = unique_variables.UniqueVariableTransformer().visit(res)
output_parsed_unique = unique_variables.UniqueVariableTransformer().visit(
output_parsed
)
res_dumps = dumps(res_parsed_unique, dialect=UPLCDialect.LegacyAiken)
output_dumps = dumps(output_parsed_unique, dialect=UPLCDialect.LegacyAiken)
assert output_dumps == res_dumps, "Program evaluated to wrong output"
try:
cost_file = next(f for f in files if f.endswith("cost"))
with open(Path(dirpath).joinpath(cost_file)) as f:
cost_content = f.read()
if cost_content == "error":
return
cost = json.loads(cost_content)
expected_spent_budget = Budget(cost["cpu"], cost["mem"])
assert expected_spent_budget == comp_res.cost, "Program evaluated with wrong cost."
except StopIteration:
pass

def main(test_root: str):
for path in acceptance_test_dirs(test_root):
failed = False
try:
run_acceptance_test(path)
except AssertionError:
failed = True
if failed:
print(path)
run_acceptance_test(path, log=True)


if __name__ == '__main__':
path = sys.argv[1]
main(path)
2 changes: 1 addition & 1 deletion tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from uplc.transformer import unique_variables
from uplc.optimizer import pre_evaluation, remove_traces, remove_force_delay

acceptance_test_path = Path("examples/acceptance_tests")
acceptance_test_path = Path(__file__).parent.parent / "examples/acceptance_tests"


def acceptance_test_dirs():
Expand Down
35 changes: 34 additions & 1 deletion uplc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass
from enum import Enum, auto
import hashlib
from itertools import zip_longest
from typing import List, Any, Dict, Union

import cbor2
Expand Down Expand Up @@ -725,7 +726,18 @@ class BuiltInFun(Enum):
# Bls12_381_MillerLoop = 68
# Bls12_381_MulMlResult = 69
# Bls12_381_FinalVerify = 70

AndByteString = 75
OrByteString = 76
XorByteString = 77
ComplementByteString = 78
ReadBit = 79
WriteBits = 80
ReplicateByte = 81
ShiftByteString = 82
RotateByteString = 83
CountSetBits = 84
FindFirstSetBit = 85
Ripemd_160 = 86

def typechecked(*typs):
def typecheck_decorator(fun):
Expand Down Expand Up @@ -876,6 +888,18 @@ def _MapData(x):
assert isinstance(x.sample_value, BuiltinPair), "Can only map over a list of pairs"
return PlutusMap({p.l_value: p.r_value for p in x.values})

def _map_trunc(foo, fill):
# implements the extending/truncating of and/or/xor
def ext_trunc_logic(switch, x, y):
x, y = x.value, y.value
if switch.value:
res = bytes(foo(xi,yi) for xi, yi in zip_longest(x, y, fillvalue=fill))
else:
# perform operation on bytes individually truncated to shorter sequence
res = bytes(foo(xi,yi) for xi, yi in zip(x, y))
return BuiltinByteString(res)
return ext_trunc_logic


two_ints = typechecked(BuiltinInteger, BuiltinInteger)
two_bytestrings = typechecked(BuiltinByteString, BuiltinByteString)
Expand Down Expand Up @@ -991,6 +1015,15 @@ def _MapData(x):
BuiltInFun.SerialiseData: single_data(
lambda x: BuiltinByteString(plutus_cbor_dumps(x))
),
BuiltInFun.AndByteString: typechecked(BuiltinBool, BuiltinByteString, BuiltinByteString)(
_map_trunc(lambda x, y: x & y, 255)
),
BuiltInFun.OrByteString: typechecked(BuiltinBool, BuiltinByteString, BuiltinByteString)(
_map_trunc(lambda x, y: x | y, 0)
),
BuiltInFun.XorByteString: typechecked(BuiltinBool, BuiltinByteString, BuiltinByteString)(
_map_trunc(lambda x, y: x ^ y, 0)
),
}

BuiltInFunForceMap = defaultdict(int)
Expand Down

0 comments on commit 7ed730f

Please sign in to comment.