Skip to content

Commit e129a81

Browse files
kschwabhramezani
andauthored
Add environment parsing support for enums. (#252)
Co-authored-by: Hasan Ramezani <[email protected]>
1 parent 7e5731a commit e129a81

File tree

3 files changed

+67
-3
lines changed

3 files changed

+67
-3
lines changed

pydantic_settings/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SettingsConfigDict(ConfigDict, total=False):
2828
env_ignore_empty: bool
2929
env_nested_delimiter: str | None
3030
env_parse_none_str: str | None
31+
env_parse_enums: bool | None
3132
secrets_dir: str | Path | None
3233
json_file: PathType | None
3334
json_file_encoding: str | None
@@ -65,6 +66,7 @@ class BaseSettings(BaseModel):
6566
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
6667
_env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.)
6768
into `None` type(None). Defaults to `None` type(None), which means no parsing should occur.
69+
_env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur.
6870
_secrets_dir: The secret files directory. Defaults to `None`.
6971
"""
7072

@@ -77,6 +79,7 @@ def __init__(
7779
_env_ignore_empty: bool | None = None,
7880
_env_nested_delimiter: str | None = None,
7981
_env_parse_none_str: str | None = None,
82+
_env_parse_enums: bool | None = None,
8083
_secrets_dir: str | Path | None = None,
8184
**values: Any,
8285
) -> None:
@@ -91,6 +94,7 @@ def __init__(
9194
_env_ignore_empty=_env_ignore_empty,
9295
_env_nested_delimiter=_env_nested_delimiter,
9396
_env_parse_none_str=_env_parse_none_str,
97+
_env_parse_enums=_env_parse_enums,
9498
_secrets_dir=_secrets_dir,
9599
)
96100
)
@@ -129,6 +133,7 @@ def _settings_build_values(
129133
_env_ignore_empty: bool | None = None,
130134
_env_nested_delimiter: str | None = None,
131135
_env_parse_none_str: str | None = None,
136+
_env_parse_enums: bool | None = None,
132137
_secrets_dir: str | Path | None = None,
133138
) -> dict[str, Any]:
134139
# Determine settings config values
@@ -149,6 +154,7 @@ def _settings_build_values(
149154
env_parse_none_str = (
150155
_env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str')
151156
)
157+
env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums')
152158
secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')
153159

154160
# Configure built-in sources
@@ -160,6 +166,7 @@ def _settings_build_values(
160166
env_nested_delimiter=env_nested_delimiter,
161167
env_ignore_empty=env_ignore_empty,
162168
env_parse_none_str=env_parse_none_str,
169+
env_parse_enums=env_parse_enums,
163170
)
164171
dotenv_settings = DotEnvSettingsSource(
165172
self.__class__,
@@ -170,6 +177,7 @@ def _settings_build_values(
170177
env_nested_delimiter=env_nested_delimiter,
171178
env_ignore_empty=env_ignore_empty,
172179
env_parse_none_str=env_parse_none_str,
180+
env_parse_enums=env_parse_enums,
173181
)
174182

175183
file_secret_settings = SecretsSettingsSource(
@@ -201,6 +209,7 @@ def _settings_build_values(
201209
env_ignore_empty=False,
202210
env_nested_delimiter=None,
203211
env_parse_none_str=None,
212+
env_parse_enums=None,
204213
json_file=None,
205214
json_file_encoding=None,
206215
yaml_file=None,

pydantic_settings/sources.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from abc import ABC, abstractmethod
88
from collections import deque
99
from dataclasses import is_dataclass
10+
from enum import Enum
1011
from pathlib import Path
1112
from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast
1213

@@ -179,6 +180,7 @@ def __init__(
179180
env_prefix: str | None = None,
180181
env_ignore_empty: bool | None = None,
181182
env_parse_none_str: str | None = None,
183+
env_parse_enums: bool | None = None,
182184
) -> None:
183185
super().__init__(settings_cls)
184186
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
@@ -189,6 +191,7 @@ def __init__(
189191
self.env_parse_none_str = (
190192
env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
191193
)
194+
self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
192195

193196
def _apply_case_sensitive(self, value: str) -> str:
194197
return value.lower() if not self.case_sensitive else value
@@ -357,8 +360,11 @@ def __init__(
357360
env_prefix: str | None = None,
358361
env_ignore_empty: bool | None = None,
359362
env_parse_none_str: str | None = None,
363+
env_parse_enums: bool | None = None,
360364
) -> None:
361-
super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str)
365+
super().__init__(
366+
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
367+
)
362368
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
363369

364370
def __call__(self) -> dict[str, Any]:
@@ -447,8 +453,11 @@ def __init__(
447453
env_nested_delimiter: str | None = None,
448454
env_ignore_empty: bool | None = None,
449455
env_parse_none_str: str | None = None,
456+
env_parse_enums: bool | None = None,
450457
) -> None:
451-
super().__init__(settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str)
458+
super().__init__(
459+
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
460+
)
452461
self.env_nested_delimiter = (
453462
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
454463
)
@@ -498,6 +507,10 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
498507
ValuesError: When There is an error in deserializing value for complex field.
499508
"""
500509
is_complex, allow_parse_failure = self._field_is_complex(field)
510+
if self.env_parse_enums and lenient_issubclass(field.annotation, Enum):
511+
if value in tuple(val.name for val in field.annotation): # type: ignore
512+
value = field.annotation[value] # type: ignore
513+
501514
if is_complex or value_is_complex:
502515
if value is None:
503516
# field is complex but no value found so far, try explode_env_vars
@@ -645,13 +658,20 @@ def __init__(
645658
env_nested_delimiter: str | None = None,
646659
env_ignore_empty: bool | None = None,
647660
env_parse_none_str: str | None = None,
661+
env_parse_enums: bool | None = None,
648662
) -> None:
649663
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
650664
self.env_file_encoding = (
651665
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
652666
)
653667
super().__init__(
654-
settings_cls, case_sensitive, env_prefix, env_nested_delimiter, env_ignore_empty, env_parse_none_str
668+
settings_cls,
669+
case_sensitive,
670+
env_prefix,
671+
env_nested_delimiter,
672+
env_ignore_empty,
673+
env_parse_none_str,
674+
env_parse_enums,
655675
)
656676

