|
1 |
| -from typing import List |
| 1 | +from typing import List, Union |
2 | 2 |
|
3 | 3 | import libcst as cst
|
4 | 4 | from libcst import matchers as m
|
5 | 5 | from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
|
6 |
| -from libcst.codemod.visitors import AddImportsVisitor |
| 6 | +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor |
7 | 7 |
|
8 | 8 | RENAMED_KEYWORDS = {
|
9 | 9 | "min_items": "min_length",
|
@@ -63,11 +63,33 @@ def leave_field_import(self, original_node: cst.Module, updated_node: cst.Module
|
63 | 63 | @m.visit(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
|
64 | 64 | def visit_field_assign(self, node: cst.AnnAssign) -> None:
|
65 | 65 | self.inside_field_assign = True
|
| 66 | + self._const: Union[cst.Arg, None] = None |
66 | 67 |
|
67 | 68 | @m.leave(m.AnnAssign(value=m.Call(func=m.Name("Field"))))
|
68 | 69 | def leave_field_assign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign:
|
69 | 70 | self.inside_field_assign = False
|
70 |
| - return updated_node |
| 71 | + |
| 72 | + if self._const is None: |
| 73 | + return updated_node |
| 74 | + |
| 75 | + AddImportsVisitor.add_needed_import(self.context, "typing", "Literal") |
| 76 | + RemoveImportsVisitor.remove_unused_import(self.context, "pydantic", "Field") |
| 77 | + return updated_node.with_changes( |
| 78 | + annotation=cst.Annotation( |
| 79 | + annotation=cst.Subscript( |
| 80 | + value=cst.Name("Literal"), |
| 81 | + slice=[cst.SubscriptElement(slice=cst.Index(value=self._const.value))], |
| 82 | + ) |
| 83 | + ), |
| 84 | + value=self._const.value, |
| 85 | + ) |
| 86 | + |
| 87 | + @m.visit(m.Call(func=m.Name("Field"))) |
| 88 | + def visit_field_call(self, node: cst.Call) -> None: |
| 89 | + # Check if there's a `const=True` argument. |
| 90 | + const_arg = m.Arg(value=m.Name("True"), keyword=m.Name("const")) |
| 91 | + if m.matches(node, m.Call(func=m.Name("Field"), args=[~m.Arg(value=m.Name("...")), const_arg])): |
| 92 | + self._const = node.args[0] |
71 | 93 |
|
72 | 94 | @m.leave(m.Call(func=m.Name("Field")))
|
73 | 95 | def leave_field_call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
|
|
0 commit comments