Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] refactor datatypes mtypes - example fixtures #458

Merged
merged 11 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions skpro/datatypes/_base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Base module for datatypes."""

from skpro.datatypes._base._base import BaseConverter, BaseDatatype
from skpro.datatypes._base._base import BaseConverter, BaseDatatype, BaseExample

__all__ = ["BaseConverter", "BaseDatatype"]
__all__ = ["BaseConverter", "BaseDatatype", "BaseExample"]
37 changes: 37 additions & 0 deletions skpro/datatypes/_base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,43 @@ def _get_key(self):
return (mtype_from, mtype_to, scitype)


class BaseExample(BaseObject):
"""Base class for Example fixtures used in tests and get_examples."""

_tags = {
"object_type": "datatype_example",
"scitype": None,
"mtype": None,
"python_version": None,
"python_dependencies": None,
"index": None, # integer index of the example to match with other mtypes
"lossy": False, # whether the example is lossy
}

def __init__(self):
super().__init__()

def _get_key(self):
"""Get unique dictionary key corresponding to self.

Private function, used in collecting a dictionary of examples.
"""
mtype = self.get_class_tag("mtype")
scitype = self.get_class_tag("scitype")
index = self.get_class_tag("index")
return (mtype, scitype, index)

def build(self):
"""Build example.

Returns
-------
obj : any
Example object.
"""
raise NotImplementedError


def _coerce_str_to_cls(cls_or_str):
"""Get class from string.

