diff --git a/docs/index.md b/docs/index.md index 59fee9d8..fc48cf67 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1486,6 +1486,36 @@ print(Settings()) #> foobar='test' ``` +#### Accesing the result of previous sources + +Each source of settings can access the output of the previous ones. + +```python +from typing import Any, Dict, Tuple + +from pydantic.fields import FieldInfo + +from pydantic_settings import PydanticBaseSettingsSource + + +class MyCustomSource(PydanticBaseSettingsSource): + def get_field_value( + self, field: FieldInfo, field_name: str + ) -> Tuple[Any, str, bool]: ... + + def __call__(self) -> Dict[str, Any]: + # Retrieve the aggregated settings from previous sources + current_state = self.current_state + current_state.get('some_setting') + + # Retrive settings from all sources individually + # self.settings_sources_data["SettingsSourceName"]: Dict[str, Any] + settings_sources_data = self.settings_sources_data + settings_sources_data['SomeSettingsSource'].get('some_setting') + + # Your code here... +``` + ### Removing sources You might also want to disable a source: diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 333980f7..3cee6ab0 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -308,7 +308,19 @@ def _settings_build_values( ) sources = (cli_settings,) + sources if sources: - return deep_update(*reversed([source() for source in sources])) + state: dict[str, Any] = {} + states: dict[str, dict[str, Any]] = {} + for source in sources: + if isinstance(source, PydanticBaseSettingsSource): + source._set_current_state(state) + source._set_settings_sources_data(states) + + source_name = source.__name__ if hasattr(source, '__name__') else type(source).__name__ + source_state = source() + + states[source_name] = source_state + state = deep_update(source_state, state) + return state else: # no one should mean to do this, but I think returning an empty dict is marginally preferable # to an informative error and much better than a confusing error diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 3f19f9a3..95e9d2b8 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -126,6 +126,36 @@ class PydanticBaseSettingsSource(ABC): def __init__(self, settings_cls: type[BaseSettings]): self.settings_cls = settings_cls self.config = settings_cls.model_config + self._current_state: dict[str, Any] = {} + self._settings_sources_data: dict[str, dict[str, Any]] = {} + + def _set_current_state(self, state: dict[str, Any]) -> None: + """ + Record the state of settings from the previous settings sources. This should + be called right before __call__. + """ + self._current_state = state + + def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None: + """ + Record the state of settings from all previous settings sources. This should + be called right before __call__. + """ + self._settings_sources_data = states + + @property + def current_state(self) -> dict[str, Any]: + """ + The current state of the settings, populated by the previous settings sources. + """ + return self._current_state + + @property + def settings_sources_data(self) -> dict[str, dict[str, Any]]: + """ + The state of all previous settings sources. + """ + return self._settings_sources_data @abstractmethod def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: diff --git a/tests/test_settings.py b/tests/test_settings.py index de80faeb..3222e50c 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -4025,3 +4025,76 @@ class Settings(BaseSettings): env.set('nested__FOO', '["string1", "string2"]') s = Settings() assert s.model_dump() == {'nested': {'FOO': ['string1', 'string2']}} + + +def test_settings_source_current_state(env): + class SettingsSource(PydanticBaseSettingsSource): + def get_field_value(self, field: FieldInfo, field_name: str) -> Any: + pass + + def __call__(self) -> Dict[str, Any]: + current_state = self.current_state + if current_state.get('one') == '1': + return {'two': '1'} + + return {} + + class Settings(BaseSettings): + one: bool = False + two: bool = False + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return (env_settings, SettingsSource(settings_cls)) + + env.set('one', '1') + s = Settings() + assert s.two is True + + +def test_settings_source_settings_sources_data(env): + class SettingsSource(PydanticBaseSettingsSource): + def get_field_value(self, field: FieldInfo, field_name: str) -> Any: + pass + + def __call__(self) -> Dict[str, Any]: + settings_sources_data = self.settings_sources_data + if settings_sources_data == { + 'InitSettingsSource': {'one': True, 'two': True}, + 'EnvSettingsSource': {'one': '1'}, + 'function_settings_source': {'three': 'false'}, + }: + return {'four': '1'} + + return {} + + def function_settings_source(): + return {'three': 'false'} + + class Settings(BaseSettings): + one: bool = False + two: bool = False + three: bool = False + four: bool = False + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return (env_settings, init_settings, function_settings_source, SettingsSource(settings_cls)) + + env.set('one', '1') + s = Settings(one=True, two=True) + assert s.four is True