Skip to content

Commit 99c3241

Browse files
feat(VGroup): Make VGroup Generic in VMobjectT
1 parent 410e76f commit 99c3241

File tree

15 files changed

+74
-59
lines changed

15 files changed

+74
-59
lines changed

docs/source/contributing/docs/typings.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ Typing guidelines
141141
from manim.typing import Vector3D
142142
# type stuff with Vector3D
143143
144+
* When typing something like ``VGroup``, type it as if it were a list, not as if it was a tuple.
145+
146+
.. code:: py
147+
# not VGroup[Line, Line]
148+
def get_two_lines() -> VGroup[Line]:
149+
return VGroup(Line(), Line().shift(LEFT))
150+
151+
144152
Missing Sections for typehints are:
145153
-----------------------------------
146154

manim/animation/changing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Callable
88

99
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
10-
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
10+
from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
1111
from manim.utils.color import (
1212
BLUE_B,
1313
BLUE_D,
@@ -19,7 +19,7 @@
1919
from manim.utils.rate_functions import smooth
2020

2121

22-
class AnimatedBoundary(VGroup):
22+
class AnimatedBoundary(VGroup[VMobjectT]):
2323
"""Boundary of a :class:`.VMobject` with animated color change.
2424
2525
Examples
@@ -38,11 +38,11 @@ def construct(self):
3838

3939
def __init__(
4040
self,
41-
vmobject,
42-
colors=[BLUE_D, BLUE_B, BLUE_E, GREY_BROWN],
43-
max_stroke_width=3,
44-
cycle_rate=0.5,
45-
back_and_forth=True,
41+
vmobject: VMobjectT,
42+
colors: list[ParsableManimColor] = [BLUE_D, BLUE_B, BLUE_E, GREY_BROWN],
43+
max_stroke_width: float = 3,
44+
cycle_rate: float = 0.5,
45+
back_and_forth: bool = True,
4646
draw_rate_func=smooth,
4747
fade_rate_func=smooth,
4848
**kwargs,
@@ -60,7 +60,7 @@ def __init__(
6060
]
6161
self.add(*self.boundary_copies)
6262
self.total_time = 0
63-
self.add_updater(lambda m, dt: self.update_boundary_copies(dt))
63+
self.add_updater(lambda _, dt: self.update_boundary_copies(dt))
6464

6565
def update_boundary_copies(self, dt):
6666
# Not actual time, but something which passes at

manim/mobject/geometry/shape_matchers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
__all__ = ["SurroundingRectangle", "BackgroundRectangle", "Cross", "Underline"]
66

7-
from typing import Any
8-
9-
from typing_extensions import Self
7+
from typing import TYPE_CHECKING, Any
108

119
from manim import config, logger
1210
from manim.constants import *
@@ -16,6 +14,9 @@
1614
from manim.mobject.types.vectorized_mobject import VGroup
1715
from manim.utils.color import BLACK, RED, YELLOW, ManimColor, ParsableManimColor
1816

17+
if TYPE_CHECKING:
18+
from typing_extensions import Self
19+
1920

2021
class SurroundingRectangle(RoundedRectangle):
2122
r"""A rectangle surrounding a :class:`~.Mobject`
@@ -133,7 +134,7 @@ def get_fill_color(self) -> ManimColor:
133134
return self.color
134135

135136

136-
class Cross(VGroup):
137+
class Cross(VGroup[Line]):
137138
"""Creates a cross.
138139
139140
Parameters
@@ -166,9 +167,7 @@ def __init__(
166167
scale_factor: float = 1.0,
167168
**kwargs,
168169
) -> None:
169-
super().__init__(
170-
Line(UP + LEFT, DOWN + RIGHT), Line(UP + RIGHT, DOWN + LEFT), **kwargs
171-
)
170+
super().__init__(Line(UL, DR), Line(UR, DL), **kwargs)
172171
if mobject is not None:
173172
self.replace(mobject, stretch=True)
174173
self.scale(scale_factor)

manim/mobject/logo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..animation.creation import Create, SpiralIn
1717
from ..animation.fading import FadeIn
1818
from ..mobject.svg.svg_mobject import VMobjectFromSVGPath
19-
from ..mobject.types.vectorized_mobject import VGroup
19+
from ..mobject.types.vectorized_mobject import VGroup, VMobjectT
2020
from ..utils.rate_functions import ease_in_out_cubic, smooth
2121

2222
MANIM_SVG_PATHS: list[se.Path] = [
@@ -100,7 +100,7 @@
100100
]
101101

102102

103-
class ManimBanner(VGroup):
103+
class ManimBanner(VGroup[VMobjectT]):
104104
r"""Convenience class representing Manim's banner.
105105
106106
Can be animated using custom methods.

manim/mobject/mobject.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,8 +2979,6 @@ def set_z_index_by_z_Point3D(self) -> Self:
29792979
self.set_z_index(z_coord)
29802980
return self
29812981

2982-
def __class_getitem__(cls, item: type) -> str:
2983-
return f"{cls.__name__}[{item.__name__}]"
29842982

29852983
class Group(Mobject, metaclass=ConvertToOpenGL):
29862984
"""Groups together multiple :class:`Mobjects <.Mobject>`.

manim/mobject/table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ def construct(self):
7878
from ..animation.composition import AnimationGroup
7979
from ..animation.creation import Create, Write
8080
from ..animation.fading import FadeIn
81-
from ..mobject.types.vectorized_mobject import VGroup, VMobject
81+
from ..mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
8282
from ..utils.color import BLACK, YELLOW, ManimColor, ParsableManimColor
8383
from .utils import get_vectorized_mobject_class
8484

8585

86-
class Table(VGroup):
86+
class Table(VGroup[VMobjectT]):
8787
"""A mobject that displays a table on the screen.
8888
8989
Parameters

manim/mobject/text/code_mobject.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
from manim.mobject.geometry.polygram import RoundedRectangle
2424
from manim.mobject.geometry.shape_matchers import SurroundingRectangle
2525
from manim.mobject.text.text_mobject import Paragraph
26-
from manim.mobject.types.vectorized_mobject import VGroup
26+
from manim.mobject.types.vectorized_mobject import VGroup, VMobjectT
2727
from manim.utils.color import WHITE
2828

2929
__all__ = ["Code"]
3030

3131

32-
class Code(VGroup):
32+
class Code(VGroup[VMobjectT]):
3333
"""A highlighted source code listing.
3434
3535
An object ``listing`` of :class:`.Code` is a :class:`.VGroup` consisting

manim/mobject/text/text_mobject.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def construct(self):
7171
from manim.constants import *
7272
from manim.mobject.geometry.arc import Dot
7373
from manim.mobject.svg.svg_mobject import SVGMobject
74-
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
74+
from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
7575
from manim.utils.color import ManimColor, ParsableManimColor, color_gradient
7676
from manim.utils.deprecation import deprecated
7777

@@ -116,7 +116,7 @@ def remove_invisible_chars(mobject: SVGMobject) -> SVGMobject:
116116
return mobject_without_dots
117117

118118

119-
class Paragraph(VGroup):
119+
class Paragraph(VGroup[VMobjectT]):
120120
r"""Display a paragraph of text.
121121
122122
For a given :class:`.Paragraph` ``par``, the attribute ``par.chars`` is a

manim/mobject/three_d/polyhedra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from manim.mobject.geometry.polygram import Polygon
1010
from manim.mobject.graph import Graph
1111
from manim.mobject.three_d.three_dimensions import Dot3D
12-
from manim.mobject.types.vectorized_mobject import VGroup
12+
from manim.mobject.types.vectorized_mobject import VGroup, VMobjectT
1313

1414
if TYPE_CHECKING:
1515
from manim.mobject.mobject import Mobject
1616

1717
__all__ = ["Polyhedron", "Tetrahedron", "Octahedron", "Icosahedron", "Dodecahedron"]
1818

1919

20-
class Polyhedron(VGroup):
20+
class Polyhedron(VGroup[VMobjectT]):
2121
"""An abstract polyhedra class.
2222
2323
In this implementation, polyhedra are defined with a list of vertex coordinates in space, and a list

manim/mobject/three_d/three_dimensions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from manim.mobject.mobject import *
3232
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
3333
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
34-
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
34+
from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
3535
from manim.utils.color import (
3636
BLUE,
3737
BLUE_D,
@@ -463,7 +463,7 @@ def __init__(
463463
self.set_color(color)
464464

465465

466-
class Cube(VGroup):
466+
class Cube(VGroup[VMobjectT]):
467467
"""A three-dimensional cube.
468468
469469
Parameters

manim/mobject/types/vectorized_mobject.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TYPE_CHECKING,
1919
Callable,
2020
Generator,
21+
Generic,
2122
Hashable,
2223
Iterable,
2324
Literal,
@@ -522,7 +523,7 @@ def get_fill_colors(self) -> list[ManimColor | None]:
522523
def get_fill_opacities(self) -> npt.NDArray[ManimFloat]:
523524
return self.get_fill_rgbas()[:, 3]
524525

525-
def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_float | Zeros:
526+
def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_Float | Zeros:
526527
try:
527528
if background:
528529
self.background_stroke_rgbas: RGBA_Array_Float
@@ -1934,7 +1935,10 @@ def force_direction(self, target_direction: Literal["CW", "CCW"]) -> Self:
19341935
return self
19351936

19361937

1937-
class VGroup(VMobject, metaclass=ConvertToOpenGL):
1938+
VMobjectT = TypeVar("VMobjectT", bound=VMobject, default=VMobject)
1939+
1940+
1941+
class VGroup(VMobject, Generic[VMobjectT], metaclass=ConvertToOpenGL):
19381942
"""A group of vectorized mobjects.
19391943
19401944
This can be used to group multiple :class:`~.VMobject` instances together
@@ -1991,7 +1995,7 @@ def construct(self):
19911995
19921996
"""
19931997

1994-
def __init__(self, *vmobjects, **kwargs):
1998+
def __init__(self, *vmobjects: VMobjectT, **kwargs):
19951999
super().__init__(**kwargs)
19962000
self.add(*vmobjects)
19972001

@@ -2478,7 +2482,7 @@ def set_location(self, new_loc: Point3D):
24782482
self.set_points(np.array([new_loc]))
24792483

24802484

2481-
class CurvesAsSubmobjects(VGroup):
2485+
class CurvesAsSubmobjects(VGroup[VMobject]):
24822486
"""Convert a curve's elements to submobjects.
24832487
24842488
Examples

0 commit comments

Comments
 (0)