diff --git a/manim/__init__.py b/manim/__init__.py index a4034ed134..4b4acefcdf 100644 --- a/manim/__init__.py +++ b/manim/__init__.py @@ -75,6 +75,7 @@ from .mobject.value_tracker import * from .mobject.vector_field import * from .renderer.cairo_renderer import * +from .scene.groups import * from .scene.moving_camera_scene import * from .scene.scene import * from .scene.scene_file_writer import * diff --git a/manim/scene/groups.py b/manim/scene/groups.py new file mode 100644 index 0000000000..f76e5659e7 --- /dev/null +++ b/manim/scene/groups.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import types +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar, Generic, ParamSpec, TypeVar, final, overload + +from typing_extensions import Self, TypedDict, Unpack + +if TYPE_CHECKING: + from .scene import Scene + +__all__ = ["group"] + + +P = ParamSpec("P") +T = TypeVar("T") + + +class SectionGroupData(TypedDict, total=False): + """(Public) data for a :class:`.SectionGroup` in a :class:`.Scene`.""" + + skip: bool + order: int + + +# mark as final because _cls_instance_count doesn't +# work with inheritance +@final +class SectionGroup(Generic[P, T]): + """A section in a :class:`.Scene`. + + It holds data about each subsection, and keeps track of the order + of the sections via :attr:`~SectionGroup.order`. + + .. warning:: + + :attr:`~SectionGroup.func` is effectively a function - it is not + bound to the scene, and thus must be called with the first argument + as an instance of :class:`.Scene`. + """ + + _cls_instance_count: ClassVar[int] = 0 + """How many times the class has been instantiated. + + This is also used for ordering sections, because of the order + decorators are called in a class. + """ + + def __init__( + self, func: Callable[P, T], **kwargs: Unpack[SectionGroupData] + ) -> None: + self.func = func + + self.skip = kwargs.get("skip", False) + + # update the order counter + self.order = self._cls_instance_count + self.__class__._cls_instance_count += 1 + if "order" in kwargs: + self.order = kwargs["order"] + + def __str__(self) -> str: + skip = self.skip + order = self.order + return f"{self.__class__.__name__}({order=}, {skip=})" + + def __repr__(self) -> str: + # return a slightly more verbose repr + s = str(self).removesuffix(")") + func = self.func + return f"{s}, {func=})" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self.func(*args, **kwargs) + + def bind(self, instance: Scene) -> Self: + """Binds :attr:`func` to the scene instance, making :attr:`func` a method. + + This allows the section to be called without the scene being passed explicitly. + """ + self.func = types.MethodType(self.func, instance) + return self + + def __get__(self, instance: Scene, _owner: type[Scene]) -> Self: + """Descriptor to bind the section to the scene instance. + + This is called implicitly by python when methods are being bound. + """ + return self # HELPME use binding + return self.bind(instance) + + +@overload +def group( + func: Callable[P, T], + **kwargs: Unpack[SectionGroupData], +) -> SectionGroup[P, T]: ... + + +@overload +def group( + func: None = None, + **kwargs: Unpack[SectionGroupData], +) -> Callable[[Callable[P, T]], SectionGroup[P, T]]: ... + + +def group( + func: Callable[P, T] | None = None, **kwargs: Unpack[SectionGroupData] +) -> SectionGroup[P, T] | Callable[[Callable[P, T]], SectionGroup[P, T]]: + r"""Decorator to create a SectionGroup in the scene. + + Example + ------- + + .. code-block:: python + + class MyScene(Scene): + SectionGroups_api = True + + @SectionGroup + def first_SectionGroup(self): + pass + + @SectionGroup(skip=True) + def second_SectionGroup(self): + pass + + Parameters + ---------- + func : Callable + The subsection. + skip : bool, optional + Whether to skip the section, by default False + """ + + def wrapper(func: Callable[P, T]) -> SectionGroup[P, T]: + return SectionGroup(func, **kwargs) + + return wrapper(func) if func is not None else wrapper diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 02c548cf7f..8be7e0450a 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -50,6 +50,7 @@ from ..utils.family_ops import restructure_list_to_exclude_certain_family_members from ..utils.file_ops import open_media_file from ..utils.iterables import list_difference_update, list_update +from .groups import SectionGroup if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -98,6 +99,11 @@ def construct(self): """ + groups_api = False + section_groups = [] + """ Internal attributes to allow group decorator in the class. + TODO Document groups """ + def __init__( self, renderer=None, @@ -110,6 +116,7 @@ def __init__( self.always_update_mobjects = always_update_mobjects self.random_seed = random_seed self.skip_animations = skip_animations + self.group_skip_animations = False # group animation are played by default self.animations = None self.stop_condition = None @@ -154,6 +161,13 @@ def __init__( random.seed(self.random_seed) np.random.seed(self.random_seed) + self.section_groups = self.build_section_groups() + for group in self.section_groups: + if not isinstance(group, SectionGroup): + raise AttributeError( + f"The method {group} doesn't look like it is decorated with the @group decorator." + ) + @property def camera(self): return self.renderer.camera @@ -303,7 +317,18 @@ def construct(self): :meth:`Scene.tear_down` """ - pass # To be implemented in subclasses + for ( + group + ) in self.section_groups: # this is empty if section groups are disabled + self.group_skip_animations = group.skip + + self.next_section( + group.skip + ) # create a default section at the start of each group + group(self) # launch the group # HELPME make a clean call + + self.group_skip_animations = False + # To be implemented in subclasses if groups API is disabled def next_section( self, @@ -315,8 +340,43 @@ def next_section( ``skip_animations`` skips the rendering of all animations in this section. Refer to :doc:`the documentation` on how to use sections. """ + # if group is disabled, all sections in it are also disabled + skip_animations = skip_animations or self.group_skip_animations self.renderer.file_writer.next_section(name, section_type, skip_animations) + def build_section_groups(self) -> List[SectionGroup]: + """Builds the group list depending on the API used (method list, enabled, disabled).""" + if self.section_groups: + # if a group list is provided we use it by default + def get_group_object(group): + if hasattr(self, group): + return getattr(self, group) + else: + raise AttributeError( + f"Couldn't find method {group} in class {__cls__}. Did you spell it correctly?" + ) + + return [get_group_object(group) for group in self.section_groups] + elif self.groups_api: + # groups api enabled, but no list provided so we have to look at the decorated groups in order + return self.find_groups() + else: + # groups api disabled + return [] + + def find_groups(self) -> list[SectionGroup]: + """Find all groups in a :class:`.Scene` if groups api is turned on.""" + groups: list[SectionGroup] = [ + bound + for _, bound in inspect.getmembers( + self, predicate=lambda x: isinstance(x, SectionGroup) + ) + ] + # we can't care about the actual value of the order + # because that would break files with multiple scenes that have sections + groups.sort(key=lambda x: x.order) + return groups + def __str__(self): return self.__class__.__name__ diff --git a/tests/test_scene_rendering/simple_scenes.py b/tests/test_scene_rendering/simple_scenes.py index 7f626ffbe2..62938cd77c 100644 --- a/tests/test_scene_rendering/simple_scenes.py +++ b/tests/test_scene_rendering/simple_scenes.py @@ -17,6 +17,8 @@ "InteractiveStaticScene", "SceneWithSections", "ElaborateSceneWithSections", + "SceneWithGroupAPI", + "SceneWithGroupList", ] @@ -165,3 +167,39 @@ def construct(self): self.next_section("fade out") self.play(FadeOut(square)) self.wait() + + +class SceneWithGroupAPI(Scene): + groups_api = True + + def __init__(self): + super().__init__() + + self.square = Square() + self.circle = Circle() + + @group + def transform(self): + self.play(TransformFromCopy(self.square, self.circle)) + + @group + def back_transform(self): + self.play(Transform(self.circle, self.square)) + + +class SceneWithGroupList(Scene): + section_groups = ["transform", "back_transform"] + + def __init__(self): + super().__init__() + + self.square = Square() + self.circle = Circle() + + @group + def back_transform(self): + self.play(Transform(self.circle, self.square)) + + @group + def transform(self): + self.play(TransformFromCopy(self.square, self.circle)) diff --git a/tests/test_scene_rendering/test_sections.py b/tests/test_scene_rendering/test_sections.py index 049802f3cb..155a9ac3c2 100644 --- a/tests/test_scene_rendering/test_sections.py +++ b/tests/test_scene_rendering/test_sections.py @@ -8,6 +8,7 @@ from tests.assert_utils import assert_dir_exists, assert_dir_not_exists from ..utils.video_tester import video_comparison +from .simple_scenes import SceneWithGroupAPI, SceneWithGroupList, SquareToCircle @pytest.mark.slow @@ -103,3 +104,21 @@ def test_skip_animations(tmp_path, manim_cfg_file, simple_scenes_path): ] _, err, exit_code = capture(command) assert exit_code == 0, err + + +def test_groups_api(tmp_path): + find_api_scene = SceneWithGroupAPI() + list_api_scene = SceneWithGroupList() + + assert not SquareToCircle().section_groups + assert len(list_api_scene.section_groups) == len(find_api_scene.section_groups) == 2 + assert ( + list_api_scene.section_groups[0].func.__name__ + == find_api_scene.section_groups[0].func.__name__ + == "transform" + ) + assert ( + list_api_scene.section_groups[1].func.__name__ + == find_api_scene.section_groups[1].func.__name__ + == "back_transform" + )