657677
def _load_env_vars(self) -> Mapping[str, str | None]:

tests/test_settings.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import uuid
66
from datetime import datetime, timezone
7+
from enum import IntEnum
78
from pathlib import Path
89
from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union
910

@@ -1853,6 +1854,40 @@ class Settings(BaseSettings):
18531854
]
18541855

18551856

1857+
def test_env_parse_enums(env):
1858+
class FruitsEnum(IntEnum):
1859+
pear = 0
1860+
kiwi = 1
1861+
lime = 2
1862+
1863+
class Settings(BaseSettings):
1864+
fruit: FruitsEnum
1865+
1866+
with pytest.raises(ValidationError) as exc_info:
1867+
env.set('FRUIT', 'kiwi')
1868+
s = Settings()
1869+
assert exc_info.value.errors(include_url=False) == [
1870+
{
1871+
'type': 'int_parsing',
1872+
'loc': ('fruit',),
1873+
'msg': 'Input should be a valid integer, unable to parse string as an integer',
1874+
'input': 'kiwi',
1875+
}
1876+
]
1877+
1878+
env.set('FRUIT', str(FruitsEnum.lime.value))
1879+
s = Settings()
1880+
assert s.fruit == FruitsEnum.lime
1881+
1882+
env.set('FRUIT', 'kiwi')
1883+
s = Settings(_env_parse_enums=True)
1884+
assert s.fruit == FruitsEnum.kiwi
1885+
1886+
env.set('FRUIT', str(FruitsEnum.lime.value))
1887+
s = Settings(_env_parse_enums=True)
1888+
assert s.fruit == FruitsEnum.lime
1889+
1890+
18561891
def test_env_parse_none_str(env):
18571892
env.set('x', 'null')
18581893
env.set('y', 'y_override')

0 commit comments

Comments
 (0)