Skip to content

Commit

Permalink
Merge pull request #1191 from lark-parser/adjust_pr1152
Browse files Browse the repository at this point in the history
Adjustments for PR #1152
  • Loading branch information
erezsh authored Sep 16, 2022
2 parents f775df3 + dce017c commit f009312
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 45 deletions.
31 changes: 17 additions & 14 deletions lark/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys, os, pickle, hashlib
import tempfile
import types
import re
from typing import (
TypeVar, Type, List, Dict, Iterator, Callable, Union, Optional, Sequence,
Tuple, Iterable, IO, Any, TYPE_CHECKING, Collection
Expand All @@ -15,6 +16,7 @@
from typing import Literal
else:
from typing_extensions import Literal
from .parser_frontends import ParsingFrontend

from .exceptions import ConfigurationError, assert_config, UnexpectedInput
from .utils import Serialize, SerializeMemoizer, FS, isascii, logger
Expand All @@ -27,7 +29,7 @@
from .parser_frontends import _validate_frontend_args, _get_lexer_callbacks, _deserialize_parsing_frontend, _construct_parsing_frontend
from .grammar import Rule

import re

try:
import regex
_has_regex = True
Expand Down Expand Up @@ -176,7 +178,7 @@ class LarkOptions(Serialize):
'_plugins': {},
}

def __init__(self, options_dict):
def __init__(self, options_dict: Dict[str, Any]) -> None:
o = dict(options_dict)

options = {}
Expand Down Expand Up @@ -205,21 +207,21 @@ def __init__(self, options_dict):
if o:
raise ConfigurationError("Unknown options: %s" % o.keys())

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
try:
return self.__dict__['options'][name]
except KeyError as e:
raise AttributeError(e)

def __setattr__(self, name, value):
def __setattr__(self, name: str, value: str) -> None:
assert_config(name, self.options.keys(), "%r isn't a valid option. Expected one of: %s")
self.options[name] = value

def serialize(self, memo):
def serialize(self, memo = None) -> Dict[str, Any]:
return self.options

@classmethod
def deserialize(cls, data, memo):
def deserialize(cls, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]]) -> "LarkOptions":
return cls(data)


Expand Down Expand Up @@ -252,7 +254,7 @@ class Lark(Serialize):
grammar: 'Grammar'
options: LarkOptions
lexer: Lexer
terminals: List[TerminalDef]
terminals: Collection[TerminalDef]

def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None:
self.options = LarkOptions(options)
Expand Down Expand Up @@ -446,15 +448,15 @@ def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None:

__serialize_fields__ = 'parser', 'rules', 'options'

def _build_lexer(self, dont_ignore=False):
def _build_lexer(self, dont_ignore: bool=False) -> BasicLexer:
lexer_conf = self.lexer_conf
if dont_ignore:
from copy import copy
lexer_conf = copy(lexer_conf)
lexer_conf.ignore = ()
return BasicLexer(lexer_conf)

def _prepare_callbacks(self):
def _prepare_callbacks(self) -> None:
self._callbacks = {}
# we don't need these callbacks if we aren't building a tree
if self.options.ambiguity != 'forest':
Expand All @@ -468,7 +470,7 @@ def _prepare_callbacks(self):
self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer)
self._callbacks.update(_get_lexer_callbacks(self.options.transformer, self.terminals))

def _build_parser(self):
def _build_parser(self) -> "ParsingFrontend":
self._prepare_callbacks()
_validate_frontend_args(self.options.parser, self.options.lexer)
parser_conf = ParserConf(self.rules, self._callbacks, self.options.start)
Expand All @@ -480,7 +482,7 @@ def _build_parser(self):
options=self.options
)

