Skip to content

Commit

Permalink
Serializable SecretFetcher (#32)
Browse files Browse the repository at this point in the history
* Make SecretFetcher serializable

* Update comments

* Update version (dev)

* Update version
  • Loading branch information
Peilun-Li authored Dec 8, 2022
1 parent 911ab0c commit 2c26391
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 5 deletions.
7 changes: 7 additions & 0 deletions datasets/dataset_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ class StorageOptions:
def to_json(self) -> dict:
ret = dataclasses.asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None})
ret["type"] = type(self).__name__

# Overwrite for nested dataclass object e.g. SecretFetcher with to_json defined
for field in dataclasses.fields(self):
field_obj = getattr(self, field.name)
if getattr(field_obj, "to_json", None):
ret[field.name] = field_obj.to_json()

return ret


Expand Down
8 changes: 7 additions & 1 deletion datasets/metaflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datasets.context import Context
from datasets.dataset_plugin import DatasetPlugin, StorageOptions
from datasets.mode import Mode
from datasets.utils.secret_fetcher import SecretFetcher


class _DatasetTypeClass(ParamType):
Expand Down Expand Up @@ -73,12 +74,17 @@ def get_storage_subclasses(cls: object) -> list[Type[StorageOptions]]:

return ret

def object_hook(self, obj: dict) -> Union[_DatasetParams, StorageOptions]:
def object_hook(self, obj: dict) -> Union[_DatasetParams, StorageOptions, SecretFetcher]:
type = obj.get("type")
if type:
# remove "type"
del obj["type"]

# SecretFetcher type
if type == "SecretFetcher":
return SecretFetcher(**obj)

# StorageOptions type
mapping: Dict[str, Type[StorageOptions]] = {
f.__name__.lower(): f for f in _DatasetParamsDecoder.get_storage_subclasses(StorageOptions)
}
Expand Down
71 changes: 68 additions & 3 deletions datasets/tests/test_dataset_plugin.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import pandas as pd
import pytest

from datasets import Context, DataFrameType, Dataset
from datasets import Context, DataFrameType, Dataset, Mode
from datasets._typing import ColumnNames
from datasets.dataset_plugin import DatasetPlugin, StorageOptions
from datasets.metaflow import _DatasetTypeClass
from datasets.plugins import HiveDataset
from datasets.plugins.batch.hive_dataset import HiveOptions
from datasets.tests.conftest import TestExecutor
from datasets.utils.secret_fetcher import SecretFetcher


class _TestPlugin(DatasetPlugin):
# We'll need to inherit dict too to make this class json serializable
class _TestPlugin(DatasetPlugin, dict):
def __init__(
self,
name: str,
logical_key: Optional[str] = None,
columns: Optional[ColumnNames] = None,
run_id: Optional[str] = None,
run_time: Optional[int] = None,
mode: Union[Mode, str] = Mode.READ,
options: Optional[StorageOptions] = None,
):
super().__init__(name, logical_key, columns, run_id, run_time, mode, options)

# Init the dict to have fields included for json.dumps
dict.__init__(self, name=name, mode=mode.name, options=options)

def write(self, data: DataFrameType, **kwargs):
raise NotImplementedError()

Expand Down Expand Up @@ -83,6 +101,21 @@ def __init__(self, name: str, options: FeeOnlineDatasetOptions, **kwargs):
super(FeeOnlineDatasetPluginTest, self).__init__(name=name, options=options, **kwargs)


@dataclass
class SecretDatasetTestOptions(StorageOptions):
a: Optional[str] = None
secret: Optional[SecretFetcher] = SecretFetcher(env_var="test1")


@DatasetPlugin.register(context=Context.BATCH, options_type=SecretDatasetTestOptions)
class SecretDatasetPluginTest(_TestPlugin):
def __init__(
self, name: str, options: SecretDatasetTestOptions, mode: Union[Mode, str] = Mode.READ, **kwargs
):
self.secret = options.secret
super(SecretDatasetPluginTest, self).__init__(name=name, options=options, mode=mode, **kwargs)


def test_dataset_factory_latency():
import datetime

