diff --git a/src/python/pants/build_graph/build_configuration.py b/src/python/pants/build_graph/build_configuration.py index 12a392b5b83..b5f491f1010 100644 --- a/src/python/pants/build_graph/build_configuration.py +++ b/src/python/pants/build_graph/build_configuration.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -@dataclass +@dataclass(frozen=True) class BuildConfiguration: """Stores the types and helper functions exposed to BUILD files.""" diff --git a/src/python/pants/engine/BUILD b/src/python/pants/engine/BUILD index a71a7c67383..0dd549a16a1 100644 --- a/src/python/pants/engine/BUILD +++ b/src/python/pants/engine/BUILD @@ -207,6 +207,7 @@ python_library( name='rules', sources=['rules.py'], dependencies=[ + '3rdparty/python:dataclasses', ':goal', ':selectors', 'src/python/pants/util:collections', diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index 9e1e4a073c2..56b9fb2588f 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -2,13 +2,14 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). import ast +import dataclasses import inspect import itertools import sys import typing from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Type, get_type_hints +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, get_type_hints from pants.engine.goal import Goal from pants.engine.objects import union @@ -356,6 +357,20 @@ def dependency_optionables(self): """A tuple of Optionable classes that are known to be necessary to run this rule.""" return () + @classmethod + def _validate_type_field(cls, type_obj, description): + if not isinstance(type_obj, type): + raise TypeError(f"{description} provided to @rules must be types! Was: {type_obj}.") + if dataclasses.is_dataclass(type_obj): + if not type_obj.__dataclass_params__.frozen: + if not frozen_after_init.is_instance(type_obj): + raise TypeError( + f"{description} {type_obj} is a dataclass declared without `frozen=True`, or without " + "both `unsafe_hash=True` and the `@frozen_after_init` decorator! " + "The engine requires that fields in params are immutable for stable hashing!" + ) + return type_obj + @frozen_after_init @dataclass(unsafe_hash=True) @@ -387,8 +402,15 @@ def __init__( cacheable: bool = True, name: Optional[str] = None, ): - self._output_type = output_type - self.input_selectors = input_selectors + self._output_type = self._validate_type_field(output_type, "@rule output type") + self.input_selectors = tuple( + self._validate_type_field(t, "@rule input selector") for t in input_selectors + ) + for g in input_gets: + product_type = g.product + subject_type = g.subject_declared_type + self._validate_type_field(product_type, "Get product type") + self._validate_type_field(subject_type, "Get subject type") self.input_gets = input_gets self.func = func # type: ignore[assignment] # cannot assign to a method self._dependency_rules = dependency_rules or () @@ -431,7 +453,7 @@ class RootRule(Rule): _output_type: Type def __init__(self, output_type: Type) -> None: - self._output_type = output_type + self._output_type = self._validate_type_field(output_type, "RootRule declared type") @property def output_type(self): diff --git a/src/python/pants/rules/core/list_target_types.py b/src/python/pants/rules/core/list_target_types.py index b27b3bf4cf1..a987dd04319 100644 --- a/src/python/pants/rules/core/list_target_types.py +++ b/src/python/pants/rules/core/list_target_types.py @@ -98,8 +98,8 @@ def create(cls, field: Type[Field]) -> "FieldInfo": fallback_to_ancestors=True, ignored_ancestors={ *Field.mro(), - AsyncField, - PrimitiveField, + *AsyncField.mro(), + *PrimitiveField.mro(), BoolField, FloatField, IntField, diff --git a/src/python/pants/util/meta.py b/src/python/pants/util/meta.py index ed0d7b79086..e7bc9eb19fa 100644 --- a/src/python/pants/util/meta.py +++ b/src/python/pants/util/meta.py @@ -1,6 +1,7 @@ # Copyright 2015 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import types from abc import ABC, abstractmethod from dataclasses import FrozenInstanceError from functools import wraps @@ -136,7 +137,14 @@ def __call__(self, cls: Type) -> Type: ... def define_instance_of(self, obj: Type, **kwargs) -> Type: - return type(obj.__name__, (obj,), {"_decorated_type_checkable_type": type(self), **kwargs}) + def update_class(ns): + ns["_decorated_type_checkable_type"] = type(self) + ns.update(**kwargs) + return ns + + bases = [obj, *getattr(obj, "__orig_bases__", [])] + + return types.new_class(obj.__name__, bases=tuple(bases), kwds=None, exec_body=update_class) def is_instance(self, obj: Type) -> bool: return getattr(obj, "_decorated_type_checkable_type", None) is type(self) @@ -193,4 +201,4 @@ def new_setattr(self, key: str, value: Any) -> None: cls.__init__ = new_init cls.__setattr__ = new_setattr - return cls + return frozen_after_init.define_instance_of(cls) diff --git a/src/python/pants/util/objects.py b/src/python/pants/util/objects.py index d821e47083e..8f441535547 100644 --- a/src/python/pants/util/objects.py +++ b/src/python/pants/util/objects.py @@ -46,6 +46,7 @@ def get_docstring( :param ignored_ancestors: if `fallback_to_ancestors` is True, do not use the docstring from these ancestors. """ + ignored_ancestors_set = frozenset(ignored_ancestors) if cls.__doc__ is not None: docstring = cls.__doc__.strip() else: @@ -56,7 +57,8 @@ def get_docstring( ( ancestor_cls.__doc__.strip() for ancestor_cls in cls.mro()[1:] - if ancestor_cls not in ignored_ancestors and ancestor_cls.__doc__ is not None + if ((ancestor_cls not in ignored_ancestors_set) and + (ancestor_cls.__doc__ is not None)) ), None, ) diff --git a/tests/python/pants_test/engine/test_scheduler.py b/tests/python/pants_test/engine/test_scheduler.py index 1005c2a2b39..4347431816d 100644 --- a/tests/python/pants_test/engine/test_scheduler.py +++ b/tests/python/pants_test/engine/test_scheduler.py @@ -20,6 +20,7 @@ remove_locations_from_traceback, ) from pants.testutil.test_base import TestBase +from pants.util.meta import frozen_after_init @dataclass(frozen=True) @@ -67,6 +68,43 @@ async def transitive_coroutine_rule(c: C) -> D: return D(b) +@dataclass +class NonFrozenDataclass: + x: int + + +@frozen_after_init +@dataclass(unsafe_hash=True) +class FrozenAfterInit: + x: int + + def __init__(self, x): + # This is an example of how you can assign within __init__() with @frozen_after_init. This + # particular example is not intended to be super useful. + self.x = x + 1 + + +@rule +def use_frozen_after_init_object(x: FrozenAfterInit) -> int: + return x.x + + +@dataclass(frozen=True) +class FrozenFieldsDataclass: + x: int + y: str + + +@dataclass(frozen=True) +class ResultDataclass: + something: str + + +@rule +def dataclass_rule(obj: FrozenFieldsDataclass) -> ResultDataclass: + return ResultDataclass(something=f"x={obj.x}, y={obj.y}") + + @union class UnionBase: pass @@ -178,6 +216,10 @@ def rules(cls): consumes_a_and_b, transitive_b_c, transitive_coroutine_rule, + dataclass_rule, + RootRule(FrozenAfterInit), + use_frozen_after_init_object, + RootRule(FrozenFieldsDataclass), RootRule(UnionWrapper), UnionRule(UnionBase, UnionA), UnionRule(UnionWithNonMemberErrorMsg, UnionX), @@ -239,6 +281,15 @@ def test_union_rules_no_docstring(self): with self._assert_execution_error("specific error message for UnionA instance"): self.request_single_product(UnionX, Params(UnionWrapper(UnionA()))) + def test_dataclass_products_rule(self): + (result,) = self.scheduler.product_request( + ResultDataclass, [Params(FrozenFieldsDataclass(3, "test string"))] + ) + self.assertEquals(result.something, "x=3, y=test string") + + (result,) = self.scheduler.product_request(int, [Params(FrozenAfterInit(x=3))]) + self.assertEquals(result, 4) + class SchedulerWithNestedRaiseTest(TestBase): @classmethod @@ -395,3 +446,39 @@ def test_trace_includes_rule_exception_traceback(self): + "\n\n", # Traces include two empty lines after. trace, ) + + +class RuleIndexingErrorTest(TestBase): + def test_non_frozen_dataclass_error(self): + with self.assertRaisesWithMessage( + TypeError, + dedent( + """\ + RootRule declared type is a dataclass declared without `frozen=True`, or without both `unsafe_hash=True` and the `@frozen_after_init` decorator! The engine requires that fields in params are immutable for stable hashing!""" + ), + ): + RootRule(NonFrozenDataclass) + + with self.assertRaisesWithMessage( + TypeError, + dedent( + """\ + @rule input selector is a dataclass declared without `frozen=True`, or without both `unsafe_hash=True` and the `@frozen_after_init` decorator! The engine requires that fields in params are immutable for stable hashing!""" + ), + ): + + @rule + def f(x: NonFrozenDataclass) -> int: + return 3 + + with self.assertRaisesWithMessage( + TypeError, + dedent( + """\ + @rule output type is a dataclass declared without `frozen=True`, or without both `unsafe_hash=True` and the `@frozen_after_init` decorator! The engine requires that fields in params are immutable for stable hashing!""" + ), + ): + + @rule + def g(x: int) -> NonFrozenDataclass: + return NonFrozenDataclass(x=x)