def save(self, f, exclude_options: Collection[str] = ()):
def save(self, f, exclude_options: Collection[str] = ()) -> None:
"""Saves the instance into the given file object
Useful for caching and multiprocessing.
Expand All @@ -491,15 +493,15 @@ def save(self, f, exclude_options: Collection[str] = ()):
pickle.dump({'data': data, 'memo': m}, f, protocol=pickle.HIGHEST_PROTOCOL)

@classmethod
def load(cls, f):
def load(cls: Type[_T], f) -> _T:
"""Loads an instance from the given file object
Useful for caching and multiprocessing.
"""
inst = cls.__new__(cls)
return inst._load(f)

def _deserialize_lexer_conf(self, data, memo, options):
def _deserialize_lexer_conf(self, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]], options: LarkOptions) -> LexerConf:
lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo)
lexer_conf.callbacks = options.lexer_callbacks or {}
lexer_conf.re_module = regex if options.regex else re
Expand All @@ -509,7 +511,7 @@ def _deserialize_lexer_conf(self, data, memo, options):
lexer_conf.postlex = options.postlex
return lexer_conf

def _load(self, f, **kwargs):
def _load(self: _T, f: Any, **kwargs) -> _T:
if isinstance(f, dict):
d = f
else:
Expand Down Expand Up @@ -593,6 +595,7 @@ def lex(self, text: str, dont_ignore: bool=False) -> Iterator[Token]:
:raises UnexpectedCharacters: In case the lexer cannot find a suitable match.
"""
lexer: Lexer
if not hasattr(self, 'lexer') or dont_ignore:
lexer = self._build_lexer(dont_ignore)
else:
Expand Down
3 changes: 2 additions & 1 deletion lark/parsers/lalr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: Erez Shinan (2017)
# Email : [email protected]
from copy import deepcopy, copy
from typing import Dict, Any
from ..lexer import Token
from ..utils import Serialize

Expand All @@ -29,7 +30,7 @@ def deserialize(cls, data, memo, callbacks, debug=False):
inst.parser = _Parser(inst._parse_table, callbacks, debug)
return inst

def serialize(self, memo):
def serialize(self, memo: Any = None) -> Dict[str, Any]:
return self._parse_table.serialize(memo)

def parse_interactive(self, lexer, start):
Expand Down
70 changes: 40 additions & 30 deletions lark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os
from functools import reduce
from collections import deque
from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence

###{standalone
import sys, re
import logging

logger: logging.Logger = logging.getLogger("lark")
logger.addHandler(logging.StreamHandler())
# Set to highest level, since we have some warnings amongst the code
Expand All @@ -15,9 +17,11 @@

NO_VALUE = object()

T = TypeVar("T")


def classify(seq, key=None, value=None):
d = {}
def classify(seq: Sequence, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
d: Dict[Any, Any] = {}
for item in seq:
k = key(item) if (key is not None) else item
v = value(item) if (value is not None) else item
Expand All @@ -28,7 +32,7 @@ def classify(seq, key=None, value=None):
return d


def _deserialize(data, namespace, memo):
def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
if isinstance(data, dict):
if '__type__' in data: # Object
class_ = namespace[data['__type__']]
Expand All @@ -41,6 +45,8 @@ def _deserialize(data, namespace, memo):
return data


_T = TypeVar("_T", bound="Serialize")

class Serialize:
"""Safe-ish serialization interface that doesn't rely on Pickle
Expand All @@ -50,23 +56,23 @@ class Serialize:
Should include all field types that aren't builtin types.
"""

def memo_serialize(self, types_to_memoize):
def memo_serialize(self, types_to_memoize: List) -> Any:
memo = SerializeMemoizer(types_to_memoize)
return self.serialize(memo), memo.serialize()

def serialize(self, memo=None):
def serialize(self, memo = None) -> Dict[str, Any]:
if memo and memo.in_types(self):
return {'@': memo.memoized.get(self)}

fields = getattr(self, '__serialize_fields__')
res = {f: _serialize(getattr(self, f), memo) for f in fields}
res['__type__'] = type(self).__name__
if hasattr(self, '_serialize'):
self._serialize(res, memo)
self._serialize(res, memo) # type: ignore[attr-defined]
return res

@classmethod
def deserialize(cls, data, memo):
def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
namespace = getattr(cls, '__serialize_namespace__', [])
namespace = {c.__name__:c for c in namespace}

Expand All @@ -83,7 +89,7 @@ def deserialize(cls, data, memo):
raise KeyError("Cannot find key for class", cls, e)

if hasattr(inst, '_deserialize'):
inst._deserialize()
inst._deserialize() # type: ignore[attr-defined]

return inst

Expand All @@ -93,18 +99,18 @@ class SerializeMemoizer(Serialize):

__serialize_fields__ = 'memoized',

def __init__(self, types_to_memoize):
def __init__(self, types_to_memoize: List) -> None:
self.types_to_memoize = tuple(types_to_memoize)
self.memoized = Enumerator()

def in_types(self, value):
def in_types(self, value: Serialize) -> bool:
return isinstance(value, self.types_to_memoize)

def serialize(self):
def serialize(self) -> Dict[int, Any]: # type: ignore[override]
return _serialize(self.memoized.reversed(), None)

@classmethod
def deserialize(cls, data, namespace, memo):
def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
return _deserialize(data, namespace, memo)


Expand All @@ -123,7 +129,7 @@ def deserialize(cls, data, namespace, memo):

categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')

def get_regexp_width(expr):
def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
if _has_regex:
# Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
# a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
Expand All @@ -134,7 +140,8 @@ def get_regexp_width(expr):
raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
regexp_final = expr
try:
return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
# Fixed in next version (past 0.960) of typeshed
return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined]
except sre_constants.error:
if not _has_regex:
raise ValueError(expr)
Expand All @@ -154,47 +161,50 @@ def get_regexp_width(expr):
_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)

