Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored master branch #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions eopsin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ def visit_Module(self, node: TypedModule) -> plt.AST:
raise RuntimeError(
"The contract can not always detect if it was passed three or two parameters on-chain."
)
cp = plt.Program("1.0.0", validator)
return cp
return plt.Program("1.0.0", validator)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function UPLCCompiler.visit_Module refactored with the following changes:


def visit_Constant(self, node: TypedConstant) -> plt.AST:
plt_type = ConstantMap.get(type(node.value))
Expand Down Expand Up @@ -796,10 +795,7 @@ def visit_ListComp(self, node: TypedListComp) -> plt.AST:
lst = plt.Apply(self.visit(gen.iter), plt.Var(STATEMONAD))
ifs = None
for ifexpr in gen.ifs:
if ifs is None:
ifs = self.visit(ifexpr)
else:
ifs = plt.And(ifs, self.visit(ifexpr))
ifs = self.visit(ifexpr) if ifs is None else plt.And(ifs, self.visit(ifexpr))
Comment on lines -799 to +798
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function UPLCCompiler.visit_ListComp refactored with the following changes:

map_fun = plt.Lambda(
["x"],
plt.Apply(
Expand Down
27 changes: 10 additions & 17 deletions eopsin/ledger/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ def compare(a: int, b: int) -> int:
# a == b: 0
# a > b: -1
if a < b:
result = 1
return 1
elif a == b:
result = 0
return 0
else:
result = -1
return result
return -1
Comment on lines -9 to +13
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function compare refactored with the following changes:



def compare_extended_helper(time: ExtendedPOSIXTime) -> int:
Expand All @@ -31,21 +30,15 @@ def compare_extended(a: ExtendedPOSIXTime, b: ExtendedPOSIXTime) -> int:
# a > b: -1
a_val = compare_extended_helper(a)
b_val = compare_extended_helper(b)
if a_val == 0 and b_val == 0:
a_finite: FinitePOSIXTime = a
b_finite: FinitePOSIXTime = b
result = compare(a_finite.time, b_finite.time)
else:
result = compare(a_val, b_val)
return result
if a_val != 0 or b_val != 0:
return compare(a_val, b_val)
a_finite: FinitePOSIXTime = a
b_finite: FinitePOSIXTime = b
return compare(a_finite.time, b_finite.time)
Comment on lines -34 to +37
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function compare_extended refactored with the following changes:



def get_bool(b: BoolData) -> bool:
if isinstance(b, TrueData):
result = True
else:
result = False
return result
return isinstance(b, TrueData)
Comment on lines -44 to +41
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_bool refactored with the following changes:



def compare_upper_bound(a: UpperBoundPOSIXTime, b: UpperBoundPOSIXTime) -> int:
Expand Down Expand Up @@ -76,7 +69,7 @@ def contains(a: POSIXTimeRange, b: POSIXTimeRange) -> bool:
# Returns True if the interval `b` is entirely contained in `a`.
lower = compare_lower_bound(a.lower_bound, b.lower_bound)
upper = compare_upper_bound(a.upper_bound, b.upper_bound)
return (lower == 1 or lower == 0) and (upper == 0 or upper == -1)
return lower in [1, 0] and upper in [0, -1]
Comment on lines -79 to +72
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function contains refactored with the following changes:

  • Replace multiple comparisons of same variable with in operator [×2] (merge-comparisons)



def make_range(
Expand Down
5 changes: 1 addition & 4 deletions eopsin/optimize/optimize_remove_deadvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ class OptimizeRemoveDeadvars(CompilingNodeTransformer):

def guaranteed(self, name: str) -> bool:
name = name
for scope in reversed(self.guaranteed_avail_names):
if name in scope:
return True
return False
return any(name in scope for scope in reversed(self.guaranteed_avail_names))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function OptimizeRemoveDeadvars.guaranteed refactored with the following changes:

  • Use any() instead of for loop (use-any)


def enter_scope(self):
self.guaranteed_avail_names.append([])
Expand Down
7 changes: 2 additions & 5 deletions eopsin/optimize/optimize_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def visit_FunctionDef(self, node: FunctionDef):
def bs_from_int(i: int):
hex_str = f"{i:x}"
if len(hex_str) % 2 == 1:
hex_str = "0" + hex_str
hex_str = f"0{hex_str}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function bs_from_int refactored with the following changes:

return bytes.fromhex(hex_str)


Expand All @@ -47,11 +47,8 @@ def visit_Module(self, node: Module) -> Module:
# collect all variable names
collector = NameCollector()
collector.visit(node)
# sort by most used
varmap = {}
varnames = sorted(collector.vars.items(), key=lambda x: x[1], reverse=True)
for i, (v, _) in enumerate(varnames):
varmap[v] = bs_from_int(i)
varmap = {v: bs_from_int(i) for i, (v, _) in enumerate(varnames)}
Comment on lines -50 to +51
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function OptimizeVarlen.visit_Module refactored with the following changes:

This removes the following comments ( why? ):

# sort by most used

self.varmap = varmap
node_cp = copy(node)
node_cp.body = [self.visit(s) for s in node.body]
Expand Down
26 changes: 11 additions & 15 deletions eopsin/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,11 @@ def all_tokens_unlocked_from_address(
) -> int:
"""Returns how many tokens of specified type are unlocked from given address"""
return sum(
[
txi.resolved.value.get(token.policy_id, {b"": 0}).get(token.token_name, 0)
for txi in txins
if txi.resolved.address == address
]
txi.resolved.value.get(token.policy_id, {b"": 0}).get(
token.token_name, 0
)
for txi in txins
if txi.resolved.address == address
Comment on lines -498 to +502
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function all_tokens_unlocked_from_address refactored with the following changes:

)


Expand All @@ -508,11 +508,9 @@ def all_tokens_locked_at_address_with_datum(
) -> int:
"""Returns how many tokens of specified type are locked at then given address with the specified datum"""
return sum(
[
txo.value.get(token.policy_id, {b"": 0}).get(token.token_name, 0)
for txo in txouts
if txo.address == address and txo.datum == output_datum
]
txo.value.get(token.policy_id, {b"": 0}).get(token.token_name, 0)
for txo in txouts
if txo.address == address and txo.datum == output_datum
Comment on lines -511 to +513
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function all_tokens_locked_at_address_with_datum refactored with the following changes:

)


Expand All @@ -521,11 +519,9 @@ def all_tokens_locked_at_address(
) -> int:
"""Returns how many tokens of specified type are locked at the given address"""
return sum(
[
txo.value.get(token.policy_id, {b"": 0}).get(token.token_name, 0)
for txo in txouts
if txo.address == address
]
txo.value.get(token.policy_id, {b"": 0}).get(token.token_name, 0)
for txo in txouts
if txo.address == address
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function all_tokens_locked_at_address refactored with the following changes:

)


Expand Down
3 changes: 1 addition & 2 deletions eopsin/rewrite/rewrite_augassign.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ class RewriteAugAssign(CompilingNodeTransformer):
def visit_AugAssign(self, node: AugAssign) -> Assign:
target_cp = copy(node.target)
target_cp.ctx = Load()
a = Assign(
return Assign(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RewriteAugAssign.visit_AugAssign refactored with the following changes:

[self.visit(node.target)],
BinOp(
self.visit(target_cp),
self.visit(node.op),
self.visit(node.value),
),
)
return a
2 changes: 1 addition & 1 deletion eopsin/rewrite/rewrite_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def visit_ImportFrom(
node.names[0].name == "*"
), "The import must have the form 'from <pkg> import *'"
assert (
node.names[0].asname == None
node.names[0].asname is None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function RewriteImport.visit_ImportFrom refactored with the following changes:

), "The import must have the form 'from <pkg> import *'"
# TODO set anchor point according to own package
if self.filename:
Expand Down
2 changes: 1 addition & 1 deletion eopsin/rewrite/rewrite_import_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
node.names[i].name == n
), "The program must contain one 'from dataclasses import dataclass'"
assert (
node.names[i].asname == None
node.names[i].asname is None
), "The program must contain one 'from dataclasses import dataclass'"
self.imports_dataclasses = True
return None
Expand Down
2 changes: 1 addition & 1 deletion eopsin/rewrite/rewrite_import_plutusdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
node.names[1].name == "PlutusData"
), "The program must contain one 'from pycardano import Datum as Anything, PlutusData'"
assert (
node.names[1].asname == None
node.names[1].asname is None
), "The program must contain one 'from pycardano import Datum as Anything, PlutusData'"
self.imports_plutus_data = True
return None
Expand Down
2 changes: 1 addition & 1 deletion eopsin/rewrite/rewrite_import_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
node.names[i].name == n
), "The program must contain one 'from typing import Dict, List, Union'"
assert (
node.names[i].asname == None
node.names[i].asname is None
), "The program must contain one 'from typing import Dict, List, Union'"
self.imports_typing = True
return None
Expand Down
24 changes: 11 additions & 13 deletions eopsin/rewrite/rewrite_tuple_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@ def visit_Assign(self, node: Assign) -> typing.List[stmt]:
uid = self.unique_id
self.unique_id += 1
assignments = [Assign([Name(f"{uid}_tup", Store())], self.visit(node.value))]
for i, t in enumerate(node.targets[0].elts):
assignments.append(
Assign(
[t],
Subscript(
value=Name(f"{uid}_tup", Load()),
slice=Index(value=Constant(i)),
ctx=Load(),
),
)
assignments.extend(
Assign(
[t],
Subscript(
value=Name(f"{uid}_tup", Load()),
slice=Index(value=Constant(i)),
ctx=Load(),
),
)
# recursively resolve multiple layers of tuples
transformed = sum([self.visit(a) for a in assignments], [])
return transformed
for i, t in enumerate(node.targets[0].elts)
)
return sum((self.visit(a) for a in assignments), [])
9 changes: 3 additions & 6 deletions eopsin/tests/test_ledger/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def test_contains(a: POSIXTimeRange, b: POSIXTimeRange):
lower = compare_lower_bound(a.lower_bound, b.lower_bound)
upper = compare_upper_bound(a.upper_bound, b.upper_bound)
if contains(a, b):
assert lower == 1 or lower == 0
assert upper == 0 or upper == -1
assert lower in [1, 0]
assert upper in [0, -1]
else:
assert lower == -1 or upper == 1

Expand Down Expand Up @@ -200,8 +200,5 @@ def test_fuzz_compare_extended_helper(time: ExtendedPOSIXTime) -> None:

@given(b=st.booleans())
def test_get_bool(b: bool) -> None:
if b:
bool_data = TrueData()
else:
bool_data = FalseData()
bool_data = TrueData() if b else FalseData()
assert get_bool(bool_data) == b
2 changes: 1 addition & 1 deletion eopsin/tests/test_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def validator(x: None) -> str:
self.assertEqual(ret, xs, "literal string returned wrong value")

def test_constant_unit(self):
source_code = f"""
source_code = """
def validator(x: None) -> None:
return None
"""
Expand Down
8 changes: 3 additions & 5 deletions eopsin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def visit_If(self, node: If) -> TypedIf:
typed_if = copy(node)
if (
isinstance(typed_if.test, Call)
and (typed_if.test.func, Name)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function AggressiveTypeInferencer.visit_If refactored with the following changes:

and typed_if.test.func.id == "isinstance"
):
tc = typed_if.test
Expand All @@ -246,7 +245,7 @@ def visit_If(self, node: If) -> TypedIf:
assert isinstance(target_class, RecordType), "Can only cast to PlutusData"
assert (
target_class in target_inst_class.typ.typs
), f"Trying to cast an instance of Union type to non-instance of union type"
), "Trying to cast an instance of Union type to non-instance of union type"
typed_if.test = self.visit(
Compare(
left=Attribute(tc.args[0], "CONSTR_ID"),
Expand Down Expand Up @@ -475,7 +474,7 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
elif isinstance(ts.value.typ.typ, DictType):
# TODO could be implemented with potentially just erroring. It might be desired to avoid this though.
raise TypeInferenceError(
f"Could not infer type of subscript of dict. Use 'get' with a default value instead."
"Could not infer type of subscript of dict. Use 'get' with a default value instead."
Comment on lines -478 to +477
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function AggressiveTypeInferencer.visit_Subscript refactored with the following changes:

)
else:
raise TypeInferenceError(
Expand Down Expand Up @@ -517,8 +516,7 @@ def visit_Call(self, node: Call) -> TypedCall:
raise TypeInferenceError("Could not infer type of call")

def visit_Pass(self, node: Pass) -> TypedPass:
tp = copy(node)
return tp
return copy(node)
Comment on lines -520 to +519
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function AggressiveTypeInferencer.visit_Pass refactored with the following changes:


def visit_Return(self, node: Return) -> TypedReturn:
tp = copy(node)
Expand Down
6 changes: 3 additions & 3 deletions eopsin/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ def empty_list(p: Type):
),
)
)
if isinstance(p.typ, RecordType) or isinstance(p.typ, AnyType):
if isinstance(p.typ, (RecordType, AnyType)):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function empty_list refactored with the following changes:

return plt.EmptyDataList()
raise NotImplementedError(f"Empty lists of type {p} can't be constructed yet")

Expand Down Expand Up @@ -1113,7 +1113,7 @@ def visit(self, node):
node_class_name = node.__class__.__name__
if node_class_name.startswith("Typed"):
node_class_name = node_class_name[len("Typed") :]
method = "visit_" + node_class_name
method = f"visit_{node_class_name}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function TypedNodeTransformer.visit refactored with the following changes:

visitor = getattr(self, method, self.generic_visit)
return visitor(node)

Expand All @@ -1124,6 +1124,6 @@ def visit(self, node):
node_class_name = node.__class__.__name__
if node_class_name.startswith("Typed"):
node_class_name = node_class_name[len("Typed") :]
method = "visit_" + node_class_name
method = f"visit_{node_class_name}"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function TypedNodeVisitor.visit refactored with the following changes:

visitor = getattr(self, method, self.generic_visit)
return visitor(node)
2 changes: 1 addition & 1 deletion examples/dict_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def validator(d: D2) -> bool:
return (
D(b"\x01") in d.dict_field.keys()
and 2 in d.dict_field.values()
and not D(b"") in d.dict_field.keys()
and D(b"") not in d.dict_field.keys()
)
7 changes: 3 additions & 4 deletions examples/fib_rec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
def fib(n: int) -> int:
if n == 0:
res = 0
return 0
elif n == 1:
res = 1
return 1
else:
res = fib(n - 1) + fib(n - 2)
return res
return fib(n - 1) + fib(n - 2)


def validator(n: int) -> int:
Expand Down
12 changes: 5 additions & 7 deletions examples/list_comprehensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@


def validator(n: int, even: bool) -> List[int]:
if even:
# generate even squares
res = [k * k for k in range(n) if k % 2 == 0]
else:
# generate all squares
res = [k * k for k in range(n)]
return res
return (
[k * k for k in range(n) if k % 2 == 0]
if even
else [k * k for k in range(n)]
)
2 changes: 1 addition & 1 deletion examples/list_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class D2(PlutusData):


def validator(d: D2) -> bool:
return b"\x01" == d.list_field[0]
return d.list_field[0] == b"\x01"
6 changes: 1 addition & 5 deletions examples/mult_for.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
def validator(a: int, b: int) -> int:
# trivial implementation of c = a * b
c = 0
for k in range(b):
c += a
return c
return sum(a for _ in range(b))
2 changes: 1 addition & 1 deletion examples/mult_while.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def validator(a: int, b: int) -> int:
# trivial implementation of c = a * b
c = 0
while 0 < b:
while b > 0:
c += a
b -= 1
return c
2 changes: 1 addition & 1 deletion examples/showcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def validator(n: int) -> int:
a += 5
while b < 5:
b += 1
for i in range(2):
for _ in range(2):
print("loop")

# sha256, sha3_256 and blake2b
Expand Down
Loading