Skip to content

Commit aeb5dcc

Browse files
authored
✨ Support const=True to Literal[T] (#41)
1 parent a4a7c7d commit aeb5dcc

File tree

3 files changed

+54
-4
lines changed

3 files changed

+54
-4
lines changed

bump_pydantic/codemods/field.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List
1+
from typing import List, Union
22

33
import libcst as cst
44
from libcst import matchers as m
55
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
6-
from libcst.codemod.visitors import AddImportsVisitor
6+
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
77

88
RENAMED_KEYWORDS = {
99
"min_items": "min_length",
@@ -63,11 +63,33 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module
6363
@m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
6464
def visit_field_assign(self, node: cst.AnnAssign) -> None:
6565
self.inside_field_assign = True
66+
self._const: Union[cst.Arg, None] = None
6667

6768
@m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
6869
def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
6970
self.inside_field_assign = False
70-
return updated_node
71+
72+
if self._const is None:
73+
return updated_node
74+
75+
AddImportsVisitor.add_needed_import(self.context, "typing", "Literal")
76+
RemoveImportsVisitor.remove_unused_import(self.context, "pydantic", "Field")
77+
return updated_node.with_changes(
78+
annotation=cst.Annotation(
79+
annotation=cst.Subscript(
80+
value=cst.Name("Literal"),
81+
slice=[cst.SubscriptElement(slice=cst.Index(value=self._const.value))],
82+
)
83+
),
84+
value=self._const.value,
85+
)
86+
87+
@m.visit(m.Call(func=m.Name("Field")))
88+
def visit_field_call(self, node: cst.Call) -> None:
89+
# Check if there's a `const=True` argument.
90+
const_arg = m.Arg(value=m.Name("True"), keyword=m.Name("const"))
91+
if m.matches(node, m.Call(func=m.Name("Field"), args=[~m.Arg(value=m.Name("...")), const_arg])):
92+
self._const = node.args[0]
7193

7294
@m.leave(m.Call(func=m.Name("Field")))
7395
def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:

bump_pydantic/codemods/replace_generic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def leave_generic_model(self, original_node: cst.ClassDef, updated_node: cst.Cla
1818
AddImportsVisitor.add_needed_import(context=self.context, module="pydantic", obj="BaseModel")
1919
return updated_node.with_changes(
2020
bases=[
21-
base if not m.matches(base, GENERIC_MODEL_ARG) else cst.Arg(value=cst.Name("BaseModel"))
21+
cst.Arg(value=cst.Name("BaseModel")) if m.matches(base, GENERIC_MODEL_ARG) else base
2222
for base in updated_node.bases
2323
]
2424
)

tests/unit/test_field.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,31 @@ class Potato(BaseModel):
7474
potato: int = Field(..., env="POTATO")
7575
"""
7676
self.assertCodemod(code, code)
77+
78+
def test_replace_const_by_literal_type(self) -> None:
79+
before = """
80+
from enum import Enum
81+
82+
from pydantic import BaseModel, Field
83+
84+
85+
class MyEnum(Enum):
86+
POTATO = "potato"
87+
88+
class Potato(BaseModel):
89+
potato: MyEnum = Field(MyEnum.POTATO, const=True)
90+
"""
91+
after = """
92+
from enum import Enum
93+
94+
from pydantic import BaseModel
95+
from typing import Literal
96+
97+
98+
class MyEnum(Enum):
99+
POTATO = "potato"
100+
101+
class Potato(BaseModel):
102+
potato: Literal[MyEnum.POTATO] = MyEnum.POTATO
103+
"""
104+
self.assertCodemod(before, after)

0 commit comments

Comments
 (0)