diff --git a/examples/test_example_fail_in_thread.py b/examples/test_example_fail_in_thread.py new file mode 100644 index 0000000..b67487b --- /dev/null +++ b/examples/test_example_fail_in_thread.py @@ -0,0 +1,21 @@ +from concurrent.futures.thread import ThreadPoolExecutor +from threading import Thread + + +import pytest_check as check + + +def force_fail(comparison): + check.equal(1 + 1, comparison, f"1 + 1 is 2, not {comparison}") + + +def test_threadpool(): + with ThreadPoolExecutor() as executor: + task = executor.submit(force_fail, 3) + task.result() + + +def test_threading(): + t = Thread(target=force_fail, args=(4, )) + t.start() + t.join() diff --git a/examples/test_example_pass_in_thread.py b/examples/test_example_pass_in_thread.py new file mode 100644 index 0000000..8936317 --- /dev/null +++ b/examples/test_example_pass_in_thread.py @@ -0,0 +1,21 @@ +from concurrent.futures.thread import ThreadPoolExecutor +from threading import Thread + + +import pytest_check as check + + +def always_pass(): + check.equal(1 + 1, 2) + + +def test_threadpool(): + with ThreadPoolExecutor() as executor: + task = executor.submit(always_pass) + task.result() + + +def test_threading(): + t = Thread(target=always_pass) + t.start() + t.join() diff --git a/src/pytest_check/pseudo_traceback.py b/src/pytest_check/pseudo_traceback.py index 4ea67e3..4061b2a 100644 --- a/src/pytest_check/pseudo_traceback.py +++ b/src/pytest_check/pseudo_traceback.py @@ -4,8 +4,8 @@ _traceback_style = "auto" -def get_full_context(level): - (_, filename, line, funcname, contextlist) = inspect.stack()[level][0:5] +def get_full_context(frame): + (_, filename, line, funcname, contextlist) = frame[0:5] try: filename = os.path.relpath(filename) except ValueError: # pragma: no cover @@ -28,11 +28,12 @@ def _build_pseudo_trace_str(): if _traceback_style == "no": return "" - level = 4 + skip_own_frames = 3 pseudo_trace = [] func = "" - while "test_" not in func: - (file, line, func, context) = get_full_context(level) + context_stack = inspect.stack()[skip_own_frames:] + while "test_" not in func and context_stack: + (file, line, func, context) = get_full_context(context_stack.pop(0)) # we want to trace through user code, not 3rd party or builtin libs if "site-packages" in file: break @@ -41,6 +42,5 @@ def _build_pseudo_trace_str(): break line = f"{file}:{line} in {func}() -> {context}" pseudo_trace.append(line) - level += 1 return "\n".join(reversed(pseudo_trace)) + "\n" diff --git a/tests/test_thread.py b/tests/test_thread.py new file mode 100644 index 0000000..023db09 --- /dev/null +++ b/tests/test_thread.py @@ -0,0 +1,12 @@ +def test_failing_threaded_testcode(pytester): + pytester.copy_example("examples/test_example_fail_in_thread.py") + result = pytester.runpytest() + result.assert_outcomes(failed=2, passed=0) + result.stdout.fnmatch_lines(["*1 + 1 is 2, not 3*"]) + result.stdout.fnmatch_lines(["*1 + 1 is 2, not 4*"]) + + +def test_passing_threaded_testcode(pytester): + pytester.copy_example("examples/test_example_pass_in_thread.py") + result = pytester.runpytest() + result.assert_outcomes(failed=0, passed=2)