Skip to content

Fix for JSON on optional nested types. #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,6 @@ print(Settings().model_dump())
`env_nested_delimiter` can be configured via the `model_config` as shown above, or via the
`_env_nested_delimiter` keyword argument on instantiation.

JSON is only parsed in top-level fields, if you need to parse JSON in sub-models, you will need to implement
validators on those models.

Nested environment variables take precedence over the top-level environment variable JSON
(e.g. in the example above, `SUB_MODEL__V2` trumps `SUB_MODEL`).

Expand Down
31 changes: 20 additions & 11 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from dotenv import dotenv_values
from pydantic import AliasChoices, AliasPath, BaseModel, Json, TypeAdapter
from pydantic._internal._typing_extra import origin_is_union
from pydantic._internal._utils import deep_update, lenient_issubclass
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin

Expand Down Expand Up @@ -188,6 +188,8 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
)
else: # string validation alias
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
else:
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))

Expand Down Expand Up @@ -478,24 +480,21 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
# simplest case, field is not complex, we only need to add the value if it was found
return value

def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool:
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))

def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if self.field_is_complex(field):
allow_parse_failure = False
elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata):
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
allow_parse_failure = True
else:
return False, False

return True, allow_parse_failure

@staticmethod
def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None:
def next_field(field: FieldInfo | Any | None, key: str) -> FieldInfo | None:
"""
Find the field in a sub model by key(env name)

Expand Down Expand Up @@ -524,11 +523,17 @@ class Cfg(BaseSettings):
Returns:
Field if it finds the next field otherwise `None`.
"""
if not field or origin_is_union(get_origin(field.annotation)):
# no support for Unions of complex BaseSettings fields
if not field:
return None
elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key):
return field.annotation.model_fields[key]

annotation = field.annotation if isinstance(field, FieldInfo) else field
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
for type_ in get_args(annotation):
type_has_key = EnvSettingsSource.next_field(type_, key)
if type_has_key:
return type_has_key
elif is_model_class(annotation) and annotation.model_fields.get(key):
return annotation.model_fields[key]

return None

Expand Down Expand Up @@ -716,3 +721,7 @@ def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
annotation
)


def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
31 changes: 31 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ class Cfg(BaseSettings):
}


def test_nested_env_optional_json(env):
class Child(BaseModel):
num_list: Optional[List[int]] = None

class Cfg(BaseSettings, env_nested_delimiter='__'):
child: Optional[Child] = None

env.set('CHILD__NUM_LIST', '[1,2,3]')
cfg = Cfg()
assert cfg.model_dump() == {
'child': {
'num_list': [1, 2, 3],
},
}


def test_nested_env_delimiter_with_prefix(env):
class Subsettings(BaseSettings):
banana: str
Expand Down Expand Up @@ -1212,6 +1228,21 @@ class Settings(BaseSettings):
assert Settings().model_dump() == {'foo': {'a': 'b'}}


def test_secrets_nested_optional_json(tmp_path):
p = tmp_path / 'foo'
p.write_text('{"a": 10}')

class Foo(BaseModel):
a: int

class Settings(BaseSettings):
foo: Optional[Foo] = None

model_config = SettingsConfigDict(secrets_dir=tmp_path)

assert Settings().model_dump() == {'foo': {'a': 10}}


def test_secrets_path_invalid_json(tmp_path):
p = tmp_path / 'foo'
p.write_text('{"a": "b"')
Expand Down