Skip to content

Commit

Permalink
Replace MarkerTypes by BaseMarker in type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
radoering authored and abn committed Mar 5, 2022
1 parent c5f4cda commit c335b76
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 61 deletions.
4 changes: 2 additions & 2 deletions src/poetry/core/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from poetry.core.packages.types import DependencyTypes
from poetry.core.poetry import Poetry
from poetry.core.spdx.license import License
from poetry.core.version.markers import MarkerTypes
from poetry.core.version.markers import BaseMarker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -347,7 +347,7 @@ def create_dependency(
)

if not markers:
marker: "MarkerTypes" = AnyMarker()
marker: "BaseMarker" = AnyMarker()
if python_versions:
marker = marker.intersect(
parse_marker(
Expand Down
112 changes: 55 additions & 57 deletions src/poetry/core/version/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Iterator
from typing import Iterable
from typing import List
from typing import Type
from typing import Union

from poetry.core.version.grammars import GRAMMAR_PEP_508_MARKERS
Expand All @@ -16,10 +17,6 @@

from poetry.core.semver.helpers import VersionTypes

MarkerTypes = Union[
"AnyMarker", "EmptyMarker", "SingleMarker", "MultiMarker", "MarkerUnion"
]


class InvalidMarker(ValueError):
"""
Expand Down Expand Up @@ -88,10 +85,10 @@ def __repr__(self) -> str:


class AnyMarker(BaseMarker):
def intersect(self, other: MarkerTypes) -> MarkerTypes:
def intersect(self, other: BaseMarker) -> BaseMarker:
return other

def union(self, other: MarkerTypes) -> MarkerTypes:
def union(self, other: BaseMarker) -> BaseMarker:
return self

def is_any(self) -> bool:
Expand All @@ -103,13 +100,13 @@ def is_empty(self) -> bool:
def validate(self, environment: Dict[str, Any]) -> bool:
return True

def without_extras(self) -> MarkerTypes:
def without_extras(self) -> BaseMarker:
return self

def exclude(self, marker_name: str) -> MarkerTypes:
def exclude(self, marker_name: str) -> BaseMarker:
return self

def only(self, *marker_names: str) -> MarkerTypes:
def only(self, *marker_names: str) -> BaseMarker:
return self

def invert(self) -> "EmptyMarker":
Expand All @@ -132,10 +129,10 @@ def __eq__(self, other: object) -> bool:


class EmptyMarker(BaseMarker):
def intersect(self, other: MarkerTypes) -> MarkerTypes:
def intersect(self, other: BaseMarker) -> BaseMarker:
return self

def union(self, other: MarkerTypes) -> MarkerTypes:
def union(self, other: BaseMarker) -> BaseMarker:
return other

def is_any(self) -> bool:
Expand Down Expand Up @@ -258,13 +255,13 @@ def operator(self) -> str:
def value(self) -> str:
return self._value

def intersect(self, other: MarkerTypes) -> MarkerTypes:
def intersect(self, other: BaseMarker) -> BaseMarker:
if isinstance(other, SingleMarker):
return MultiMarker.of(self, other)

return other.intersect(self)

def union(self, other: MarkerTypes) -> MarkerTypes:
def union(self, other: BaseMarker) -> BaseMarker:
if isinstance(other, SingleMarker):
if self == other:
return self
Expand All @@ -285,10 +282,10 @@ def validate(self, environment: Dict[str, Any]) -> bool:

return self._constraint.allows(self._parser(environment[self._name]))

def without_extras(self) -> MarkerTypes:
def without_extras(self) -> BaseMarker:
return self.exclude("extra")

def exclude(self, marker_name: str) -> MarkerTypes:
def exclude(self, marker_name: str) -> BaseMarker:
if self.name == marker_name:
return AnyMarker()

Expand All @@ -300,7 +297,7 @@ def only(self, *marker_names: str) -> Union["SingleMarker", EmptyMarker]:

return self

def invert(self) -> MarkerTypes:
def invert(self) -> BaseMarker:
if self._operator in ("===", "=="):
operator = "!="
elif self._operator == "!=":
Expand Down Expand Up @@ -361,8 +358,9 @@ def __str__(self) -> str:


def _flatten_markers(
markers: Iterator[Union["MarkerUnion", "MultiMarker"]], flatten_class: Any
) -> List[MarkerTypes]:
markers: Iterable[BaseMarker],
flatten_class: Type[Union["MarkerUnion", "MultiMarker"]],
) -> List[BaseMarker]:
flattened = []

for marker in markers:
Expand All @@ -375,23 +373,23 @@ def _flatten_markers(


class MultiMarker(BaseMarker):
def __init__(self, *markers: MarkerTypes) -> None:
def __init__(self, *markers: BaseMarker) -> None:
self._markers = []

markers = _flatten_markers(markers, MultiMarker)
flattened_markers = _flatten_markers(markers, MultiMarker)

for m in markers:
for m in flattened_markers:
self._markers.append(m)

@classmethod
def of(cls, *markers: MarkerTypes) -> MarkerTypes:
def of(cls, *markers: BaseMarker) -> BaseMarker:
new_markers = _flatten_markers(markers, MultiMarker)
markers = []
old_markers: List[BaseMarker] = []

while markers != new_markers:
markers = new_markers
while old_markers != new_markers:
old_markers = new_markers
new_markers = []
for marker in markers:
for marker in old_markers:
if marker in new_markers:
continue

Expand Down Expand Up @@ -446,10 +444,10 @@ def of(cls, *markers: MarkerTypes) -> MarkerTypes:
return MultiMarker(*new_markers)

@property
def markers(self) -> List[MarkerTypes]:
def markers(self) -> List[BaseMarker]:
return self._markers

def intersect(self, other: MarkerTypes) -> MarkerTypes:
def intersect(self, other: BaseMarker) -> BaseMarker:
if other.is_any():
return self

Expand All @@ -460,7 +458,7 @@ def intersect(self, other: MarkerTypes) -> MarkerTypes:

return MultiMarker.of(*new_markers)

def union(self, other: MarkerTypes) -> MarkerTypes:
def union(self, other: BaseMarker) -> BaseMarker:
if other in self._markers:
return other

Expand All @@ -472,10 +470,10 @@ def union(self, other: MarkerTypes) -> MarkerTypes:
def validate(self, environment: Dict[str, Any]) -> bool:
return all(m.validate(environment) for m in self._markers)

def without_extras(self) -> MarkerTypes:
def without_extras(self) -> BaseMarker:
return self.exclude("extra")

def exclude(self, marker_name: str) -> MarkerTypes:
def exclude(self, marker_name: str) -> BaseMarker:
new_markers = []

for m in self._markers:
Expand All @@ -490,7 +488,7 @@ def exclude(self, marker_name: str) -> MarkerTypes:

return self.of(*new_markers)

def only(self, *marker_names: str) -> MarkerTypes:
def only(self, *marker_names: str) -> BaseMarker:
new_markers = []

for m in self._markers:
Expand All @@ -505,7 +503,7 @@ def only(self, *marker_names: str) -> MarkerTypes:

return self.of(*new_markers)

def invert(self) -> MarkerTypes:
def invert(self) -> BaseMarker:
markers = [marker.invert() for marker in self._markers]

return MarkerUnion.of(*markers)
Expand Down Expand Up @@ -535,28 +533,28 @@ def __str__(self) -> str:


class MarkerUnion(BaseMarker):
def __init__(self, *markers: MarkerTypes) -> None:
def __init__(self, *markers: BaseMarker) -> None:
self._markers = list(markers)

@property
def markers(self) -> List[MarkerTypes]:
def markers(self) -> List[BaseMarker]:
return self._markers

@classmethod
def of(cls, *markers: BaseMarker) -> MarkerTypes:
def of(cls, *markers: BaseMarker) -> BaseMarker:
flattened_markers = _flatten_markers(markers, MarkerUnion)

markers = []
new_markers: List[BaseMarker] = []
for marker in flattened_markers:
if marker in markers:
if marker in new_markers:
continue

if (
isinstance(marker, SingleMarker)
and marker.name in PYTHON_VERSION_MARKERS
):
included = False
for i, mark in enumerate(markers):
for i, mark in enumerate(new_markers):
if (
not isinstance(mark, SingleMarker)
or mark.name not in PYTHON_VERSION_MARKERS
Expand All @@ -568,7 +566,7 @@ def of(cls, *markers: BaseMarker) -> MarkerTypes:
included = True
break
elif union == marker.constraint:
markers[i] = marker
new_markers[i] = marker
included = True
break
elif union.is_any():
Expand All @@ -577,26 +575,26 @@ def of(cls, *markers: BaseMarker) -> MarkerTypes:
if included:
continue

markers.append(marker)
new_markers.append(marker)

if any(m.is_any() for m in markers):
if any(m.is_any() for m in new_markers):
return AnyMarker()

if not markers:
if not new_markers:
return EmptyMarker()

if len(markers) == 1:
return markers[0]
if len(new_markers) == 1:
return new_markers[0]

return MarkerUnion(*markers)
return MarkerUnion(*new_markers)

def append(self, marker: MarkerTypes) -> None:
def append(self, marker: BaseMarker) -> None:
if marker in self._markers:
return

self._markers.append(marker)

def intersect(self, other: MarkerTypes) -> MarkerTypes:
def intersect(self, other: BaseMarker) -> BaseMarker:
if other.is_any():
return self

Expand All @@ -620,7 +618,7 @@ def intersect(self, other: MarkerTypes) -> MarkerTypes:

return MarkerUnion.of(*new_markers)

def union(self, other: MarkerTypes) -> MarkerTypes:
def union(self, other: BaseMarker) -> BaseMarker:
if other.is_any():
return other

Expand All @@ -634,10 +632,10 @@ def union(self, other: MarkerTypes) -> MarkerTypes:
def validate(self, environment: Dict[str, Any]) -> bool:
return any(m.validate(environment) for m in self._markers)

def without_extras(self) -> MarkerTypes:
def without_extras(self) -> BaseMarker:
return self.exclude("extra")

def exclude(self, marker_name: str) -> MarkerTypes:
def exclude(self, marker_name: str) -> BaseMarker:
new_markers = []

for m in self._markers:
Expand All @@ -652,7 +650,7 @@ def exclude(self, marker_name: str) -> MarkerTypes:

return self.of(*new_markers)

def only(self, *marker_names: str) -> MarkerTypes:
def only(self, *marker_names: str) -> BaseMarker:
new_markers = []

for m in self._markers:
Expand All @@ -667,7 +665,7 @@ def only(self, *marker_names: str) -> MarkerTypes:

return self.of(*new_markers)

def invert(self) -> MarkerTypes:
def invert(self) -> BaseMarker:
markers = [marker.invert() for marker in self._markers]

return MultiMarker.of(*markers)
Expand Down Expand Up @@ -697,7 +695,7 @@ def is_empty(self) -> bool:
return all(m.is_empty() for m in self._markers)


def parse_marker(marker: str) -> MarkerTypes:
def parse_marker(marker: str) -> BaseMarker:
if marker == "<empty>":
return EmptyMarker()

Expand All @@ -711,10 +709,10 @@ def parse_marker(marker: str) -> MarkerTypes:
return markers


def _compact_markers(tree_elements: "Tree", tree_prefix: str = "") -> MarkerTypes:
def _compact_markers(tree_elements: "Tree", tree_prefix: str = "") -> BaseMarker:
from lark import Token

groups = [MultiMarker()]
groups: List[BaseMarker] = [MultiMarker()]
for token in tree_elements:
if isinstance(token, Token):
if token.type == f"{tree_prefix}BOOL_OP" and token.value == "or":
Expand Down
4 changes: 2 additions & 2 deletions tests/version/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


if TYPE_CHECKING:
from poetry.core.version.markers import MarkerTypes
from poetry.core.version.markers import BaseMarker


def assert_requirement(
Expand All @@ -22,7 +22,7 @@ def assert_requirement(
url: Optional[str] = None,
extras: Optional[List[str]] = None,
constraint: str = "*",
marker: Optional["MarkerTypes"] = None,
marker: Optional["BaseMarker"] = None,
):
if extras is None:
extras = []
Expand Down

0 comments on commit c335b76

Please sign in to comment.