diff --git a/pytest_container/container.py b/pytest_container/container.py index 447cec4..6c2c7c9 100644 --- a/pytest_container/container.py +++ b/pytest_container/container.py @@ -43,6 +43,7 @@ import pytest import testinfra from _pytest.mark import ParameterSet +from filelock import BaseFileLock from filelock import FileLock from pytest_container.helpers import get_always_pull_option from pytest_container.inspect import ContainerHealth @@ -76,6 +77,11 @@ def __str__(self) -> str: return "oci" if self == ImageFormat.OCIv1 else "docker" +def lock_host_port_search(rootdir: Path) -> BaseFileLock: + """Generate a filelock for finding free ports on the host.""" + return FileLock(rootdir / "port_check.lock") + + def create_host_port_port_forward( port_forwards: List[PortForwarding], ) -> List[PortForwarding]: @@ -87,38 +93,34 @@ def create_host_port_port_forward( """ finished_forwards: List[PortForwarding] = [] - # list of sockets that will be created and cleaned up afterwards - # We have to defer the cleanup, as otherwise the OS might give us a - # previously freed socket again. But it will not do that, if we are still - # listening on it. - sockets: List[socket.socket] = [] - - for port in port_forwards: - if socket.has_ipv6 and (":" in port.bind_ip or not port.bind_ip): - family = socket.AF_INET6 - else: - family = socket.AF_INET + # We have to defer the cleanup of all sockets via an ExitStack, as otherwise + # the OS might give us a previously freed port again. But it will not do + # that, if we are still listening on it + with contextlib.ExitStack() as stack: + for port in port_forwards: + if socket.has_ipv6 and (":" in port.bind_ip or not port.bind_ip): + family = socket.AF_INET6 + else: + family = socket.AF_INET - sock = socket.socket( - family=family, - type=port.protocol.SOCK_CONST, - ) - sock.bind((port.bind_ip, max(0, port.host_port))) + sock = stack.enter_context( + socket.socket( + family=family, + type=port.protocol.SOCK_CONST, + ) + ) + sock.bind((port.bind_ip, max(0, port.host_port))) - port_num: int = sock.getsockname()[1] + port_num: int = sock.getsockname()[1] - finished_forwards.append( - PortForwarding( - container_port=port.container_port, - protocol=port.protocol, - host_port=port_num, - bind_ip=port.bind_ip, + finished_forwards.append( + PortForwarding( + container_port=port.container_port, + protocol=port.protocol, + host_port=port_num, + bind_ip=port.bind_ip, + ) ) - ) - sockets.append(sock) - - for sock in sockets: - sock.close() assert len(port_forwards) == len(finished_forwards) return finished_forwards @@ -1095,7 +1097,7 @@ def release_lock() -> None: # port forwards must be launched while the lock is being held. Otherwise # another container could pick the same ports before this one launches. if forwarded_ports and self._expose_ports: - with FileLock(self.rootdir / "port_check.lock"): + with lock_host_port_search(self.rootdir): self._new_port_forwards = create_host_port_port_forward( forwarded_ports ) diff --git a/pytest_container/pod.py b/pytest_container/pod.py index 83c74fc..629cc81 100644 --- a/pytest_container/pod.py +++ b/pytest_container/pod.py @@ -12,12 +12,12 @@ from typing import Union from _pytest.mark import ParameterSet -from filelock import FileLock from pytest_container.container import Container from pytest_container.container import ContainerData from pytest_container.container import ContainerLauncher from pytest_container.container import create_host_port_port_forward from pytest_container.container import DerivedContainer +from pytest_container.container import lock_host_port_search from pytest_container.inspect import PortForwarding from pytest_container.logging import _logger from pytest_container.runtime import get_selected_runtime @@ -122,7 +122,7 @@ def launch_pod(self) -> None: ) if self.pod.forwarded_ports: - with FileLock(self.rootdir / "port_check.lock"): + with lock_host_port_search(self.rootdir): self._new_port_forwards = create_host_port_port_forward( self.pod.forwarded_ports ) diff --git a/source/conf.py b/source/conf.py index 343c64c..13b59da 100644 --- a/source/conf.py +++ b/source/conf.py @@ -59,4 +59,7 @@ nitpicky = True nitpick_ignore = [("py:class", "py._path.local.LocalPath")] -nitpick_ignore_regex = [("py:class", "_pytest.*")] +nitpick_ignore_regex = [ + ("py:class", "_pytest.*"), + ("py:class", ".*BaseFileLock.*"), +] diff --git a/tests/test_port_forwarding.py b/tests/test_port_forwarding.py index ae1b0c1..111beed 100644 --- a/tests/test_port_forwarding.py +++ b/tests/test_port_forwarding.py @@ -7,10 +7,15 @@ import pytest from pytest_container.container import ContainerData +from pytest_container.container import ContainerLauncher from pytest_container.container import DerivedContainer +from pytest_container.container import lock_host_port_search from pytest_container.container import PortForwarding from pytest_container.inspect import NetworkProtocol +from pytest_container.pod import Pod +from pytest_container.pod import PodLauncher from pytest_container.runtime import LOCALHOST +from pytest_container.runtime import OciRuntimeBase from pytest_container.runtime import Version from .images import NGINX_URL @@ -219,29 +224,65 @@ def test_bind_to_address(addr: str, container: ContainerData, host) -> None: assert host.run_expect([7], cmd) -_sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) -_sock.bind(("", 0)) -_PORT = _sock.getsockname()[1] -_sock.close() +def test_container_bind_to_host_port( + container_runtime: OciRuntimeBase, host, pytestconfig: pytest.Config +) -> None: + with lock_host_port_search(pytestconfig.rootpath): + with socket.socket( + family=socket.AF_INET, type=socket.SOCK_STREAM + ) as sock: + sock.bind(("", 0)) + PORT = sock.getsockname()[1] + + ctr = DerivedContainer( + base=WEB_SERVER, + forwarded_ports=[ + PortForwarding(container_port=8000, host_port=PORT) + ], + ) + with ContainerLauncher( + container=ctr, + container_runtime=container_runtime, + rootdir=pytestconfig.rootpath, + ) as launcher: + launcher.launch_container() + + assert launcher.container_data.forwarded_ports[0].host_port == PORT + assert ( + host.run_expect( + [0], f"{_CURL} http://localhost:{PORT}" + ).stdout.strip() + == "Hello Green World!" + ) + + +def test_pod_bind_to_host_port( + container_runtime: OciRuntimeBase, host, pytestconfig: pytest.Config +) -> None: + if not container_runtime.runner_binary.endswith("podman"): + pytest.skip("pods are only supported with podman") + + with lock_host_port_search(pytestconfig.rootpath): + with socket.socket( + family=socket.AF_INET, type=socket.SOCK_STREAM + ) as sock: + sock.bind(("", 0)) + PORT = sock.getsockname()[1] + + pod = Pod( + containers=[WEB_SERVER], + forwarded_ports=[ + PortForwarding(container_port=8000, host_port=PORT) + ], + ) + with PodLauncher(pod=pod, rootdir=pytestconfig.rootpath) as launcher: + launcher.launch_pod() -@pytest.mark.parametrize( - "container", - ( - DerivedContainer( - base=WEB_SERVER, - forwarded_ports=[ - PortForwarding(container_port=8000, host_port=_PORT) - ], - ), - ), - indirect=True, -) -def test_bind_to_host_port(container: ContainerData, host) -> None: - assert container.forwarded_ports[0].host_port == _PORT - assert ( - host.run_expect( - [0], f"{_CURL} http://localhost:{_PORT}" - ).stdout.strip() - == "Hello Green World!" - ) + assert launcher.pod_data.forwarded_ports[0].host_port == PORT + assert ( + host.run_expect( + [0], f"{_CURL} http://localhost:{PORT}" + ).stdout.strip() + == "Hello Green World!" + )