Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing fixes. Mypy now produces 0 type errors #1354

Merged
merged 1 commit into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions lark/load_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import pkgutil
from ast import literal_eval
from contextlib import suppress
from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence
from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence, Generator

from .utils import bfs, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique, small_factors, OrderedSet
from .lexer import Token, TerminalDef, PatternStr, PatternRE
from .lexer import Token, TerminalDef, PatternStr, PatternRE, Pattern

from .parse_tree_builder import ParseTreeBuilder
from .parser_frontends import ParsingFrontend
Expand Down Expand Up @@ -195,10 +195,10 @@


class FindRuleSize(Transformer):
def __init__(self, keep_all_tokens):
def __init__(self, keep_all_tokens: bool):
self.keep_all_tokens = keep_all_tokens

def _will_not_get_removed(self, sym):
def _will_not_get_removed(self, sym: Symbol) -> bool:
if isinstance(sym, NonTerminal):
return not sym.name.startswith('_')
if isinstance(sym, Terminal):
Expand All @@ -207,7 +207,7 @@ def _will_not_get_removed(self, sym):
return False
assert False, sym

def _args_as_int(self, args):
def _args_as_int(self, args: List[Union[int, Symbol]]) -> Generator[int, None, None]:
for a in args:
if isinstance(a, int):
yield a
Expand All @@ -216,10 +216,10 @@ def _args_as_int(self, args):
else:
assert False

def expansion(self, args):
def expansion(self, args) -> int:
return sum(self._args_as_int(args))

def expansions(self, args):
def expansions(self, args) -> int:
return max(self._args_as_int(args))


Expand All @@ -232,7 +232,7 @@ def __init__(self):
self.i = 0
self.rule_options = None

def _name_rule(self, inner):
def _name_rule(self, inner: str):
new_name = '__%s_%s_%d' % (self.prefix, inner, self.i)
self.i += 1
return new_name
Expand All @@ -243,7 +243,7 @@ def _add_rule(self, key, name, expansions):
self.rules_cache[key] = t
return t

def _add_recurse_rule(self, type_, expr):
def _add_recurse_rule(self, type_: str, expr: Tree):
try:
return self.rules_cache[expr]
except KeyError:
Expand Down Expand Up @@ -312,7 +312,7 @@ def _add_repeat_opt_rule(self, a, b, target, target_opt, atom):
])
return self._add_rule(key, new_name, tree)

def _generate_repeats(self, rule, mn, mx):
def _generate_repeats(self, rule: Tree, mn: int, mx: int):
"""Generates a rule tree that repeats ``rule`` exactly between ``mn`` to ``mx`` times.
"""
# For a small number of repeats, we can take the naive approach
Expand Down Expand Up @@ -343,7 +343,7 @@ def _generate_repeats(self, rule, mn, mx):

return ST('expansions', [ST('expansion', [mn_target] + [diff_opt_target])])

def expr(self, rule, op, *args):
def expr(self, rule: Tree, op: Token, *args):
if op.value == '?':
empty = ST('expansion', [])
return ST('expansions', [rule, empty])
Expand Down Expand Up @@ -372,7 +372,7 @@ def expr(self, rule, op, *args):

assert False, op

def maybe(self, rule):
def maybe(self, rule: Tree):
keep_all_tokens = self.rule_options and self.rule_options.keep_all_tokens
rule_size = FindRuleSize(keep_all_tokens).transform(rule)
empty = ST('expansion', [_EMPTY] * rule_size)
Expand All @@ -382,11 +382,11 @@ def maybe(self, rule):
class SimplifyRule_Visitor(Visitor):

@staticmethod
def _flatten(tree):
def _flatten(tree: Tree):
while tree.expand_kids_by_data(tree.data):
pass

def expansion(self, tree):
def expansion(self, tree: Tree):
# rules_list unpacking
# a : b (c|d) e
# -->
Expand Down Expand Up @@ -417,7 +417,7 @@ def alias(self, tree):
tree.data = 'expansions'
tree.children = aliases

def expansions(self, tree):
def expansions(self, tree: Tree):
self._flatten(tree)
# Ensure all children are unique
if len(set(tree.children)) != len(tree.children):
Expand Down Expand Up @@ -610,23 +610,25 @@ def range(self, start, end):
return ST('pattern', [PatternRE(regexp)])


def _make_joined_pattern(regexp, flags_set):
def _make_joined_pattern(regexp, flags_set) -> PatternRE:
return PatternRE(regexp, ())

class TerminalTreeToPattern(Transformer_NonRecursive):
def pattern(self, ps):
p ,= ps
return p

def expansion(self, items):
assert items
def expansion(self, items: List[Pattern]) -> Pattern:
if not items:
return PatternStr('')

if len(items) == 1:
return items[0]

pattern = ''.join(i.to_regexp() for i in items)
return _make_joined_pattern(pattern, {i.flags for i in items})

def expansions(self, exps):
def expansions(self, exps: List[Pattern]) -> Pattern:
if len(exps) == 1:
return exps[0]

Expand All @@ -637,7 +639,8 @@ def expansions(self, exps):
pattern = '(?:%s)' % ('|'.join(i.to_regexp() for i in exps))
return _make_joined_pattern(pattern, {i.flags for i in exps})

def expr(self, args):
def expr(self, args) -> Pattern:
inner: Pattern
inner, op = args[:2]
if op == '~':
if len(args) == 3:
Expand Down
5 changes: 3 additions & 2 deletions lark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def is_id_start(s: str) -> bool:
return _test_unicode_category(s, _ID_START)


def dedup_list(l: List[T]) -> List[T]:
def dedup_list(l: Sequence[T]) -> List[T]:
"""Given a list (l) will removing duplicates from the list,
preserving the original order of the list. Assumes that
the list entries are hashable."""
Expand Down Expand Up @@ -231,7 +231,8 @@ def combine_alternatives(lists):
return list(product(*lists))

try:
import atomicwrites
# atomicwrites doesn't have type bindings
import atomicwrites # type: ignore[import]
_has_atomicwrites = True
except ImportError:
_has_atomicwrites = False
Expand Down
4 changes: 2 additions & 2 deletions lark/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __mul__(
return TransformerChain(*self.transformers + (other,))


class Transformer_InPlace(Transformer):
class Transformer_InPlace(Transformer[_Leaf_T, _Return_T]):
"""Same as Transformer, but non-recursive, and changes the tree in-place instead of returning new instances

Useful for huge trees. Conservative in memory.
Expand All @@ -282,7 +282,7 @@ def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
return self._transform_tree(tree)


class Transformer_NonRecursive(Transformer):
class Transformer_NonRecursive(Transformer[_Leaf_T, _Return_T]):
"""Same as Transformer but non-recursive.

Like Transformer, it doesn't change the original tree.
Expand Down
Loading