Skip to content

Commit

Permalink
Merge pull request #315 from dbt-labs/plypaul--64--serializable-datac…
Browse files Browse the repository at this point in the history
…lass-updates

Test serializability of `SerializableDataclass` subclasses
  • Loading branch information
plypaul authored Jul 17, 2024
2 parents eade721 + 573d0cc commit d1bf57a
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 10 deletions.
47 changes: 37 additions & 10 deletions dbt_semantic_interfaces/dataclass_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

import dataclasses
import datetime
import inspect
import logging
from abc import ABC
from builtins import NameError
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
ClassVar,
Dict,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -141,25 +146,47 @@ def _get_type_parameter_for_sequence_like_tuple_type(field_type: Type) -> Type:
return args[0]


class SerializableDataclass:
"""Describes a dataclass that can be serialized using DataclassSerializer.
class SerializableDataclass(ABC):
"""Describes a dataclass that can be serialized using `DataclassSerializer`.
Previously, Pydnatic has been used for defining objects as it provides built in support for serialization and
Previously, Pydantic has been used for defining objects as it provides built in support for serialization and
deserialization. However, Pydantic object is slow compared to dataclass initialization, with tests showing 10x-100x
slower performance. This is an issue if many objects are created, which can happen in during plan generation. Using
the BaseModel.construct() is still not as fast as dataclass initiaization and it also makes for an awkward developer
interface. Because of this, MF implements a simple custom serializer / deserializer to work with the built-in
Python dataclass.
the BaseModel.construct() is still not as fast as dataclass initialization, and it also makes for an awkward
developer interface. Because of this, MF implements a simple custom serializer / deserializer to work with the
built-in Python dataclass.
The dataclass must have concrete types for all fields and not all types are supported. Please see implementation
details in DataclassSerializer. Not adding post_init checks as there have been previous issues with slow object
initialization.
This is a concrete object as MyPy currently throws a type error if a Python dataclass is defined with an abstract
parent class.
"""

pass
# Contains all known implementing subclasses.
_concrete_subclass_registry: ClassVar[Optional[Set[Type[SerializableDataclass]]]] = None

@classmethod
def concrete_subclasses_for_testing(cls) -> Sequence[Type[SerializableDataclass]]:
"""Returns subclasses that implement this interface.
This is intended to be used in tests to verify the ability to serialize the class.
"""
return sorted(
cls._concrete_subclass_registry or (), key=lambda class_type: (class_type.__module__, class_type.__name__)
)

def __init_subclass__(cls, **kwargs) -> None:
"""Adds the implementing class to the registry and check for non-concrete fields.
It would be helpful to check that the fields of the dataclass are concrete fields, but that would need to be
done after class initialization, and checking in `__post_init__` adds significant overhead.
"""
super().__init_subclass__(**kwargs)

if SerializableDataclass._concrete_subclass_registry is None:
SerializableDataclass._concrete_subclass_registry = set()

if not inspect.isabstract(cls):
SerializableDataclass._concrete_subclass_registry.add(cls)


SerializableDataclassT = TypeVar("SerializableDataclassT", bound=SerializableDataclass)
Expand Down
1 change: 1 addition & 0 deletions dbt_semantic_interfaces/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class GroupByMetricReference(LinkableElementReference):
pass


@dataclass(frozen=True, order=True)
class ModelReference(SerializableDataclass):
"""A reference to something in the model.
Expand Down
Empty file.
38 changes: 38 additions & 0 deletions dbt_semantic_interfaces/test_helpers/dataclass_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Iterable, Sequence, Type

from dbt_semantic_interfaces.dataclass_serialization import (
DataClassDeserializer,
DataclassSerializer,
SerializableDataclass,
)


def assert_includes_all_serializable_dataclass_types(
instances: Sequence[SerializableDataclass], excluded_classes: Iterable[Type[SerializableDataclass]]
) -> None:
"""Verify that the given instances include at least one instance of the known subclasses."""
instance_types = {type(instance) for instance in instances}
missing_instance_types = (
set(SerializableDataclass.concrete_subclasses_for_testing())
.difference(instance_types)
.difference(excluded_classes)
)
missing_type_names = sorted(instance_type.__name__ for instance_type in missing_instance_types)
assert (
len(missing_type_names) == 0
), f"Missing instances of the following classes: {missing_type_names}. Please add them."


def assert_serializable(instances: Sequence[SerializableDataclass]) -> None:
"""Verify that the given instances are actually serializable."""
serializer = DataclassSerializer()
deserializer = DataClassDeserializer()

for instance in instances:
try:
serialized_output = serializer.pydantic_serialize(instance)
deserialized_instance = deserializer.pydantic_deserialize(type(instance), serialized_output)
except Exception as e:
raise AssertionError(f"Error serializing {instance=}") from e

assert instance == deserialized_instance
Empty file added tests/serialization/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions tests/serialization/test_serializable_dataclass_subclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import itertools
import logging

from dbt_semantic_interfaces.references import (
DimensionReference,
ElementReference,
EntityReference,
GroupByMetricReference,
LinkableElementReference,
MeasureReference,
MetricModelReference,
MetricReference,
ModelReference,
SemanticModelElementReference,
SemanticModelReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.test_helpers.dataclass_serialization import (
assert_includes_all_serializable_dataclass_types,
assert_serializable,
)
from tests.test_dataclass_serialization import (
DataclassWithDataclassDefault,
DataclassWithDefaultTuple,
DataclassWithOptional,
DataclassWithPrimitiveTypes,
DataclassWithTuple,
DeeplyNestedDataclass,
NestedDataclass,
NestedDataclassWithProtocol,
SimpleClassWithProtocol,
SimpleDataclass,
)

logger = logging.getLogger(__name__)


def test_serializable_dataclass_subclasses() -> None:
"""Verify that all subclasses of `SerializableDataclass` are serializable."""
counter = itertools.count(start=0)

def _get_next_field_str() -> str:
return f"field_{next(counter)}"

instances = [
LinkableElementReference(_get_next_field_str()),
ElementReference(_get_next_field_str()),
SemanticModelElementReference(_get_next_field_str(), _get_next_field_str()),
EntityReference(_get_next_field_str()),
SemanticModelReference(_get_next_field_str()),
TimeDimensionReference(_get_next_field_str()),
MetricReference(_get_next_field_str()),
GroupByMetricReference(_get_next_field_str()),
MetricModelReference(_get_next_field_str()),
DimensionReference(_get_next_field_str()),
MeasureReference(_get_next_field_str()),
ModelReference(),
]

assert_includes_all_serializable_dataclass_types(
instances=instances,
# These are classes defined and used in a separate test.
excluded_classes=[
DataclassWithDataclassDefault,
DataclassWithDefaultTuple,
DataclassWithOptional,
DataclassWithPrimitiveTypes,
DataclassWithTuple,
DeeplyNestedDataclass,
NestedDataclass,
NestedDataclassWithProtocol,
SimpleClassWithProtocol,
SimpleDataclass,
],
)
assert_serializable(instances)

0 comments on commit d1bf57a

Please sign in to comment.