diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 6deca656..1744820b 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -28,6 +28,7 @@ class SettingsConfigDict(ConfigDict, total=False): env_ignore_empty: bool env_nested_delimiter: str | None env_parse_none_str: str | None + env_parse_enums: bool | None secrets_dir: str | Path | None json_file: PathType | None json_file_encoding: str | None @@ -65,6 +66,7 @@ class BaseSettings(BaseModel): _env_nested_delimiter: The nested env values delimiter. Defaults to `None`. _env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.) into `None` type(None). Defaults to `None` type(None), which means no parsing should occur. + _env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur. _secrets_dir: The secret files directory. Defaults to `None`. """ @@ -77,6 +79,7 @@ def __init__( _env_ignore_empty: bool | None = None, _env_nested_delimiter: str | None = None, _env_parse_none_str: str | None = None, + _env_parse_enums: bool | None = None, _secrets_dir: str | Path | None = None, **values: Any, ) -> None: @@ -91,6 +94,7 @@ def __init__( _env_ignore_empty=_env_ignore_empty, _env_nested_delimiter=_env_nested_delimiter, _env_parse_none_str=_env_parse_none_str, + _env_parse_enums=_env_parse_enums, _secrets_dir=_secrets_dir, ) ) @@ -129,6 +133,7 @@ def _settings_build_values( _env_ignore_empty: bool | None = None, _env_nested_delimiter: str | None = None, _env_parse_none_str: str | None = None, + _env_parse_enums: bool | None = None, _secrets_dir: str | Path | None = None, ) -> dict[str, Any]: # Determine settings config values @@ -149,6 +154,7 @@ def _settings_build_values( env_parse_none_str = ( _env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str') ) + env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums') secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir') # Configure built-in sources @@ -160,6 +166,7 @@ def _settings_build_values( env_nested_delimiter=env_nested_delimiter, env_ignore_empty=env_ignore_empty, env_parse_none_str=env_parse_none_str, + env_parse_enums=env_parse_enums, ) dotenv_settings = DotEnvSettingsSource( self.__class__, @@ -170,6 +177,7 @@ def _settings_build_values( env_nested_delimiter=env_nested_delimiter, env_ignore_empty=env_ignore_empty, env_parse_none_str=env_parse_none_str, + env_parse_enums=env_parse_enums, ) file_secret_settings = SecretsSettingsSource( @@ -201,6 +209,7 @@ def _settings_build_values( env_ignore_empty=False, env_nested_delimiter=None, env_parse_none_str=None, + env_parse_enums=None, json_file=None, json_file_encoding=None, yaml_file=None, diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 82ec7271..79d5e501 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections import deque from dataclasses import is_dataclass +from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast @@ -179,6 +180,7 @@ def __init__( env_prefix: str | None = None, env_ignore_empty: bool | None = None, env_parse_none_str: str | None = None, + env_parse_enums: bool | None = None, ) -> None: super().__init__(settings_cls) self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False) @@ -189,6 +191,7 @@ def __init__( self.env_parse_none_str = ( env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str') ) + self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums') def _apply_case_sensitive(self, value: str) -> str: return value.lower() if not self.case_sensitive else value @@ -357,8 +360,11 @@ def __init__( env_prefix: str | None = None, env_ignore_empty: bool | None = None, env_parse_none_str: str | None = None, + env_parse_enums: bool | None = None, ) -> None: - super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str) + super().__init__( + settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums + ) self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir') def __call__(self) -> dict[str, Any]: @@ -447,8 +453,11 @@ def __init__( env_nested_delimiter: str | None = None, env_ignore_empty: bool | None = None, env_parse_none_str: str | None = None, + env_parse_enums: bool | None = None, ) -> None: - super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str) + super().__init__( + settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums + ) self.env_nested_delimiter = ( env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter') ) @@ -498,6 +507,10 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val ValuesError: When There is an error in deserializing value for complex field. """ is_complex, allow_parse_failure = self._field_is_complex(field) + if self.env_parse_enums and lenient_issubclass(field.annotation, Enum): + if value in tuple(val.name for val in field.annotation): # type: ignore + value = field.annotation[value] # type: ignore + if is_complex or value_is_complex: if value is None: # field is complex but no value found so far, try explode_env_vars @@ -645,13 +658,20 @@ def __init__( env_nested_delimiter: str | None = None, env_ignore_empty: bool | None = None, env_parse_none_str: str | None = None, + env_parse_enums: bool | None = None, ) -> None: self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file') self.env_file_encoding = ( env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding') ) super().__init__( - settings_cls, case_sensitive, env_prefix, env_nested_delimiter, env_ignore_empty, env_parse_none_str + settings_cls, + case_sensitive, + env_prefix, + env_nested_delimiter, + env_ignore_empty, + env_parse_none_str, + env_parse_enums, ) def _load_env_vars(self) -> Mapping[str, str | None]: diff --git a/tests/test_settings.py b/tests/test_settings.py index b0c75e47..7d248ca7 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -4,6 +4,7 @@ import sys import uuid from datetime import datetime, timezone +from enum import IntEnum from pathlib import Path from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -1853,6 +1854,40 @@ class Settings(BaseSettings): ] +def test_env_parse_enums(env): + class FruitsEnum(IntEnum): + pear = 0 + kiwi = 1 + lime = 2 + + class Settings(BaseSettings): + fruit: FruitsEnum + + with pytest.raises(ValidationError) as exc_info: + env.set('FRUIT', 'kiwi') + s = Settings() + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': ('fruit',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'kiwi', + } + ] + + env.set('FRUIT', str(FruitsEnum.lime.value)) + s = Settings() + assert s.fruit == FruitsEnum.lime + + env.set('FRUIT', 'kiwi') + s = Settings(_env_parse_enums=True) + assert s.fruit == FruitsEnum.kiwi + + env.set('FRUIT', str(FruitsEnum.lime.value)) + s = Settings(_env_parse_enums=True) + assert s.fruit == FruitsEnum.lime + + def test_env_parse_none_str(env): env.set('x', 'null') env.set('y', 'y_override')