Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve performance for merging markers from overrides #10018

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 64 additions & 30 deletions src/poetry/puzzle/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def _solve_in_compatibility_mode(
self,
overrides: tuple[dict[Package, dict[str, Dependency]], ...],
) -> dict[Package, TransitivePackageInfo]:
packages: dict[Package, TransitivePackageInfo] = {}
override_packages: list[
tuple[
dict[Package, dict[str, Dependency]],
dict[Package, TransitivePackageInfo],
]
] = []
for override in overrides:
self._provider.debug(
# ignore the warning as provider does not do interpolation
Expand All @@ -149,9 +154,9 @@ def _solve_in_compatibility_mode(
)
self._provider.set_overrides(override)
new_packages = self._solve()
merge_packages_from_override(packages, new_packages, override)
override_packages.append((override, new_packages))

return packages
return merge_override_packages(override_packages)

def _solve(self) -> dict[Package, TransitivePackageInfo]:
if self._provider._overrides:
Expand Down Expand Up @@ -406,34 +411,63 @@ def calculate_markers(
transitive_info.markers = transitive_marker


def merge_packages_from_override(
packages: dict[Package, TransitivePackageInfo],
new_packages: dict[Package, TransitivePackageInfo],
override: dict[Package, dict[str, Dependency]],
) -> None:
override_marker: BaseMarker = AnyMarker()
for deps in override.values():
for dep in deps.values():
override_marker = override_marker.intersect(dep.marker.without_extras())
for new_package, new_package_info in new_packages.items():
if package_info := packages.get(new_package):
# update existing package
package_info.depth = max(package_info.depth, new_package_info.depth)
package_info.groups.update(new_package_info.groups)
for group, marker in new_package_info.markers.items():
package_info.markers[group] = package_info.markers.get(
group, EmptyMarker()
).union(override_marker.intersect(marker))
for package in packages:
if package == new_package:
for dep in new_package.requires:
if dep not in package.requires:
package.add_dependency(dep)

def merge_override_packages(
override_packages: list[
tuple[
dict[Package, dict[str, Dependency]], dict[Package, TransitivePackageInfo]
]
],
) -> dict[Package, TransitivePackageInfo]:
result: dict[Package, TransitivePackageInfo] = {}
all_packages: dict[
Package, list[tuple[Package, TransitivePackageInfo, BaseMarker]]
] = {}
for override, o_packages in override_packages:
override_marker: BaseMarker = AnyMarker()
for deps in override.values():
for dep in deps.values():
override_marker = override_marker.intersect(dep.marker.without_extras())
for package, info in o_packages.items():
all_packages.setdefault(package, []).append(
(package, info, override_marker)
)
for package_duplicates in all_packages.values():
base = package_duplicates[0]
package = base[0]
package_info = base[1]
first_override_marker = base[2]
result[package] = package_info
package_info.depth = max(info.depth for _, info, _ in package_duplicates)
package_info.groups = {
g for _, info, _ in package_duplicates for g in info.groups
}
if all(
info.markers == package_info.markers for _, info, _ in package_duplicates
):
# performance shortcut:
# if markers are the same for all overrides,
# we can use less expensive marker operations
override_marker = EmptyMarker()
for _, _, marker in package_duplicates:
override_marker = override_marker.union(marker)
package_info.markers = {
group: override_marker.intersect(marker)
for group, marker in package_info.markers.items()
}
else:
for group, marker in new_package_info.markers.items():
new_package_info.markers[group] = override_marker.intersect(marker)
packages[new_package] = new_package_info
# fallback / general algorithm with performance issues
for group, marker in package_info.markers.items():
package_info.markers[group] = first_override_marker.intersect(marker)
for _, info, override_marker in package_duplicates[1:]:
for group, marker in info.markers.items():
package_info.markers[group] = package_info.markers.get(
group, EmptyMarker()
).union(override_marker.intersect(marker))
for duplicate_package, _, _ in package_duplicates[1:]:
for dep in duplicate_package.requires:
if dep not in package.requires:
package.add_dependency(dep)
return result


@functools.cache
Expand Down
169 changes: 89 additions & 80 deletions tests/puzzle/test_solver_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from poetry.puzzle.solver import PackageNode
from poetry.puzzle.solver import Solver
from poetry.puzzle.solver import depth_first_search
from poetry.puzzle.solver import merge_packages_from_override
from poetry.puzzle.solver import merge_override_packages


if TYPE_CHECKING:
Expand Down Expand Up @@ -359,28 +359,29 @@ def test_propagate_markers_with_cycle(package: ProjectPackage, solver: Solver) -
}


def test_merge_packages_from_override_restricted(package: ProjectPackage) -> None:
def test_merge_override_packages_restricted(package: ProjectPackage) -> None:
"""Markers of dependencies should be intersected with override markers."""
a = Package("a", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
{package: {"a": dep("b", 'python_version < "3.9"')}},
)
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
{package: {"a": dep("b", 'python_version >= "3.9"')}},
packages = merge_override_packages(
[
(
{package: {"a": dep("b", 'python_version < "3.9"')}},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
),
(
{package: {"a": dep("b", 'python_version >= "3.9"')}},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
),
]
)
assert len(packages) == 1
assert packages[a].groups == {"main"}
Expand All @@ -392,28 +393,33 @@ def test_merge_packages_from_override_restricted(package: ProjectPackage) -> Non
}