Expand Down
57 changes: 34 additions & 23 deletions skpro/datatypes/_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
e.g., metadata such as column names are missing
"""

from functools import lru_cache

from skpro.datatypes._registry import mtype_to_scitype

__author__ = ["fkiraly"]
Expand All @@ -21,29 +23,36 @@
"get_examples",
]

from skpro.datatypes._proba import (
example_dict_lossy_Proba,
example_dict_metadata_Proba,
example_dict_Proba,
)
from skpro.datatypes._table import (
example_dict_lossy_Table,
example_dict_metadata_Table,
example_dict_Table,
)

# pool example_dict-s
example_dict = dict()
example_dict.update(example_dict_Proba)
example_dict.update(example_dict_Table)
@lru_cache(maxsize=1)
def generate_example_dicts(soft_deps="present"):
"""Generate example dicts using lookup."""
from skbase.utils.dependencies import _check_estimator_deps

from skpro.datatypes._base import BaseExample
from skpro.utils.retrieval import _all_classes

classes = _all_classes("skpro.datatypes")
classes = [x[1] for x in classes]
classes = [x for x in classes if issubclass(x, BaseExample)]
classes = [x for x in classes if not x.__name__.startswith("Base")]

example_dict_lossy = dict()
example_dict_lossy.update(example_dict_lossy_Proba)
example_dict_lossy.update(example_dict_lossy_Table)
# subset only to data types with soft dependencies present
if soft_deps == "present":
classes = [x for x in classes if _check_estimator_deps(x, severity="none")]

example_dict_metadata = dict()
example_dict_metadata.update(example_dict_metadata_Proba)
example_dict_metadata.update(example_dict_metadata_Table)
example_dict = dict()
example_dict_lossy = dict()
example_dict_metadata = dict()
for cls in classes:
k = cls()
key = k._get_key()
key_meta = (key[1], key[2])
example_dict[key] = k
example_dict_lossy[key] = k.get_class_tags().get("lossy", False)
example_dict_metadata[key_meta] = k.get_class_tags().get("metadata", {})

return example_dict, example_dict_lossy, example_dict_metadata


def get_examples(
Expand Down Expand Up @@ -79,6 +88,8 @@ def get_examples(
if as_scitype is None:
as_scitype = mtype_to_scitype(mtype)

example_dict, example_dict_lossy, example_dict_metadata = generate_example_dicts()

# retrieve all keys that match the query
exkeys = example_dict.keys()
keys = [k for k in exkeys if k[0] == mtype and k[1] == as_scitype]
Expand All @@ -88,14 +99,14 @@ def get_examples(

for k in keys:
if return_lossy:
fixtures[k[2]] = (example_dict.get(k), example_dict_lossy.get(k))
fixtures[k[2]] = (example_dict.get(k).build(), example_dict_lossy.get(k))
elif return_metadata:
fixtures[k[2]] = (
example_dict.get(k),
example_dict.get(k).build(),
example_dict_lossy.get(k),
example_dict_metadata.get((k[1], k[2])),
)
else:
fixtures[k[2]] = example_dict.get(k)
fixtures[k[2]] = example_dict.get(k).build()

return fixtures
10 changes: 0 additions & 10 deletions skpro/datatypes/_proba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,11 @@

from skpro.datatypes._proba._check import check_dict as check_dict_Proba
from skpro.datatypes._proba._convert import convert_dict as convert_dict_Proba
from skpro.datatypes._proba._examples import example_dict as example_dict_Proba
from skpro.datatypes._proba._examples import (
example_dict_lossy as example_dict_lossy_Proba,
)
from skpro.datatypes._proba._examples import (
example_dict_metadata as example_dict_metadata_Proba,
)
from skpro.datatypes._proba._registry import MTYPE_LIST_PROBA, MTYPE_REGISTER_PROBA

__all__ = [
"check_dict_Proba",
"convert_dict_Proba",
"MTYPE_LIST_PROBA",
"MTYPE_REGISTER_PROBA",
"example_dict_Proba",
"example_dict_lossy_Proba",
"example_dict_metadata_Proba",
]
142 changes: 95 additions & 47 deletions skpro/datatypes/_proba/_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,64 +31,112 @@
import numpy as np
import pandas as pd

example_dict = dict()
example_dict_lossy = dict()
example_dict_metadata = dict()
from skpro.datatypes._base import BaseExample

###
# example 0: univariate

pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_q.columns = pd.MultiIndex.from_product([["foo"], [0.2, 0.6]])

# we need to use this due to numerical inaccuracies from the binary based representation
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)
class _ProbaUniv(BaseExample):
_tags = {
"scitype": "Proba",
"index": 0,
"metadata": {
"is_univariate": True,
"is_empty": False,
"has_nans": False,
},
}

example_dict[("pred_quantiles", "Proba", 0)] = pred_q
example_dict_lossy[("pred_quantiles", "Proba", 0)] = False

pred_int = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_int.columns = pd.MultiIndex.from_tuples(
[("foo", 0.6, "lower"), ("foo", pseudo_0_2, "upper")]
)
class _ProbaUnivPredQ(_ProbaUniv):
_tags = {
"mtype": "pred_quantiles",
"python_dependencies": None,
"lossy": False,
}

example_dict[("pred_interval", "Proba", 0)] = pred_int
example_dict_lossy[("pred_interval", "Proba", 0)] = False
def build(self):
pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_q.columns = pd.MultiIndex.from_product([["foo"], [0.2, 0.6]])

return pred_q


class _ProbaUnivPredInt(_ProbaUniv):
_tags = {
"mtype": "pred_interval",
"python_dependencies": None,
"lossy": False,
}

def build(self):
# we need to use this due to numerical inaccuracies
# from the binary based representation
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)

pred_int = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_int.columns = pd.MultiIndex.from_tuples(
[("foo", 0.6, "lower"), ("foo", pseudo_0_2, "upper")]
)

return pred_int

example_dict_metadata[("Proba", 0)] = {
"is_univariate": True,
"is_empty": False,
"has_nans": False,
}

###
# example 1: multi

pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]})
pred_q.columns = pd.MultiIndex.from_product([["foo", "bar"], [0.2, 0.6]])

example_dict[("pred_quantiles", "Proba", 1)] = pred_q
example_dict_lossy[("pred_quantiles", "Proba", 1)] = False

pred_int = pd.DataFrame(
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
)
pred_int.columns = pd.MultiIndex.from_tuples(
[
("foo", 0.6, "lower"),
("foo", pseudo_0_2, "upper"),
("bar", 0.6, "lower"),
("bar", pseudo_0_2, "upper"),
]
)

example_dict[("pred_interval", "Proba", 1)] = pred_int
example_dict_lossy[("pred_interval", "Proba", 1)] = False


example_dict_metadata[("Proba", 1)] = {
"is_univariate": False,
"is_empty": False,
"has_nans": False,
}

class _ProbaMulti(BaseExample):
_tags = {
"scitype": "Proba",
"index": 1,
"metadata": {
"is_univariate": False,
"is_empty": False,
"has_nans": False,
},
}


class _ProbaMultiPredQ(_ProbaMulti):
_tags = {
"mtype": "pred_quantiles",
"python_dependencies": None,
"lossy": False,
}

def build(self):
pred_q = pd.DataFrame(
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
)
pred_q.columns = pd.MultiIndex.from_product([["foo", "bar"], [0.2, 0.6]])

return pred_q


class _ProbaMultiPredInt(_ProbaMulti):
_tags = {
"mtype": "pred_interval",
"python_dependencies": None,
"lossy": False,
}

def build(self):
# we need to use this due to numerical inaccuracies
# from the binary based representation
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)

pred_int = pd.DataFrame(
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
)
pred_int.columns = pd.MultiIndex.from_tuples(
[
("foo", 0.6, "lower"),
("foo", pseudo_0_2, "upper"),
("bar", 0.6, "lower"),
("bar", pseudo_0_2, "upper"),
]
)

return pred_int
10 changes: 0 additions & 10 deletions skpro/datatypes/_table/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
"""Module exports: Series type checkers, converters and mtype inference."""

from skpro.datatypes._table._convert import convert_dict as convert_dict_Table
from skpro.datatypes._table._examples import example_dict as example_dict_Table
from skpro.datatypes._table._examples import (
example_dict_lossy as example_dict_lossy_Table,
)
from skpro.datatypes._table._examples import (
example_dict_metadata as example_dict_metadata_Table,
)
from skpro.datatypes._table._registry import MTYPE_LIST_TABLE, MTYPE_REGISTER_TABLE

__all__ = [
"convert_dict_Table",
"MTYPE_LIST_TABLE",
"MTYPE_REGISTER_TABLE",
"example_dict_Table",
"example_dict_lossy_Table",
"example_dict_metadata_Table",
]
Loading
Loading