def _test_unicode_category(s, categories):
def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
if len(s) != 1:
return all(_test_unicode_category(char, categories) for char in s)
return s == '_' or unicodedata.category(s) in categories

def is_id_continue(s):
def is_id_continue(s: str) -> bool:
"""
Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
"""
return _test_unicode_category(s, _ID_CONTINUE)

def is_id_start(s):
def is_id_start(s: str) -> bool:
"""
Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
"""
return _test_unicode_category(s, _ID_START)


def dedup_list(l):
def dedup_list(l: List[T]) -> List[T]:
"""Given a list (l) will removing duplicates from the list,
preserving the original order of the list. Assumes that
the list entries are hashable."""
dedup = set()
return [x for x in l if not (x in dedup or dedup.add(x))]
# This returns None, but that's expected
return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value]
# 2x faster (ordered in PyPy and CPython 3.6+, gaurenteed to be ordered in Python 3.7+)
# return list(dict.fromkeys(l))


class Enumerator(Serialize):
def __init__(self):
self.enums = {}
def __init__(self) -> None:
self.enums: Dict[Any, int] = {}

def get(self, item):
def get(self, item) -> int:
if item not in self.enums:
self.enums[item] = len(self.enums)
return self.enums[item]

def __len__(self):
return len(self.enums)

def reversed(self):
def reversed(self) -> Dict[int, Any]:
r = {v: k for k, v in self.enums.items()}
assert len(r) == len(self.enums)
return r
Expand Down Expand Up @@ -240,11 +250,11 @@ def open(name, mode="r", **kwargs):



def isascii(s):
def isascii(s: str) -> bool:
""" str.isascii only exists in python3.7+ """
try:
if sys.version_info >= (3, 7):
return s.isascii()
except AttributeError:
else:
try:
s.encode('ascii')
return True
Expand All @@ -257,7 +267,7 @@ def __repr__(self):
return '{%s}' % ', '.join(map(repr, self))


def classify_bool(seq, pred):
def classify_bool(seq: Sequence, pred: Callable) -> Any:
true_elems = []
false_elems = []

Expand All @@ -270,7 +280,7 @@ def classify_bool(seq, pred):
return true_elems, false_elems


def bfs(initial, expand):
def bfs(initial: Sequence, expand: Callable) -> Iterator:
open_q = deque(list(initial))
visited = set(open_q)
while open_q:
Expand All @@ -290,7 +300,7 @@ def bfs_all_unique(initial, expand):
open_q += expand(node)


def _serialize(value, memo):
def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
if isinstance(value, Serialize):
return value.serialize(memo)
elif isinstance(value, list):
Expand All @@ -305,7 +315,7 @@ def _serialize(value, memo):



def small_factors(n, max_factor):
def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
"""
Splits n up into smaller factors and summands <= max_factor.
Returns a list of [(a, b), ...]
Expand Down

0 comments on commit f009312

Please sign in to comment.