def test_merge_packages_from_override_extras(package: ProjectPackage) -> None:
def test_merge_override_packages_extras(package: ProjectPackage) -> None:
"""Extras from overrides should not be visible in the resulting marker."""
a = Package("a", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
{package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}},
)
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
{package: {"a": dep("b", 'python_version >= "3.9" and extra == "foo"')}},
packages = merge_override_packages(
[
(
{package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
),
(
{
package: {
"a": dep("b", 'python_version >= "3.9" and extra == "foo"')
}
},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
),
]
)
assert len(packages) == 1
assert packages[a].groups == {"main"}
Expand All @@ -425,21 +431,23 @@ def test_merge_packages_from_override_extras(package: ProjectPackage) -> None:
}


def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) -> None:
def test_merge_override_packages_multiple_deps(package: ProjectPackage) -> None:
"""All override markers should be intersected."""
a = Package("a", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})},
{
package: {
"a": dep("b", 'python_version < "3.9"'),
"c": dep("d", 'sys_platform == "linux"'),
},
a: {"e": dep("f", 'python_version >= "3.8"')},
},
packages = merge_override_packages(
[
(
{
package: {
"a": dep("b", 'python_version < "3.9"'),
"c": dep("d", 'sys_platform == "linux"'),
},
a: {"e": dep("f", 'python_version >= "3.8"')},
},
{a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})},
),
]
)

assert len(packages) == 1
Expand All @@ -452,44 +460,45 @@ def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) ->
}


def test_merge_packages_from_override_groups(package: ProjectPackage) -> None:
def test_merge_override_packages_groups(package: ProjectPackage) -> None:
a = Package("a", "1")
b = Package("b", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
packages = merge_override_packages(
[
(
{package: {"a": dep("b", 'python_version < "3.9"')}},
{
"main": parse_marker("sys_platform == 'win32'"),
"dev": parse_marker("sys_platform == 'linux'"),
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
{
"main": parse_marker("sys_platform == 'win32'"),
"dev": parse_marker("sys_platform == 'linux'"),
},
),
},
),
},
{package: {"a": dep("b", 'python_version < "3.9"')}},
)
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
(
{package: {"a": dep("b", 'python_version >= "3.9"')}},
{
"main": parse_marker("platform_machine == 'amd64'"),
"dev": parse_marker("platform_machine == 'aarch64'"),
a: TransitivePackageInfo(
0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
{
"main": parse_marker("platform_machine == 'amd64'"),
"dev": parse_marker("platform_machine == 'aarch64'"),
},
),
},
),
},
{package: {"a": dep("b", 'python_version >= "3.9"')}},
]
)
assert len(packages) == 2
assert packages[a].groups == {"main", "dev"}
Expand Down
Loading