Skip to content

Commit ae72a18

Browse files
committed
🐛 Do not register multiple times same child schemas (#188)
1 parent 708b16f commit ae72a18

File tree

4 files changed

+100
-39
lines changed

4 files changed

+100
-39
lines changed

flama/schemas/generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import yaml
99

10-
from flama import routing, schemas, types, url
10+
from flama import exceptions, routing, schemas, types, url
1111
from flama.schemas import Schema, openapi
1212
from flama.schemas.data_structures import Parameter
1313

@@ -205,20 +205,20 @@ def register(self, schema: schemas.Schema, name: t.Optional[str] = None) -> int:
205205
:return: Schema ID.
206206
"""
207207
if schema in self:
208-
raise ValueError("Schema is already registered.")
208+
raise exceptions.ApplicationError(f"Schema '{schema}' is already registered.")
209209

210210
s = schemas.Schema(schema)
211211

212212
try:
213213
schema_name = name or s.name
214214
except ValueError as e: # pragma: no cover
215-
raise ValueError("Cannot infer schema name.") from e
215+
raise exceptions.ApplicationError("Cannot infer schema name.") from e
216216

217217
schema_instance = s.unique_schema
218218
schema_id = id(schema_instance)
219219
self[schema_id] = SchemaInfo(name=schema_name, schema=schema_instance)
220220

221-
for child_schema in [schemas.Schema(x) for x in s.nested_schemas() if x not in self]:
221+
for child_schema in (schemas.Schema(x) for x in s.nested_schemas() if x not in self):
222222
self.register(schema=child_schema.schema, name=child_schema.name)
223223

224224
return schema_id

tests/conftest.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,6 @@ def model_cls(self, framework: str):
165165
return self._models_cls[framework]
166166

167167
def _sklearn(self):
168-
assert np
169-
assert sklearn
170-
171168
model = sklearn.neural_network.MLPClassifier(activation="tanh", max_iter=2000, hidden_layer_sizes=(10,))
172169
model.fit(
173170
np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
@@ -176,9 +173,6 @@ def _sklearn(self):
176173
return model, sklearn.neural_network.MLPClassifier
177174

178175
def _sklearn_pipeline(self):
179-
assert np
180-
assert sklearn
181-
182176
model = sklearn.neural_network.MLPClassifier(activation="tanh", max_iter=2000, hidden_layer_sizes=(10,))
183177
numerical_transformer = sklearn.pipeline.Pipeline(
184178
[
@@ -205,9 +199,6 @@ def _sklearn_pipeline(self):
205199
return pipeline, sklearn.pipeline.Pipeline
206200

207201
def _tensorflow(self):
208-
assert np
209-
assert tf
210-
211202
model = tf.keras.models.Sequential(
212203
[
213204
tf.keras.Input((2,)),
@@ -227,9 +218,6 @@ def _tensorflow(self):
227218
return model, tf.keras.models.Sequential
228219

229220
def _torch(self):
230-
assert np
231-
assert torch
232-
233221
class Model(torch.nn.Module):
234222
def __init__(self):
235223
super().__init__()

tests/schemas/conftest.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,45 @@ def bar_optional_schema(app, foo_schema):
7474
return namedtuple("BarOptionalSchema", ("schema", "name"))(schema=schema, name=name)
7575

7676

77+
@pytest.fixture(scope="function")
78+
def bar_multiple_schema(app, foo_schema):
79+
child_schema = foo_schema.schema
80+
if app.schema.schema_library.name == "pydantic":
81+
schema = pydantic.create_model(
82+
"BarMultiple",
83+
first=(t.Union[child_schema, None], None),
84+
second=(t.Union[child_schema, None], None),
85+
__module__="pydantic.main",
86+
)
87+
name = "pydantic.main.BarOptional"
88+
elif app.schema.schema_library.name == "typesystem":
89+
schema = typesystem.Schema(
90+
title="BarOptional",
91+
fields={
92+
"first": typesystem.Reference(
93+
to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}), allow_null=True, default=None
94+
),
95+
"second": typesystem.Reference(
96+
to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}), allow_null=True, default=None
97+
),
98+
},
99+
)
100+
name = "typesystem.schemas.BarOptional"
101+
elif app.schema.schema_library.name == "marshmallow":
102+
schema = type(
103+
"BarOptional",
104+
(marshmallow.Schema,),
105+
{
106+
"first": marshmallow.fields.Nested(child_schema(), required=False, dump_default=None, allow_none=True),
107+
"second": marshmallow.fields.Nested(child_schema(), required=False, dump_default=None, allow_none=True),
108+
},
109+
)
110+
name = "abc.BarOptional"
111+
else:
112+
raise ValueError(f"Wrong schema lib: {app.schema.schema_library.name}")
113+
return namedtuple("BarMultipleSchema", ("schema", "name"))(schema=schema, name=name)
114+
115+
77116
@pytest.fixture(scope="function")
78117
def bar_list_schema(app, foo_schema):
79118
child_schema = foo_schema.schema
@@ -131,5 +170,12 @@ def bar_dict_schema(app, foo_schema):
131170

132171

133172
@pytest.fixture(scope="function")
134-
def schemas(foo_schema, bar_schema, bar_list_schema, bar_dict_schema):
135-
return {"Foo": foo_schema, "Bar": bar_schema, "BarList": bar_list_schema, "BarDict": bar_dict_schema}
173+
def schemas(foo_schema, bar_schema, bar_optional_schema, bar_multiple_schema, bar_list_schema, bar_dict_schema):
174+
return {
175+
"Foo": foo_schema,
176+
"Bar": bar_schema,
177+
"BarOptional": bar_optional_schema,
178+
"BarMultiple": bar_multiple_schema,
179+
"BarList": bar_list_schema,
180+
"BarDict": bar_dict_schema,
181+
}

tests/schemas/test_generator.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import typesystem
99
import typesystem.fields
1010

11-
from flama import Flama, schemas
11+
from flama import Flama, exceptions, schemas
1212
from flama.endpoints import HTTPEndpoint
1313
from flama.schemas import openapi
1414
from flama.schemas.generator import SchemaRegistry
@@ -782,30 +782,57 @@ def test_used(self, registry, schemas, spec, operation, register_schemas, output
782782
assert set(registry.used(spec).keys()) == expected_output
783783

784784
@pytest.mark.parametrize(
785-
["schema", "explicit_name", "output"],
785+
["cases"],
786786
[
787-
pytest.param("Foo", "Foo", {"Foo": "Foo"}, id="explicit_name"),
788-
pytest.param("Foo", None, {"Foo": None}, id="infer_name"),
789-
pytest.param("Bar", "Bar", {"Bar": "Bar"}, id="nested_schemas"),
787+
pytest.param(
788+
[
789+
("Foo", "Foo", {"Foo": "Foo"}, None),
790+
],
791+
id="explicit_name",
792+
),
793+
pytest.param(
794+
[
795+
("Foo", None, {"Foo": None}, None),
796+
],
797+
id="infer_name",
798+
),
799+
pytest.param(
800+
[
801+
("Bar", "Bar", {"Bar": "Bar"}, None),
802+
],
803+
id="nested_schemas",
804+
),
805+
pytest.param(
806+
[
807+
("Foo", "Foo", {"Foo": "Foo"}, None),
808+
("Foo", None, None, (exceptions.ApplicationError, r"Schema '.*' is already registered.")),
809+
],
810+
id="error_already_registered",
811+
),
812+
pytest.param(
813+
[
814+
("BarMultiple", "BarMultiple", {}, None),
815+
],
816+
id="multiple_child_schema",
817+
),
790818
],
791819
)
792-
def test_register(self, registry, schemas, schema, explicit_name, output):
793-
schema, name = schemas[schema]
794-
expected_name = name if not explicit_name else explicit_name
795-
exception = (
796-
contextlib.ExitStack() if expected_name else pytest.raises(ValueError, match="Cannot infer schema name.")
797-
)
798-
with exception:
799-
registry.register(schema, name=explicit_name)
800-
for s, n in output.items():
801-
assert schemas[s].schema in registry
802-
assert registry[schemas[s].schema].name == (n or schemas[s].name)
820+
def test_register(self, registry, schemas, cases):
821+
for schema_key, explicit_name, output, exception in cases:
822+
schema, name = schemas[schema_key]
803823

804-
def test_register_already_registered(self, registry, foo_schema):
805-
schema = foo_schema.schema
806-
registry.register(schema, name="Foo")
807-
with pytest.raises(ValueError, match="Schema is already registered."):
808-
registry.register(schema, name="Foo")
824+
if explicit_name is None and name is None:
825+
exception = pytest.raises(ValueError, match="Cannot infer schema name.")
826+
elif exception is not None:
827+
exception = pytest.raises(exception[0], match=exception[1])
828+
else:
829+
exception = contextlib.ExitStack()
830+
831+
with exception:
832+
registry.register(schema, name=explicit_name)
833+
for s, n in output.items():
834+
assert schemas[s].schema in registry
835+
assert registry[schemas[s].schema].name == (n or schemas[s].name)
809836

810837
@pytest.mark.parametrize(
811838
["multiple", "result"],

0 commit comments

Comments
 (0)