Skip to content

Commit 36c68a7

Browse files
authored
Merge pull request #1582 from HeyHugo/fix/defaultdict-issues
Fix issues with nested default dicts
2 parents d009420 + 8b56455 commit 36c68a7

File tree

3 files changed

+112
-77
lines changed

3 files changed

+112
-77
lines changed

ninja/params/models.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC, abstractmethod
2-
from collections import defaultdict
32
from typing import (
43
TYPE_CHECKING,
54
Any,
@@ -39,10 +38,6 @@
3938
TModels = List[TModel]
4039

4140

42-
def NestedDict() -> DictStrAny:
43-
return defaultdict(NestedDict)
44-
45-
4641
class ParamModel(BaseModel, ABC):
4742
__ninja_param_source__ = None
4843

@@ -65,11 +60,6 @@ def resolve(
6560
return cls()
6661

6762
data = cls._map_data_paths(data)
68-
# Convert defaultdict to dict for pydantic 2.12+ compatibility
69-
# In pydantic 2.12+, accessing missing keys in defaultdict creates nested
70-
# defaultdicts which then fail validation
71-
if isinstance(data, defaultdict):
72-
data = dict(data)
7363
return cls.model_validate(data, context={"request": request})
7464

7565
@classmethod
@@ -78,22 +68,20 @@ def _map_data_paths(cls, data: DictStrAny) -> DictStrAny:
7868
if not flatten_map:
7969
return data
8070

81-
mapped_data: DictStrAny = NestedDict()
82-
for k in flatten_map:
83-
if k in data:
84-
cls._map_data_path(mapped_data, data[k], flatten_map[k])
85-
else:
86-
cls._map_data_path(mapped_data, None, flatten_map[k])
87-
71+
mapped_data: DictStrAny = {}
72+
for key, path in flatten_map.items():
73+
cls._map_data_path(mapped_data, data.get(key), path)
8874
return mapped_data
8975

9076
@classmethod
91-
def _map_data_path(cls, data: DictStrAny, value: Any, path: Tuple) -> None:
92-
if len(path) == 1:
93-
if value is not None:
94-
data[path[0]] = value
95-
else:
96-
cls._map_data_path(data[path[0]], value, path[1:])
77+
def _map_data_path(
78+
cls, data: DictStrAny, value: Any, path: Tuple[str, ...]
79+
) -> None:
80+
current = data
81+
for key in path[:-1]:
82+
current = current.setdefault(key, {})
83+
if value is not None:
84+
current[path[-1]] = value
9785

9886

9987
class QueryModel(ParamModel):

tests/test_params_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Optional
2+
3+
from ninja.params.models import DictStrAny, ParamModel
4+
5+
6+
class _NestedParamModel(ParamModel):
7+
outer: DictStrAny
8+
leaf: Optional[int]
9+
10+
__ninja_flatten_map__ = {
11+
"foo": ("outer", "foo"),
12+
"bar": ("outer", "bar"),
13+
"leaf": ("leaf",),
14+
}
15+
16+
17+
def test_map_data_paths_creates_parent_for_missing_nested_values():
18+
assert _NestedParamModel._map_data_paths({}) == {"outer": {}}
19+
20+
21+
def test_map_data_paths_sets_values_when_present():
22+
data = _NestedParamModel._map_data_paths({"foo": 1, "leaf": 2})
23+
assert data == {"outer": {"foo": 1}, "leaf": 2}

tests/test_query_schema.py

Lines changed: 78 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datetime import datetime
22
from enum import IntEnum
33

4-
from pydantic import Field
4+
from pydantic import BaseModel, Field
55

66
from ninja import NinjaAPI, Query, Schema
7+
from ninja.testing.client import TestClient
78

89

910
class Range(IntEnum):
@@ -12,7 +13,7 @@ class Range(IntEnum):
1213
TWO_HUNDRED = 200
1314

1415

15-
class Filter(Schema):
16+
class Filter(BaseModel):
1617
to_datetime: datetime = Field(alias="to")
1718
from_datetime: datetime = Field(alias="from")
1819
range: Range = Range.TWENTY
@@ -28,7 +29,7 @@ class Data(Schema):
2829

2930
@api.get("/test")
3031
def query_params_schema(request, filters: Filter = Query(...)):
31-
return filters.dict()
32+
return filters.model_dump()
3233

3334

3435
@api.get("/test-mixed")
@@ -39,57 +40,80 @@ def query_params_mixed_schema(
3940
filters: Filter = Query(...),
4041
data: Data = Query(...),
4142
):
42-
return dict(query1=query1, query2=query2, filters=filters.dict(), data=data.dict())
43-
44-
45-
# def test_request():
46-
# client = TestClient(api)
47-
# response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
48-
# print("!", response.json())
49-
# assert response.json() == {
50-
# "to_datetime": "1970-01-01T00:00:02Z",
51-
# "from_datetime": "1970-01-01T00:00:01Z",
52-
# "range": 20,
53-
# }
54-
55-
# response = client.get("/test?from=1&to=2&range=21")
56-
# assert response.status_code == 422
57-
58-
59-
# def test_request_mixed():
60-
# client = TestClient(api)
61-
# response = client.get(
62-
# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6"
63-
# )
64-
# print(response.json())
65-
# assert response.json() == {
66-
# "data": {"a_float": 1.6, "an_int": 3},
67-
# "filters": {
68-
# "from_datetime": "1970-01-01T00:00:01Z",
69-
# "range": 20,
70-
# "to_datetime": "1970-01-01T00:00:02Z",
71-
# },
72-
# "query1": 2,
73-
# "query2": 5,
74-
# }
75-
76-
# response = client.get(
77-
# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10"
78-
# )
79-
# print(response.json())
80-
# assert response.json() == {
81-
# "data": {"a_float": 1.5, "an_int": 0},
82-
# "filters": {
83-
# "from_datetime": "1970-01-01T00:00:01Z",
84-
# "range": 20,
85-
# "to_datetime": "1970-01-01T00:00:02Z",
86-
# },
87-
# "query1": 2,
88-
# "query2": 10,
89-
# }
90-
91-
# response = client.get("/test-mixed?from=1&to=2")
92-
# assert response.status_code == 422
43+
return dict(
44+
query1=query1,
45+
query2=query2,
46+
filters=filters.model_dump(),
47+
data=data.model_dump(),
48+
)
49+
50+
51+
def test_request():
52+
client = TestClient(api)
53+
response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
54+
print("!", response.json())
55+
assert response.json() == {
56+
"to_datetime": "1970-01-01T00:00:02Z",
57+
"from_datetime": "1970-01-01T00:00:01Z",
58+
"range": 20,
59+
}
60+
61+
response = client.get("/test?from=1&to=2&range=21")
62+
assert response.status_code == 422
63+
64+
65+
def test_request_mixed():
66+
client = TestClient(api)
67+
response = client.get(
68+
"/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6"
69+
)
70+
print(response.json())
71+
assert response.json() == {
72+
"data": {"a_float": 1.6, "an_int": 3},
73+
"filters": {
74+
"from_datetime": "1970-01-01T00:00:01Z",
75+
"range": 20,
76+
"to_datetime": "1970-01-01T00:00:02Z",
77+
},
78+
"query1": 2,
79+
"query2": 5,
80+
}
81+
82+
response = client.get(
83+
"/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10"
84+
)
85+
print(response.json())
86+
assert response.json() == {
87+
"data": {"a_float": 1.5, "an_int": 0},
88+
"filters": {
89+
"from_datetime": "1970-01-01T00:00:01Z",
90+
"range": 20,
91+
"to_datetime": "1970-01-01T00:00:02Z",
92+
},
93+
"query1": 2,
94+
"query2": 10,
95+
}
96+
97+
response = client.get("/test-mixed?from=1&to=2")
98+
assert response.status_code == 422
99+
100+
101+
def test_request_query_params_using_basemodel():
102+
class Foo(BaseModel):
103+
start: int
104+
optional: int = 42
105+
106+
temp_api = NinjaAPI()
107+
108+
@temp_api.get("/foo")
109+
def view(request, foo: Foo = Query(...)):
110+
return foo.model_dump()
111+
112+
client = TestClient(temp_api)
113+
resp = client.get("/foo?start=1")
114+
115+
assert resp.status_code == 200
116+
assert resp.json() == {"start": 1, "optional": 42}
93117

94118

95119
def test_schema():

0 commit comments

Comments
 (0)