Expand Down Expand Up @@ -122,6 +155,18 @@ def test_dataset_factory_constructor():
assert dataset.test_fee == "TestFee"
assert isinstance(dataset, FeeOnlineDatasetPluginTest)

dataset = Dataset("TestSecret", options=SecretDatasetTestOptions(a="Ta"))
assert dataset.name == "TestSecret"
assert dataset.secret.env_var == "test1"
assert isinstance(dataset, SecretDatasetPluginTest)

dataset = Dataset(
"TestSecret", options=SecretDatasetTestOptions(a="Ta", secret=SecretFetcher(env_var="test2"))
)
assert dataset.name == "TestSecret"
assert dataset.secret.env_var == "test2"
assert isinstance(dataset, SecretDatasetPluginTest)


def test_dataset_json_constructor():
dataset = _DatasetTypeClass().convert('{"name": "FooName"}', None, None)
Expand Down Expand Up @@ -153,6 +198,26 @@ def test_dataset_json_constructor():
assert dataset.test_fee == "TestFee"
assert isinstance(dataset, FeeOnlineDatasetPluginTest)

dataset = _DatasetTypeClass().convert(
'{"name": "TestSecret", "options":{"type": "SecretDatasetTestOptions", "a": "Ta"}}',
None,
None,
)
assert dataset.name == "TestSecret"
assert dataset.secret.env_var == "test1"
assert isinstance(dataset, SecretDatasetPluginTest)

dataset = _DatasetTypeClass().convert(
'{"name": "TestSecret",'
' "options":{"type": "SecretDatasetTestOptions",'
' "a": "Ta", "secret": {"type": "SecretFetcher", "env_var": "test2"}}}',
None,
None,
)
assert dataset.name == "TestSecret"
assert dataset.secret.env_var == "test2"
assert isinstance(dataset, SecretDatasetPluginTest)


def test_dataset_factory_constructor_unhappy():
@dataclass
Expand Down
21 changes: 21 additions & 0 deletions datasets/tests/test_metaflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from datasets.plugins.batch.batch_base_plugin import BatchOptions
from datasets.plugins.batch.batch_dataset import BatchDataset
from datasets.plugins.batch.hive_dataset import HiveOptions
from datasets.tests.test_dataset_plugin import (
SecretDatasetPluginTest,
SecretDatasetTestOptions,
)
from datasets.utils.secret_fetcher import SecretFetcher


def test_dataset_dumps_load():
Expand All @@ -31,6 +36,22 @@ def test_dataset_dumps_load():
assert dataset == dataset2


def test_dataset_secret_dumps_load():
dataset = Dataset(
name="TestSecret",
mode=Mode.READ_WRITE,
options=SecretDatasetTestOptions(secret=SecretFetcher(env_var="test2")),
)

json_value = json.dumps(dataset)
dataset2 = _DatasetTypeClass().convert(json_value, None, None)

assert dataset2.secret.env_var == "test2"
assert dataset2.mode == Mode.READ_WRITE
assert isinstance(dataset2, SecretDatasetPluginTest)
assert dataset == dataset2


def test_dataset_type_class():
json_value = (
'{"name": "HiDataset", "mode": "READ_WRITE", "options":{"type":"HiveOptions", "path":"/foo_hive"}}'
Expand Down
6 changes: 6 additions & 0 deletions datasets/utils/secret_fetcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import dataclasses
import importlib
import json
import logging
Expand Down Expand Up @@ -106,6 +107,11 @@ def value(self) -> SECRET_RETURN_TYPE:
if self.raw_secret:
return self._fetch_raw_secret()

def to_json(self) -> dict:
ret = dataclasses.asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None})
ret["type"] = type(self).__name__
return ret

def _fetch_kubernetes_secret(self) -> SECRET_RETURN_TYPE:
kubernetes = try_import_kubernetes()
# Try to fetch from cache first
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zdatasets"
version = "0.2.2"
version = "0.2.3"
description = "Dataset SDK for consistent read/write [batch, online, streaming] data."
classifiers = [
"Development Status :: 2 - Pre-Alpha",
Expand Down

0 comments on commit 2c26391

Please sign in to comment.