Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert dataclass engine params (#8540)" #8548

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/python/pants/build_graph/build_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


@dataclass
@dataclass(frozen=True)
class BuildConfiguration:
"""Stores the types and helper functions exposed to BUILD files."""

Expand Down
1 change: 1 addition & 0 deletions src/python/pants/engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ python_library(
name='rules',
sources=['rules.py'],
dependencies=[
'3rdparty/python:dataclasses',
':goal',
':selectors',
'src/python/pants/util:collections',
Expand Down
30 changes: 26 additions & 4 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import ast
import dataclasses
import inspect
import itertools
import sys
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Type, get_type_hints
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, get_type_hints

from pants.engine.goal import Goal
from pants.engine.objects import union
Expand Down Expand Up @@ -356,6 +357,20 @@ def dependency_optionables(self):
"""A tuple of Optionable classes that are known to be necessary to run this rule."""
return ()

@classmethod
def _validate_type_field(cls, type_obj, description):
if not isinstance(type_obj, type):
raise TypeError(f"{description} provided to @rules must be types! Was: {type_obj}.")
if dataclasses.is_dataclass(type_obj):
if not type_obj.__dataclass_params__.frozen:
if not frozen_after_init.is_instance(type_obj):
raise TypeError(
f"{description} {type_obj} is a dataclass declared without `frozen=True`, or without "
"both `unsafe_hash=True` and the `@frozen_after_init` decorator! "
"The engine requires that fields in params are immutable for stable hashing!"
)
return type_obj


@frozen_after_init
@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -387,8 +402,15 @@ def __init__(
cacheable: bool = True,
name: Optional[str] = None,
):
self._output_type = output_type
self.input_selectors = input_selectors
self._output_type = self._validate_type_field(output_type, "@rule output type")
self.input_selectors = tuple(
self._validate_type_field(t, "@rule input selector") for t in input_selectors
)
for g in input_gets:
product_type = g.product
subject_type = g.subject_declared_type
self._validate_type_field(product_type, "Get product type")
self._validate_type_field(subject_type, "Get subject type")
self.input_gets = input_gets
self.func = func # type: ignore[assignment] # cannot assign to a method
self._dependency_rules = dependency_rules or ()
Expand Down Expand Up @@ -431,7 +453,7 @@ class RootRule(Rule):
_output_type: Type

def __init__(self, output_type: Type) -> None:
self._output_type = output_type
self._output_type = self._validate_type_field(output_type, "RootRule declared type")

@property
def output_type(self):
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/rules/core/list_target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def create(cls, field: Type[Field]) -> "FieldInfo":
fallback_to_ancestors=True,
ignored_ancestors={
*Field.mro(),
AsyncField,
PrimitiveField,
*AsyncField.mro(),
*PrimitiveField.mro(),
BoolField,
FloatField,
IntField,
Expand Down
12 changes: 10 additions & 2 deletions src/python/pants/util/meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import types
from abc import ABC, abstractmethod
from dataclasses import FrozenInstanceError
from functools import wraps
Expand Down Expand Up @@ -136,7 +137,14 @@ def __call__(self, cls: Type) -> Type:
...

def define_instance_of(self, obj: Type, **kwargs) -> Type:
return type(obj.__name__, (obj,), {"_decorated_type_checkable_type": type(self), **kwargs})
def update_class(ns):
ns["_decorated_type_checkable_type"] = type(self)
ns.update(**kwargs)
return ns

bases = [obj, *getattr(obj, "__orig_bases__", [])]

return types.new_class(obj.__name__, bases=tuple(bases), kwds=None, exec_body=update_class)

def is_instance(self, obj: Type) -> bool:
return getattr(obj, "_decorated_type_checkable_type", None) is type(self)
Expand Down Expand Up @@ -193,4 +201,4 @@ def new_setattr(self, key: str, value: Any) -> None:
cls.__init__ = new_init
cls.__setattr__ = new_setattr

return cls
return frozen_after_init.define_instance_of(cls)
4 changes: 3 additions & 1 deletion src/python/pants/util/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_docstring(
:param ignored_ancestors: if `fallback_to_ancestors` is True, do not use the docstring from
these ancestors.
"""
ignored_ancestors_set = frozenset(ignored_ancestors)
if cls.__doc__ is not None:
docstring = cls.__doc__.strip()
else:
Expand All @@ -56,7 +57,8 @@ def get_docstring(
(
ancestor_cls.__doc__.strip()
for ancestor_cls in cls.mro()[1:]
if ancestor_cls not in ignored_ancestors and ancestor_cls.__doc__ is not None
if ((ancestor_cls not in ignored_ancestors_set) and
(ancestor_cls.__doc__ is not None))
),
None,
)
Expand Down
87 changes: 87 additions & 0 deletions tests/python/pants_test/engine/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
remove_locations_from_traceback,
)
from pants.testutil.test_base import TestBase
from pants.util.meta import frozen_after_init


@dataclass(frozen=True)
Expand Down Expand Up @@ -67,6 +68,43 @@ async def transitive_coroutine_rule(c: C) -> D:
return D(b)


@dataclass
class NonFrozenDataclass:
x: int


@frozen_after_init
@dataclass(unsafe_hash=True)
class FrozenAfterInit:
x: int

def __init__(self, x):
# This is an example of how you can assign within __init__() with @frozen_after_init. This
# particular example is not intended to be super useful.
self.x = x + 1


@rule
def use_frozen_after_init_object(x: FrozenAfterInit) -> int:
return x.x


@dataclass(frozen=True)
class FrozenFieldsDataclass:
x: int
y: str


@dataclass(frozen=True)
class ResultDataclass:
something: str


@rule
def dataclass_rule(obj: FrozenFieldsDataclass) -> ResultDataclass:
return ResultDataclass(something=f"x={obj.x}, y={obj.y}")


@union
class UnionBase:
pass
Expand Down Expand Up @@ -178,6 +216,10 @@ def rules(cls):
consumes_a_and_b,
transitive_b_c,
transitive_coroutine_rule,
dataclass_rule,
RootRule(FrozenAfterInit),
use_frozen_after_init_object,
RootRule(FrozenFieldsDataclass),
RootRule(UnionWrapper),
UnionRule(UnionBase, UnionA),
UnionRule(UnionWithNonMemberErrorMsg, UnionX),
Expand Down Expand Up @@ -239,6 +281,15 @@ def test_union_rules_no_docstring(self):
with self._assert_execution_error("specific error message for UnionA instance"):
self.request_single_product(UnionX, Params(UnionWrapper(UnionA())))

def test_dataclass_products_rule(self):
(result,) = self.scheduler.product_request(
ResultDataclass, [Params(FrozenFieldsDataclass(3, "test string"))]
)
self.assertEquals(result.something, "x=3, y=test string")

(result,) = self.scheduler.product_request(int, [Params(FrozenAfterInit(x=3))])
self.assertEquals(result, 4)


class SchedulerWithNestedRaiseTest(TestBase):
@classmethod
Expand Down Expand Up @@ -395,3 +446,39 @@ def test_trace_includes_rule_exception_traceback(self):
+ "\n\n", # Traces include two empty lines after.
trace,
)


class RuleIndexingErrorTest(TestBase):
def test_non_frozen_dataclass_error(self):
with self.assertRaisesWithMessage(
TypeError,
dedent(
"""\
RootRule declared type <class 'pants_test.engine.test_scheduler.NonFrozenDataclass'> is a dataclass declared without `frozen=True`, or without both `unsafe_hash=True` and the `@frozen_after_init` decorator! The engine requires that fields in params are immutable for stable hashing!"""
),
):
RootRule(NonFrozenDataclass)

with self.assertRaisesWithMessage(
TypeError,
dedent(
"""\
@rule input selector <class 'pants_test.engine.test_scheduler.NonFrozenDataclass'> is a dataclass declared without `frozen=True`, or without both `unsafe_hash=True` and the `@frozen_after_init` decorator! The engine requires that fields in params are immutable for stable hashing!"""
),
):

@rule
def f(x: NonFrozenDataclass) -> int:
return 3

with self.assertRaisesWithMessage(
TypeError,
dedent(
"""\
@rule output type <class 'pants_test.engine.test_scheduler.NonFrozenDataclass'> is a dataclass declared without `frozen=True`, or without both `unsafe_hash=True` and the `@frozen_after_init` decorator! The engine requires that fields in params are immutable for stable hashing!"""
),
):

@rule
def g(x: int) -> NonFrozenDataclass:
return NonFrozenDataclass(x=x)