Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test serializability of SerializableDataclass subclasses #315

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions dbt_semantic_interfaces/dataclass_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

import dataclasses
import datetime
import inspect
import logging
from abc import ABC
from builtins import NameError
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
ClassVar,
Dict,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -141,25 +146,47 @@ def _get_type_parameter_for_sequence_like_tuple_type(field_type: Type) -> Type:
return args[0]


class SerializableDataclass:
"""Describes a dataclass that can be serialized using DataclassSerializer.
class SerializableDataclass(ABC):
"""Describes a dataclass that can be serialized using `DataclassSerializer`.

Previously, Pydnatic has been used for defining objects as it provides built in support for serialization and
Previously, Pydantic has been used for defining objects as it provides built in support for serialization and
deserialization. However, Pydantic object is slow compared to dataclass initialization, with tests showing 10x-100x
slower performance. This is an issue if many objects are created, which can happen in during plan generation. Using
the BaseModel.construct() is still not as fast as dataclass initiaization and it also makes for an awkward developer
interface. Because of this, MF implements a simple custom serializer / deserializer to work with the built-in
Python dataclass.
the BaseModel.construct() is still not as fast as dataclass initialization, and it also makes for an awkward
developer interface. Because of this, MF implements a simple custom serializer / deserializer to work with the
built-in Python dataclass.

The dataclass must have concrete types for all fields and not all types are supported. Please see implementation
details in DataclassSerializer. Not adding post_init checks as there have been previous issues with slow object
initialization.

This is a concrete object as MyPy currently throws a type error if a Python dataclass is defined with an abstract
parent class.
"""

pass
# Contains all known implementing subclasses.
_concrete_subclass_registry: ClassVar[Optional[Set[Type[SerializableDataclass]]]] = None

@classmethod
def concrete_subclasses_for_testing(cls) -> Sequence[Type[SerializableDataclass]]:
"""Returns subclasses that implement this interface.

This is intended to be used in tests to verify the ability to serialize the class.
"""
return sorted(
cls._concrete_subclass_registry or (), key=lambda class_type: (class_type.__module__, class_type.__name__)
)

def __init_subclass__(cls, **kwargs) -> None:
"""Adds the implementing class to the registry and check for non-concrete fields.

It would be helpful to check that the fields of the dataclass are concrete fields, but that would need to be
done after class initialization, and checking in `__post_init__` adds significant overhead.
"""
super().__init_subclass__(**kwargs)

if SerializableDataclass._concrete_subclass_registry is None:
SerializableDataclass._concrete_subclass_registry = set()

if not inspect.isabstract(cls):
SerializableDataclass._concrete_subclass_registry.add(cls)


SerializableDataclassT = TypeVar("SerializableDataclassT", bound=SerializableDataclass)
Expand Down
1 change: 1 addition & 0 deletions dbt_semantic_interfaces/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class GroupByMetricReference(LinkableElementReference):
pass


@dataclass(frozen=True, order=True)
class ModelReference(SerializableDataclass):
"""A reference to something in the model.

Expand Down
Empty file.
38 changes: 38 additions & 0 deletions dbt_semantic_interfaces/test_helpers/dataclass_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Iterable, Sequence, Type

from dbt_semantic_interfaces.dataclass_serialization import (
DataClassDeserializer,
DataclassSerializer,
SerializableDataclass,
)


def assert_includes_all_serializable_dataclass_types(
instances: Sequence[SerializableDataclass], excluded_classes: Iterable[Type[SerializableDataclass]]
) -> None:
"""Verify that the given instances include at least one instance of the known subclasses."""
instance_types = {type(instance) for instance in instances}
missing_instance_types = (
set(SerializableDataclass.concrete_subclasses_for_testing())
.difference(instance_types)
.difference(excluded_classes)
)
missing_type_names = sorted(instance_type.__name__ for instance_type in missing_instance_types)
assert (
len(missing_type_names) == 0
), f"Missing instances of the following classes: {missing_type_names}. Please add them."


def assert_serializable(instances: Sequence[SerializableDataclass]) -> None:
"""Verify that the given instances are actually serializable."""
serializer = DataclassSerializer()
deserializer = DataClassDeserializer()

for instance in instances:
try:
serialized_output = serializer.pydantic_serialize(instance)
deserialized_instance = deserializer.pydantic_deserialize(type(instance), serialized_output)
except Exception as e:
raise AssertionError(f"Error serializing {instance=}") from e

assert instance == deserialized_instance
Empty file added tests/serialization/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions tests/serialization/test_serializable_dataclass_subclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import itertools
import logging

from dbt_semantic_interfaces.references import (
DimensionReference,
ElementReference,
EntityReference,
GroupByMetricReference,
LinkableElementReference,
MeasureReference,
MetricModelReference,
MetricReference,
ModelReference,
SemanticModelElementReference,
SemanticModelReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.test_helpers.dataclass_serialization import (
assert_includes_all_serializable_dataclass_types,
assert_serializable,
)
from tests.test_dataclass_serialization import (
DataclassWithDataclassDefault,
DataclassWithDefaultTuple,
DataclassWithOptional,
DataclassWithPrimitiveTypes,
DataclassWithTuple,
DeeplyNestedDataclass,
NestedDataclass,
NestedDataclassWithProtocol,
SimpleClassWithProtocol,
SimpleDataclass,
)

logger = logging.getLogger(__name__)


def test_serializable_dataclass_subclasses() -> None:
"""Verify that all subclasses of `SerializableDataclass` are serializable."""
counter = itertools.count(start=0)

def _get_next_field_str() -> str:
return f"field_{next(counter)}"

instances = [
LinkableElementReference(_get_next_field_str()),
ElementReference(_get_next_field_str()),
SemanticModelElementReference(_get_next_field_str(), _get_next_field_str()),
EntityReference(_get_next_field_str()),
SemanticModelReference(_get_next_field_str()),
TimeDimensionReference(_get_next_field_str()),
MetricReference(_get_next_field_str()),
GroupByMetricReference(_get_next_field_str()),
MetricModelReference(_get_next_field_str()),
DimensionReference(_get_next_field_str()),
MeasureReference(_get_next_field_str()),
ModelReference(),
]

assert_includes_all_serializable_dataclass_types(
instances=instances,
# These are classes defined and used in a separate test.
excluded_classes=[
DataclassWithDataclassDefault,
DataclassWithDefaultTuple,
DataclassWithOptional,
DataclassWithPrimitiveTypes,
DataclassWithTuple,
DeeplyNestedDataclass,
NestedDataclass,
NestedDataclassWithProtocol,
SimpleClassWithProtocol,
SimpleDataclass,
],
)
assert_serializable(instances)
Loading