Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/dl_api_client/dl_api_client/dsmaker/api/dataset_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,16 @@ def load_dataset(self, dataset: Dataset) -> HttpDatasetApiResponse:
dataset=dataset,
)

def export_dataset(self, dataset: Dataset, data: dict, bi_headers: dict) -> HttpDatasetApiResponse:
response = self._request(f"/api/v1/datasets/export/{dataset.id}", method="post", data=data, headers=bi_headers)
def export_dataset(self, dataset: Dataset, data: dict, headers: dict) -> HttpDatasetApiResponse:
response = self._request(f"/api/v1/datasets/export/{dataset.id}", method="post", data=data, headers=headers)
return HttpDatasetApiResponse(
json=response.json,
status_code=response.status_code,
dataset=None,
)

def import_dataset(self, data: dict, bi_headers: dict) -> HttpDatasetApiResponse:
response = self._request("/api/v1/datasets/import", method="post", data=data, headers=bi_headers)
def import_dataset(self, data: dict, headers: dict) -> HttpDatasetApiResponse:
response = self._request("/api/v1/datasets/import", method="post", data=data, headers=headers)
return HttpDatasetApiResponse(
json=response.json,
status_code=response.status_code,
Expand Down
30 changes: 21 additions & 9 deletions lib/dl_api_lib/dl_api_lib/app/control_api/resources/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
log_dataset_field_stats,
)
from dl_api_lib.enums import USPermissionKind
from dl_api_lib.schemas import main as dl_api_main_schemas
import dl_api_lib.schemas.main
import dl_api_lib.schemas.data
import dl_api_lib.schemas.dataset_base
import dl_api_lib.schemas.validation
Expand Down Expand Up @@ -85,9 +85,9 @@ def generate_dataset_location(cls, body: dict) -> EntryLocation:
@put_to_request_context(endpoint_code="DatasetCreate")
@schematic_request(
ns=ns,
body=dl_api_main_schemas.CreateDatasetSchema(),
body=dl_api_lib.schemas.main.CreateDatasetSchema(),
responses={
200: ("Success", dl_api_main_schemas.CreateDatasetResponseSchema()),
200: ("Success", dl_api_lib.schemas.main.CreateDatasetResponseSchema()),
},
)
def post(self, body: dict) -> dict:
Expand Down Expand Up @@ -310,7 +310,6 @@ def post(self, dataset_id: str, body: dict) -> dict:

