diff --git a/build-support/githooks/pre-commit b/build-support/githooks/pre-commit index 45b6d8108cb..27524b031cd 100755 --- a/build-support/githooks/pre-commit +++ b/build-support/githooks/pre-commit @@ -53,7 +53,7 @@ echo "* Checking shell scripts via our custom linter" # fails in pants changed. if git rev-parse --verify "${MERGE_BASE}" &>/dev/null; then echo "* Checking imports" - ./build-support/bin/isort.sh || die "To fix import sort order, run \`\"$(pwd)/build-support/bin/isort.sh\" -f\`" + ./build-support/bin/isort.sh || die "To fix import sort order, run \`build-support/bin/isort.sh -f\`" # TODO(CMLivingston) Make lint use `-q` option again after addressing proper workunit labeling: # https://github.com/pantsbuild/pants/issues/6633 diff --git a/src/python/pants/build_graph/build_configuration.py b/src/python/pants/build_graph/build_configuration.py index 5512d3a3910..d99070e9ddf 100644 --- a/src/python/pants/build_graph/build_configuration.py +++ b/src/python/pants/build_graph/build_configuration.py @@ -170,7 +170,7 @@ def rules(self): return list(self._rules) def union_rules(self): - """Returns a mapping of registered union base types -> [a list of union member types]. + """Returns a mapping of registered union base types -> [OrderedSet of union member types]. :rtype: OrderedDict """ diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index 865ef11464a..d5b06e34acc 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -10,8 +10,9 @@ from abc import ABC, abstractmethod from collections import OrderedDict from collections.abc import Iterable +from dataclasses import dataclass from textwrap import dedent -from typing import Any, Callable, Type, cast +from typing import Any, Callable, Dict, Type, cast import asttokens from twitter.common.collections import OrderedSet @@ -428,6 +429,17 @@ def __new__(cls, union_base, union_member): return super().__new__(cls, union_base, union_member) +@dataclass(frozen=True) +class UnionMembership: + union_rules: Dict[type, typing.Iterable[type]] + + def is_member(self, union_type, putative_member): + members = self.union_rules.get(union_type) + if members is None: + raise TypeError(f'Not a registered union type: {union_type}') + return type(putative_member) in members + + class Rule(ABC): """Rules declare how to produce products for the product graph. diff --git a/src/python/pants/init/engine_initializer.py b/src/python/pants/init/engine_initializer.py index 5df054546b7..65bea4ed791 100644 --- a/src/python/pants/init/engine_initializer.py +++ b/src/python/pants/init/engine_initializer.py @@ -42,7 +42,7 @@ from pants.engine.mapper import AddressMapper from pants.engine.parser import SymbolTable from pants.engine.platform import create_platform_rules -from pants.engine.rules import RootRule, rule +from pants.engine.rules import RootRule, UnionMembership, rule from pants.engine.scheduler import Scheduler from pants.engine.selectors import Params from pants.init.options_initializer import BuildConfigInitializer, OptionsInitializer @@ -357,6 +357,10 @@ def build_configuration_singleton() -> BuildConfiguration: def symbol_table_singleton() -> SymbolTable: return symbol_table + @rule + def union_membership_singleton() -> UnionMembership: + return UnionMembership(build_configuration.union_rules()) + # Create a Scheduler containing graph and filesystem rules, with no installed goals. The # LegacyBuildGraph will explicitly request the products it needs. rules = ( @@ -365,6 +369,7 @@ def symbol_table_singleton() -> SymbolTable: glob_match_error_behavior_singleton, build_configuration_singleton, symbol_table_singleton, + union_membership_singleton, ] + create_legacy_graph_tasks() + create_fs_rules() + diff --git a/src/python/pants/rules/core/test.py b/src/python/pants/rules/core/test.py index 0883889b0fe..7f4c950beae 100644 --- a/src/python/pants/rules/core/test.py +++ b/src/python/pants/rules/core/test.py @@ -2,14 +2,16 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). import logging +from dataclasses import dataclass +from typing import Optional from pants.base.exiter import PANTS_FAILED_EXIT_CODE, PANTS_SUCCEEDED_EXIT_CODE -from pants.build_graph.address import Address +from pants.build_graph.address import Address, BuildFileAddress from pants.engine.addressable import BuildFileAddresses from pants.engine.console import Console from pants.engine.goal import Goal from pants.engine.legacy.graph import HydratedTarget -from pants.engine.rules import console_rule, rule +from pants.engine.rules import UnionMembership, console_rule, rule from pants.engine.selectors import Get from pants.rules.core.core_test_model import Status, TestResult, TestTarget @@ -24,18 +26,27 @@ class Test(Goal): name = 'test' +@dataclass(frozen=True) +class AddressAndTestResult: + address: BuildFileAddress + test_result: Optional[TestResult] # If None, target was not a test target. + + @console_rule def fast_test(console: Console, addresses: BuildFileAddresses) -> Test: - test_results = yield [Get(TestResult, Address, address.to_address()) for address in addresses] + results = yield [Get(AddressAndTestResult, Address, addr.to_address()) for addr in addresses] did_any_fail = False - for address, test_result in zip(addresses, test_results): + filtered_results = [(x.address, x.test_result) for x in results if x.test_result is not None] + + for address, test_result in filtered_results: if test_result.status == Status.FAILURE: did_any_fail = True if test_result.stdout: console.write_stdout( "{} stdout:\n{}\n".format( address.reference(), - console.red(test_result.stdout) if test_result.status == Status.FAILURE else test_result.stdout + (console.red(test_result.stdout) if test_result.status == Status.FAILURE + else test_result.stdout) ) ) if test_result.stderr: @@ -44,14 +55,16 @@ def fast_test(console: Console, addresses: BuildFileAddresses) -> Test: console.write_stdout( "{} stderr:\n{}\n".format( address.reference(), - console.red(test_result.stderr) if test_result.status == Status.FAILURE else test_result.stderr + (console.red(test_result.stderr) if test_result.status == Status.FAILURE + else test_result.stderr) ) ) console.write_stdout("\n") - for address, test_result in zip(addresses, test_results): - console.print_stdout('{0:80}.....{1:>10}'.format(address.reference(), test_result.status.value)) + for address, test_result in filtered_results: + console.print_stdout('{0:80}.....{1:>10}'.format( + address.reference(), test_result.status.value)) if did_any_fail: console.print_stderr(console.red('Tests failed')) @@ -63,19 +76,24 @@ def fast_test(console: Console, addresses: BuildFileAddresses) -> Test: @rule -def coordinator_of_tests(target: HydratedTarget) -> TestResult: +def coordinator_of_tests(target: HydratedTarget, + union_membership: UnionMembership) -> AddressAndTestResult: # TODO(#6004): when streaming to live TTY, rely on V2 UI for this information. When not a # live TTY, periodically dump heavy hitters to stderr. See # https://github.com/pantsbuild/pants/issues/6004#issuecomment-492699898. - logger.info("Starting tests: {}".format(target.address.reference())) - # NB: This has the effect of "casting" a TargetAdaptor to a member of the TestTarget union. If the - # TargetAdaptor is not a member of the union, it will fail at runtime with a useful error message. - result = yield Get(TestResult, TestTarget, target.adaptor) - logger.info("Tests {}: {}".format( - "succeeded" if result.status == Status.SUCCESS else "failed", - target.address.reference(), - )) - yield result + if union_membership.is_member(TestTarget, target.adaptor): + logger.info("Starting tests: {}".format(target.address.reference())) + # NB: This has the effect of "casting" a TargetAdaptor to a member of the TestTarget union. + # The adaptor will always be a member because of the union membership check above, but if + # it were not it would fail at runtime with a useful error message. + result = yield Get(TestResult, TestTarget, target.adaptor) + logger.info("Tests {}: {}".format( + "succeeded" if result.status == Status.SUCCESS else "failed", + target.address.reference(), + )) + else: + result = None # Not a test target. + yield AddressAndTestResult(target.address, result) def rules(): diff --git a/tests/python/pants_test/rules/test_test.py b/tests/python/pants_test/rules/test_test.py index 5ae5a309b93..c51635c26b3 100644 --- a/tests/python/pants_test/rules/test_test.py +++ b/tests/python/pants_test/rules/test_test.py @@ -7,7 +7,15 @@ from pants.build_graph.address import Address, BuildFileAddress from pants.engine.legacy.graph import HydratedTarget from pants.engine.legacy.structs import PythonTestsAdaptor -from pants.rules.core.test import Status, TestResult, coordinator_of_tests, fast_test +from pants.engine.rules import UnionMembership +from pants.rules.core.core_test_model import TestTarget +from pants.rules.core.test import ( + AddressAndTestResult, + Status, + TestResult, + coordinator_of_tests, + fast_test, +) from pants_test.engine.util import MockConsole, run_rule from pants_test.test_base import TestBase @@ -16,14 +24,16 @@ class TestTest(TestBase): def single_target_test(self, result, expected_console_output, success=True): console = MockConsole(use_colors=False) - res = run_rule(fast_test, console, (self.make_build_target_address("some/target"),), { - (TestResult, Address): lambda _: result, + addr = self.make_build_target_address("some/target") + res = run_rule(fast_test, console, (addr,), { + (AddressAndTestResult, Address): lambda _: AddressAndTestResult(addr, result), }) self.assertEquals(console.stdout.getvalue(), expected_console_output) self.assertEquals(0 if success else 1, res.exit_code) - def make_build_target_address(self, spec): + @staticmethod + def make_build_target_address(spec): address = Address.parse(spec) return BuildFileAddress( build_file=None, @@ -61,14 +71,15 @@ def test_output_mixed(self): def make_result(target): if target == target1: - return TestResult(status=Status.SUCCESS, stdout='I passed\n', stderr='') + tr = TestResult(status=Status.SUCCESS, stdout='I passed\n', stderr='') elif target == target2: - return TestResult(status=Status.FAILURE, stdout='I failed\n', stderr='') + tr = TestResult(status=Status.FAILURE, stdout='I failed\n', stderr='') else: raise Exception("Unrecognised target") + return AddressAndTestResult(target, tr) res = run_rule(fast_test, console, (target1, target2), { - (TestResult, Address): make_result, + (AddressAndTestResult, Address): make_result, }) self.assertEqual(1, res.exit_code) @@ -97,10 +108,19 @@ def test_stderr(self): ) def test_coordinator_python_test(self): + addr = Address.parse("some/target") target_adaptor = PythonTestsAdaptor(type_alias='python_tests') with self.captured_logging(logging.INFO): - result = run_rule(coordinator_of_tests, HydratedTarget(Address.parse("some/target"), target_adaptor, ()), { - (TestResult, PythonTestsAdaptor): lambda _: TestResult(status=Status.FAILURE, stdout='foo', stderr=''), - }) - - self.assertEqual(result, TestResult(status=Status.FAILURE, stdout='foo', stderr='')) + result = run_rule( + coordinator_of_tests, + HydratedTarget(addr, target_adaptor, ()), + UnionMembership(union_rules={TestTarget: [PythonTestsAdaptor]}), + { + (TestResult, PythonTestsAdaptor): + lambda _: TestResult(status=Status.FAILURE, stdout='foo', stderr=''), + }) + + self.assertEqual( + result, + AddressAndTestResult(addr, TestResult(status=Status.FAILURE, stdout='foo', stderr='')) + )