Skip to content

Commit 4433101

Browse files
kschwabhramezani
andauthored
Coerce env vars if strict is True. (#693)
Co-authored-by: Hasan Ramezani <[email protected]>
1 parent 4d2ebfd commit 4433101

File tree

3 files changed

+76
-6
lines changed

3 files changed

+76
-6
lines changed

pydantic_settings/sources/providers/env.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Any,
88
)
99

10+
from pydantic import Json, TypeAdapter, ValidationError
1011
from pydantic._internal._utils import deep_update, is_model_class
1112
from pydantic.dataclasses import is_pydantic_dataclass
1213
from pydantic.fields import FieldInfo
@@ -17,6 +18,7 @@
1718
from ..base import PydanticBaseEnvSettingsSource
1819
from ..types import EnvNoneType
1920
from ..utils import (
21+
_annotation_contains_types,
2022
_annotation_enum_name_to_val,
2123
_get_model_fields,
2224
_union_is_complex,
@@ -125,7 +127,7 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
125127
return value
126128
elif value is not None:
127129
# simplest case, field is not complex, we only need to add the value if it was found
128-
return value
130+
return self._coerce_env_val_strict(field, value)
129131

130132
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
131133
"""
@@ -256,10 +258,31 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
256258
raise e
257259
if isinstance(env_var, dict):
258260
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
259-
env_var[last_key] = env_val
260-
261+
env_var[last_key] = self._coerce_env_val_strict(target_field, env_val)
261262
return result
262263

264+
def _coerce_env_val_strict(self, field: FieldInfo | None, value: Any) -> Any:
265+
"""
266+
Coerce environment string values based on field annotation if model config is `strict=True`.
267+
268+
Args:
269+
field: The field.
270+
value: The value to coerce.
271+
272+
Returns:
273+
The coerced value if successful, otherwise the original value.
274+
"""
275+
try:
276+
if self.config.get('strict') and isinstance(value, str) and field is not None:
277+
if value == self.env_parse_none_str:
278+
return value
279+
if not _annotation_contains_types(field.annotation, (Json,), is_instance=True):
280+
return TypeAdapter(field.annotation).validate_python(value)
281+
except ValidationError:
282+
# Allow validation error to be raised at time of instatiation
283+
pass
284+
return value
285+
263286
def __repr__(self) -> str:
264287
return (
265288
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '

pydantic_settings/sources/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,24 @@ def _annotation_contains_types(
9292
types: tuple[Any, ...],
9393
is_include_origin: bool = True,
9494
is_strip_annotated: bool = False,
95+
is_instance: bool = False,
9596
) -> bool:
9697
"""Check if a type annotation contains any of the specified types."""
9798
if is_strip_annotated:
9899
annotation = _strip_annotated(annotation)
99-
if is_include_origin is True and get_origin(annotation) in types:
100-
return True
100+
if is_include_origin is True:
101+
origin = get_origin(annotation)
102+
if origin in types:
103+
return True
104+
if is_instance and any(isinstance(origin, type_) for type_ in types):
105+
return True
101106
for type_ in get_args(annotation):
102-
if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated):
107+
if _annotation_contains_types(
108+
type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated, is_instance=is_instance
109+
):
103110
return True
111+
if is_instance and any(isinstance(annotation, type_) for type_ in types):
112+
return True
104113
return annotation in types
105114

106115

tests/test_settings.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3200,3 +3200,41 @@ class Settings(BaseSettings):
32003200
f'source to the settings sources via the settings_customise_sources hook.'
32013201
)
32023202
assert warning.message.args[0] == expected_message
3203+
3204+
3205+
def test_env_strict_coercion(env):
3206+
class SubModel(BaseModel):
3207+
my_str: str
3208+
my_int: int
3209+
3210+
class Settings(BaseSettings, env_nested_delimiter='__'):
3211+
my_str: str
3212+
my_int: int
3213+
sub_model: SubModel
3214+
3215+
env.set('MY_STR', '0')
3216+
env.set('MY_INT', '0')
3217+
env.set('SUB_MODEL__MY_STR', '1')
3218+
env.set('SUB_MODEL__MY_INT', '1')
3219+
Settings().model_dump() == {
3220+
'my_str': '0',
3221+
'my_int': 0,
3222+
'sub_model': {
3223+
'my_str': '1',
3224+
'my_int': 1,
3225+
},
3226+
}
3227+
3228+
class StrictSettings(BaseSettings, env_nested_delimiter='__', strict=True):
3229+
my_str: str
3230+
my_int: int
3231+
sub_model: SubModel
3232+
3233+
StrictSettings().model_dump() == {
3234+
'my_str': '0',
3235+
'my_int': 0,
3236+
'sub_model': {
3237+
'my_str': '1',
3238+
'my_int': 1,
3239+
},
3240+
}

0 commit comments

Comments
 (0)