Skip to content

Commit a7b399f

Browse files
authored
✨ Add None default value to fields in Optional[T] = Field() (#72)
1 parent e081fca commit a7b399f

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

bump_pydantic/codemods/add_default_none.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
5656
self.inside_base_model = False
5757
return updated_node
5858

59-
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
59+
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
6060
if m.matches(
6161
node.annotation.annotation,
6262
m.Subscript(m.Name("Optional") | m.Attribute(m.Name("typing"), m.Name("Optional")))
@@ -75,11 +75,29 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
7575
| m.BinaryOperation(operator=m.BitOr(), right=m.Name("None")),
7676
):
7777
self.should_add_none = True
78-
return super().visit_AnnAssign(node)
78+
return None
7979

8080
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
81-
if self.inside_base_model and self.should_add_none and updated_node.value is None:
82-
updated_node = updated_node.with_changes(value=cst.Name("None"))
81+
if self.inside_base_model and self.should_add_none:
82+
if updated_node.value is None:
83+
updated_node = updated_node.with_changes(value=cst.Name("None"))
84+
# TODO: Should accept `pydantic.Field` as well.
85+
elif m.matches(updated_node.value, m.Call(func=m.Name("Field"))):
86+
assert isinstance(updated_node.value, cst.Call)
87+
if updated_node.value.args:
88+
arg = updated_node.value.args[0]
89+
if (arg.keyword is None or arg.keyword.value == "default") and m.matches(arg.value, m.Ellipsis()):
90+
updated_node = updated_node.with_changes(
91+
value=updated_node.value.with_changes(
92+
args=[arg.with_changes(value=cst.Name("None")), *updated_node.value.args[1:]]
93+
)
94+
)
95+
# This is the case where `Field` is called without any arguments e.g. `Field()`.
96+
else:
97+
updated_node = updated_node.with_changes(
98+
value=updated_node.value.with_changes(args=[cst.Arg(value=cst.Name("None"))]) # type: ignore
99+
)
100+
83101
self.inside_an_assign = False
84102
self.should_add_none = False
85103
return updated_node

tests/integration/cases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from .add_none import cases as add_none_cases
55
from .base_settings import cases as base_settings_cases
66
from .config_to_model import cases as config_to_model_cases
7+
from .field import cases as generic_model_cases
78
from .folder_inside_folder import cases as folder_inside_folder_cases
8-
from .generic_model import cases as generic_model_cases
99
from .is_base_model import cases as is_base_model_cases
1010
from .replace_validator import cases as replace_validator_cases
1111
from .root_model import cases as root_model_cases

tests/integration/cases/add_none.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
content=[
1010
"from typing import Any, Dict, Optional, Union",
1111
"",
12-
"from pydantic import BaseModel",
12+
"from pydantic import BaseModel, Field",
1313
"",
1414
"",
1515
"class A(BaseModel):",
@@ -18,14 +18,18 @@
1818
" c: Union[int, None]",
1919
" d: Any",
2020
" e: Dict[str, str]",
21+
" f: Optional[int] = Field(..., lt=10)",
22+
" g: Optional[int] = Field()",
23+
" h: Optional[int] = Field(...)",
24+
" i: Optional[int] = Field(default_factory=lambda: None)",
2125
],
2226
),
2327
expected=File(
2428
"add_none.py",
2529
content=[
2630
"from typing import Any, Dict, Optional, Union",
2731
"",
28-
"from pydantic import BaseModel",
32+
"from pydantic import BaseModel, Field",
2933
"",
3034
"",
3135
"class A(BaseModel):",
@@ -34,6 +38,10 @@
3438
" c: Union[int, None] = None",
3539
" d: Any = None",
3640
" e: Dict[str, str]",
41+
" f: Optional[int] = Field(None, lt=10)",
42+
" g: Optional[int] = Field(None)",
43+
" h: Optional[int] = Field(None)",
44+
" i: Optional[int] = Field(default_factory=lambda: None)",
3745
],
3846
),
3947
)

0 commit comments

Comments
 (0)