Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use forkserver start method for multiprocessing #296

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions green/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations


import atexit
import os
import shutil
import sys
import tempfile
from typing import Sequence
Expand Down Expand Up @@ -87,22 +88,18 @@ def _main(argv: Sequence[str] | None, testing: bool) -> int:
def main(argv: Sequence[str] | None = None, testing: bool = False) -> int:
# create the temp dir only once (i.e., not while in the recursed call)
if not os.environ.get("TMPDIR"): # pragma: nocover
# Use `atexit` to cleanup `temp_dir_for_tests` so that multiprocessing can run its
# own cleanup before its temp directory is deleted.
temp_dir_for_tests = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(temp_dir_for_tests, ignore_errors=True))
os.environ["TMPDIR"] = temp_dir_for_tests
prev_tempdir = tempfile.tempdir
tempfile.tempdir = temp_dir_for_tests
try:
with tempfile.TemporaryDirectory() as temp_dir_for_tests:
try:
os.environ["TMPDIR"] = temp_dir_for_tests
tempfile.tempdir = temp_dir_for_tests
return _main(argv, testing)
finally:
del os.environ["TMPDIR"]
tempfile.tempdir = None
except OSError as os_error:
if os_error.errno == 39:
# "Directory not empty" when trying to delete the temp dir can just be a warning
print(f"warning: {os_error.strerror}")
return 0
else:
raise os_error
return _main(argv, testing)
finally:
del os.environ["TMPDIR"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would want reset TMPDIR to the previous version if any. Since it was the pre-existing behavior I'm approving since the changes are definitely a net improvement.

Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We checked that TMPDIR was not set before entering this if branch. (Technically, it was either unset or set to an empty string, but these are the same for python's tempfile logic, so I think it's fine to delete in either case.)

tempfile.tempdir = prev_tempdir
else:
return _main(argv, testing)

Expand Down
3 changes: 1 addition & 2 deletions green/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import copy # pragma: no cover
import functools # pragma: no cover
import logging # pragma: no cover
import multiprocessing # pragma: no cover
import os # pragma: no cover
import pathlib # pragma: no cover
import sys # pragma: no cover
Expand All @@ -36,7 +35,7 @@ def get_default_args() -> argparse.Namespace:
"""
return argparse.Namespace( # pragma: no cover
targets=["."], # Not in configs
processes=multiprocessing.cpu_count(),
processes=os.cpu_count(),
initializer="",
finalizer="",
maxtasksperchild=None,
Expand Down
13 changes: 11 additions & 2 deletions green/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,21 @@ def run(
# The call to toParallelTargets needs to happen before pool stuff so we can crash if there
# are, for example, syntax errors in the code to be loaded.
parallel_targets = toParallelTargets(suite, args.targets)
# Use "forkserver" method when available to avoid problems with "fork". See, for example,
# https://github.com/python/cpython/issues/84559
if "forkserver" in multiprocessing.get_all_start_methods():
mp_method = "forkserver"
else:
mp_method = None
mp_context = multiprocessing.get_context(mp_method)
pool = LoggingDaemonlessPool(
processes=args.processes or None,
initializer=InitializerOrFinalizer(args.initializer),
finalizer=InitializerOrFinalizer(args.finalizer),
maxtasksperchild=args.maxtasksperchild,
context=mp_context,
)
manager: SyncManager = multiprocessing.Manager()
manager: SyncManager = mp_context.Manager()
targets: list[tuple[str, Queue]] = [
(target, manager.Queue()) for target in parallel_targets
]
Expand Down Expand Up @@ -165,10 +173,11 @@ def run(

pool.close()
pool.join()
manager.shutdown()

result.stopTestRun()

# Ignore the type mismatch untile we make GreenTestResult a subclass of unittest.TestResult.
# Ignore the type mismatch until we make GreenTestResult a subclass of unittest.TestResult.
removeResult(result) # type: ignore

return result
11 changes: 7 additions & 4 deletions green/test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from textwrap import dedent
import unittest
from unittest import mock
import warnings
import weakref

from green.config import get_default_args
Expand Down Expand Up @@ -114,7 +115,7 @@ def setUp(self):
self.loader = GreenTestLoader()

def tearDown(self):
del self.tmpdir
shutil.rmtree(self.tmpdir, ignore_errors=True)
del self.stream

def test_stdout(self):
Expand Down Expand Up @@ -162,7 +163,7 @@ def test01(self):

def test_warnings(self):
"""
setting warnings='always' doesn't crash
test runner does not generate warnings
"""
self.args.warnings = "always"
sub_tmpdir = pathlib.Path(tempfile.mkdtemp(dir=self.tmpdir))
Expand All @@ -177,10 +178,12 @@ def test01(self):
(sub_tmpdir / "test_warnings.py").write_text(content, encoding="utf-8")
os.chdir(sub_tmpdir)
try:
tests = self.loader.loadTargets("test_warnings")
result = run(tests, self.stream, self.args)
with warnings.catch_warnings(record=True) as recorded:
tests = self.loader.loadTargets("test_warnings")
result = run(tests, self.stream, self.args)
finally:
os.chdir(self.startdir)
self.assertEqual(recorded, [])
self.assertEqual(result.testsRun, 1)
self.assertIn("OK", self.stream.getvalue())

Expand Down