ds, _ = self.get_dataset(dataset_id=dataset_id, body={})
ds_dict = ds.as_dict()
us_manager.load_dependencies(ds)
ds_dict.update(
self.make_dataset_response_data(
dataset=ds, us_entry_buffer=us_manager.get_entry_buffer(), conn_id_mapping=body["id_mapping"]
Expand All @@ -322,7 +321,6 @@ def post(self, dataset_id: str, body: dict) -> dict:
ds_dict["dataset"]["name"] = dl_loc.entry_name

ds_dict["dataset"]["revision_id"] = None
del ds_dict["dataset"]["rls"]

notifications = []
localizer = self.get_service_registry().get_localizer()
Expand Down Expand Up @@ -350,15 +348,29 @@ def generate_dataset_location(cls, body: dict) -> EntryLocation:

@classmethod
def replace_conn_ids(cls, data: dict, conn_id_mapping: dict) -> None:
for sources in data["dataset"]["sources"]:
sources["connection_id"] = conn_id_mapping[sources["connection_id"]]
if "sources" not in data["dataset"]:
LOGGER.info("There are no sources in the passed dataset data, so nothing to replace")
return

for source in data["dataset"]["sources"]:
assert isinstance(source, dict)
fake_conn_id = source["connection_id"]
if fake_conn_id not in conn_id_mapping:
LOGGER.info(
'Can not find "%s" in conn id mapping for source with id %s, going to replace it with None',
fake_conn_id,
source.get("id"),
)
source["connection_id"] = None
else:
source["connection_id"] = conn_id_mapping[fake_conn_id]

@put_to_request_context(endpoint_code="DatasetImport")
@schematic_request(
ns=ns,
body=dl_api_main_schemas.DatasetImportRequestSchema(),
body=dl_api_lib.schemas.main.DatasetImportRequestSchema(),
responses={
200: ("Success", dl_api_main_schemas.ImportResponseSchema()),
200: ("Success", dl_api_lib.schemas.main.ImportResponseSchema()),
},
)
def post(self, body: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,16 @@ def dump_dataset_data(
origin_dsrc = dsrc_coll.get_strict(role=DataSourceRole.origin)
connection_id = dsrc_coll.get_connection_id(DataSourceRole.origin)
if conn_id_mapping is not None:
try:
if connection_id not in conn_id_mapping:
LOGGER.info(
'Can not find "%s" in conn id mapping for source with id %s, going to replace it with None',
connection_id,
source_id,
)
connection_id = None
else:
connection_id = conn_id_mapping[connection_id]
except KeyError:
raise DatasetExportError(f"Error to find {connection_id} in connection_id_mapping")

sources.append(
{
"id": source_id,
Expand Down
2 changes: 1 addition & 1 deletion lib/dl_api_lib/dl_api_lib/request_model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class DeleteObligatoryFilterAction(ObligatoryFilterActionBase):

@attr.s(frozen=True, kw_only=True, auto_attribs=True)
class ReplaceConnection:
id: str
id: str | None
new_id: str


Expand Down
9 changes: 8 additions & 1 deletion lib/dl_api_lib/dl_api_lib/schemas/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ class DashSQLRequestSchema(BaseSchema):


class IdMappingContentSchema(BaseSchema):
id_mapping = ma_fields.Dict(ma_fields.String(), ma_fields.String(), required=True)
id_mapping = ma_fields.Dict(
ma_fields.String(allow_none=True),
ma_fields.String(allow_none=True),
required=True,
)


class DatasetExportRequestSchema(IdMappingContentSchema):
Expand All @@ -185,6 +189,9 @@ class NotificationContentSchema(BaseSchema):

class DatasetExportResponseSchema(BaseSchema):
class DatasetContentInternalExportSchema(DatasetContentInternalSchema):
class Meta(DatasetContentInternalSchema.Meta):
exclude = ("rls",) # not exporting rls at all, only rls2

name = ma_fields.String()

dataset = ma_fields.Nested(DatasetContentInternalExportSchema)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from dl_api_client.dsmaker.api.data_api import SyncHttpDataApiV1
from dl_api_client.dsmaker.api.dataset_api import SyncHttpDatasetApiV1
from dl_api_commons.base_models import DLHeadersCommon
from dl_api_lib.app_settings import ControlApiAppSettings
from dl_core.us_manager.us_manager_sync import SyncUSManager
import pytest

from dl_api_client.dsmaker.primitives import (
Dataset,
RequestLegendItem,
RequestLegendItemRef,
ResultField,
Expand Down Expand Up @@ -163,7 +169,52 @@ def test_dataset_with_deleted_connection(self, saved_dataset, saved_connection_i
result_resp = data_api.get_result(dataset=saved_dataset, fields=[saved_dataset.result_schema[0]], fail_ok=True)
assert result_resp.status_code == 400
assert result_resp.bi_status_code == "ERR.DS_API.REFERENCED_ENTRY_NOT_FOUND"
assert result_resp.json["message"] == f"Referenced connection {saved_connection_id} was deleted"
assert result_resp.json["message"] == f"Referenced connection does not exist (connection id: {saved_connection_id})"

def test_dataset_with_null_connection(
self,
saved_dataset: Dataset,
control_api: SyncHttpDatasetApiV1,
control_api_app_settings: ControlApiAppSettings,
data_api: SyncHttpDataApiV1,
sync_us_manager: SyncUSManager,
) -> None:
# the only intended way to create such a dataset is via export-import, so let's create it that way
export_import_headers = {
DLHeadersCommon.US_MASTER_TOKEN.value: control_api_app_settings.US_MASTER_TOKEN,
}
export_req_data = {"id_mapping": {}}
export_resp = control_api.export_dataset(dataset=saved_dataset, data=export_req_data, headers=export_import_headers)
assert export_resp.status_code == 200, export_resp.json
assert export_resp.json["dataset"]["sources"][0]["connection_id"] == None

import_req_data: dict = {
"id_mapping": {},
"data": {"workbook_id": None, "dataset": export_resp.json["dataset"]},
}
import_resp = control_api.import_dataset(data=import_req_data, headers=export_import_headers)
assert import_resp.status_code == 200, f"{import_resp.json} vs {export_resp.json}"

ds = control_api.serial_adapter.load_dataset_from_response_body(Dataset(), export_resp.json)

query = data_api.serial_adapter.make_req_data_get_result(
dataset=None,
fields=[ds.result_schema[0]],
)
headers = {
"Content-Type": "application/json",
}
result_resp = data_api.get_response_for_dataset_result(
dataset_id=import_resp.json["id"],
raw_body=query,
headers=headers,
)

assert result_resp.status_code == 400, result_resp.json
assert result_resp.json["code"] == "ERR.DS_API.REFERENCED_ENTRY_NOT_FOUND"
assert result_resp.json["message"] == "Referenced connection does not exist (connection id: empty)"

control_api.delete_dataset(dataset_id=import_resp.json["id"])


class TestDashSQLErrors(DefaultApiTestBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from dl_core.us_manager.us_manager_sync import SyncUSManager
from dl_testing.regulated_test import RegulatedTestCase

import pytest


class DefaultConnectorDatasetTestSuite(DatasetTestBase, RegulatedTestCase, metaclass=abc.ABCMeta):
def check_basic_dataset(self, ds: Dataset) -> None:
Expand Down Expand Up @@ -88,35 +90,40 @@ def test_replace_connection(
assert dataset.sources
assert all(source.connection_id == new_connection_id for source in dataset.sources)

def test_export_import_dataset(
@pytest.fixture(scope="function")
def export_import_headers(self, control_api_app_settings: ControlApiAppSettings) -> dict[str, str]:
return {
DLHeadersCommon.US_MASTER_TOKEN.value: control_api_app_settings.US_MASTER_TOKEN,
}

def test_export_import_invalid_schema(
self,
control_api: SyncHttpDatasetApiV1,
saved_connection_id: str,
saved_dataset: Dataset,
sync_us_manager: SyncUSManager,
control_api_app_settings: ControlApiAppSettings,
bi_headers: Optional[dict[str, str]],
) -> None:
us_master_token = control_api_app_settings.US_MASTER_TOKEN
assert us_master_token
export_import_headers: dict[str, str],
):
export_data = dict()
export_resp = control_api.export_dataset(dataset=saved_dataset, data=export_data, headers=export_import_headers)
assert export_resp.status_code == 400, export_resp.json

if bi_headers is None:
bi_headers = dict()
import_data = dict()
import_resp = control_api.import_dataset(data=import_data, headers=export_import_headers)
assert import_resp.status_code == 400, import_resp.json

bi_headers[DLHeadersCommon.US_MASTER_TOKEN.value] = us_master_token

# test invalid schema
export_data: dict = dict()
export_resp = control_api.export_dataset(saved_dataset, data=export_data, bi_headers=bi_headers)
assert export_resp.status_code == 400

export_data = {"id_mapping": {}}
export_resp = control_api.export_dataset(saved_dataset, data=export_data, bi_headers=bi_headers)
assert export_resp.status_code == 400
import_data = {"id_mapping": {}}
import_resp = control_api.import_dataset(data=import_data, headers=export_import_headers)
assert import_resp.status_code == 400, import_resp.json

def test_export_import_dataset(
self,
control_api: SyncHttpDatasetApiV1,
saved_connection_id: str,
saved_dataset: Dataset,
export_import_headers: dict[str, str],
) -> None:
# test common export
export_data = {"id_mapping": {saved_connection_id: "conn_id_1"}}
export_resp = control_api.export_dataset(saved_dataset, data=export_data, bi_headers=bi_headers)
export_resp = control_api.export_dataset(dataset=saved_dataset, data=export_data, headers=export_import_headers)
assert export_resp.status_code == 200
assert export_resp.json["dataset"]["sources"][0]["connection_id"] == "conn_id_1"

Expand All @@ -125,7 +132,33 @@ def test_export_import_dataset(
"id_mapping": {"conn_id_1": saved_connection_id},
"data": {"workbook_id": None, "dataset": export_resp.json["dataset"]},
}
import_resp = control_api.import_dataset(data=import_data, bi_headers=bi_headers)
assert import_resp.status_code == 200, import_resp.json["dataset"] != export_resp.json["dataset"]
import_resp = control_api.import_dataset(data=import_data, headers=export_import_headers)
assert import_resp.status_code == 200, f"{import_resp.json} vs {export_resp.json}"

control_api.delete_dataset(dataset_id=import_resp.json["id"])

def test_export_import_dataset_with_no_connection(
self,
control_api: SyncHttpDatasetApiV1,
saved_connection_id: str,
sync_us_manager: SyncUSManager,
saved_dataset: Dataset,
export_import_headers: dict[str, str],
):
sync_us_manager.delete(sync_us_manager.get_by_id(saved_connection_id))

# export with no connection
export_req_data = {"id_mapping": {}}
export_resp = control_api.export_dataset(dataset=saved_dataset, data=export_req_data, headers=export_import_headers)
assert export_resp.status_code == 200, export_resp.json
assert export_resp.json["dataset"]["sources"][0]["connection_id"] == None

# import with no connection
import_req_data: dict = {
"id_mapping": {},
"data": {"workbook_id": None, "dataset": export_resp.json["dataset"]},
}
import_resp = control_api.import_dataset(data=import_req_data, headers=export_import_headers)
assert import_resp.status_code == 200, f"{import_resp.json} vs {export_resp.json}"

control_api.delete_dataset(dataset_id=import_resp.json["id"])
11 changes: 1 addition & 10 deletions lib/dl_core/dl_core/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,8 @@ class DefaultConnectionRef(ConnectionRef):
conn_id: str = attr.ib(kw_only=True)


@attr.s(frozen=True, slots=True)
class InternalMaterializationConnectionRef(ConnectionRef):
pass


def connection_ref_from_id(connection_id: Optional[str]) -> ConnectionRef:
if connection_id is None:
# TODO REMOVE: some sample source code still relies on mat con ref
return InternalMaterializationConnectionRef()
else:
return DefaultConnectionRef(conn_id=connection_id)
return DefaultConnectionRef(conn_id=connection_id)


@attr.s()
Expand Down
11 changes: 0 additions & 11 deletions lib/dl_core/dl_core/data_source/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,6 @@ def spec(self) -> DataSourceSpec:
return self._spec

def _validate_connection(self) -> None:
if self._connection is not None and self._spec.connection_ref is None: # type: ignore # TODO: fix
# TODO CONSIDER: extraction of connection ref
pass
elif self._spec.connection_ref is not None and self._connection is None: # type: ignore # TODO: fix
pass
else:
raise ValueError(
f"Unexpected combination of 'connection' and 'connection_ref':"
f" {self._connection} and {self._spec.connection_ref}" # type: ignore # TODO: fix no attribute
)

if self._connection is not None:
self._validate_connection_cls(self._connection)

Expand Down
3 changes: 0 additions & 3 deletions lib/dl_core/dl_core/data_source/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from dl_core.base_models import (
ConnectionRef,
DefaultConnectionRef,
InternalMaterializationConnectionRef,
)
from dl_core.connection_executors.sync_base import SyncConnExecutorBase
import dl_core.data_source.base as base
Expand Down Expand Up @@ -71,8 +70,6 @@ def get_connection_id(self, role: DataSourceRole | None = None) -> str | None:
conn_ref = self.get_strict(role=role).connection_ref
if isinstance(conn_ref, DefaultConnectionRef):
return conn_ref.conn_id
elif isinstance(conn_ref, InternalMaterializationConnectionRef):
return None
else:
raise TypeError(f"Unexpected conn_ref class: {type(conn_ref)}")

Expand Down
2 changes: 1 addition & 1 deletion lib/dl_core/dl_core/data_source_spec/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def collect_links(self) -> dict[str, str]:
assert self.origin is not None
result: dict[str, str] = {}
connection_ref = self.origin.connection_ref
if isinstance(connection_ref, DefaultConnectionRef):
if isinstance(connection_ref, DefaultConnectionRef) and connection_ref.conn_id is not None:
result[self.id] = connection_ref.conn_id
return result
3 changes: 2 additions & 1 deletion lib/dl_core/dl_core/us_manager/local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def get_entry(self, entry_id: ConnectionRef) -> USEntry:
if isinstance(entry, BrokenUSLink):
if isinstance(entry.reference, DefaultConnectionRef):
if entry.error_kind == BrokenUSLinkErrorKind.NOT_FOUND:
raise exc.ReferencedUSEntryNotFound(f"Referenced connection {entry.reference.conn_id} was deleted")
conn_id = entry.reference.conn_id if entry.reference.conn_id is not None else "empty"
raise exc.ReferencedUSEntryNotFound(f"Referenced connection does not exist (connection id: {conn_id})")
elif entry.error_kind == BrokenUSLinkErrorKind.ACCESS_DENIED:
raise exc.ReferencedUSEntryAccessDenied(
f"Referenced connection {entry.reference.conn_id} cannot be loaded: access denied",
Expand Down
Loading
Loading