Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import flask

from dl_api_commons.base_models import (
AuthData,
NoAuthData,
TenantCommon,
TenantDef,
Expand All @@ -22,15 +21,14 @@ class TrustAuthService:
fake_user_id: Optional[str] = None
fake_user_name: Optional[str] = None
fake_tenant: Optional[TenantDef] = None
fake_auth_data: Optional[AuthData] = None

def _before_request(self) -> None:
fake_user_id = self.fake_user_id
fake_user_name = self.fake_user_name
fake_tenant = self.fake_tenant

temp_rci = ReqCtxInfoMiddleware.get_temp_rci().clone(
auth_data=NoAuthData() if self.fake_auth_data is None else self.fake_auth_data,
auth_data=NoAuthData(),
tenant=TenantCommon() if fake_tenant is None else fake_tenant,
)
if fake_user_id is not None:
Expand Down
6 changes: 1 addition & 5 deletions lib/dl_api_lib/dl_api_lib/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
import attr
import pydantic

from dl_api_commons.base_models import (
AuthData,
TenantDef,
)
from dl_api_commons.base_models import TenantDef
from dl_api_lib.connector_availability.base import ConnectorAvailabilityConfig
from dl_configs.crypto_keys import CryptoKeysConfig
from dl_configs.enums import RedisMode
Expand Down Expand Up @@ -309,7 +306,6 @@ def jaeger_service_name(self) -> str:
class ControlApiAppTestingsSettings:
us_auth_mode_override: Optional[USAuthMode] = attr.ib(default=None)
fake_tenant: Optional[TenantDef] = attr.ib(default=None)
fake_auth_data: Optional[AuthData] = attr.ib(default=None)


@attr.s(frozen=True)
Expand Down
1 change: 0 additions & 1 deletion lib/dl_api_lib_testing/dl_api_lib_testing/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def set_up_environment(
fake_user_id=TEST_USER_ID,
fake_user_name=TEST_USER_NAME,
fake_tenant=None if testing_app_settings is None else testing_app_settings.fake_tenant,
fake_auth_data=None if testing_app_settings is None else testing_app_settings.fake_auth_data,
).set_up(app)

us_auth_mode_override = None if testing_app_settings is None else testing_app_settings.us_auth_mode_override
Expand Down
11 changes: 1 addition & 10 deletions lib/dl_api_lib_testing/dl_api_lib_testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from dl_api_client.dsmaker.api.dataset_api import SyncHttpDatasetApiV1
from dl_api_client.dsmaker.api.http_sync_base import SyncHttpClientBase
from dl_api_commons.base_models import (
AuthData,
RequestContextInfo,
TenantCommon,
TenantDef,
Expand Down Expand Up @@ -167,27 +166,19 @@ def control_api_app_factory(self, control_api_app_settings: ControlApiAppSetting
def fake_tenant(self) -> TenantDef:
return TenantCommon()

@pytest.fixture(scope="function")
def fake_auth_data(self) -> Optional[AuthData]:
return None

@pytest.fixture(scope="function")
def control_api_app(
self,
environment_readiness: None,
control_api_app_factory: ControlApiAppFactory,
connectors_settings: dict[ConnectionType, ConnectorSettingsBase],
fake_tenant: TenantDef,
fake_auth_data: Optional[AuthData],
) -> Generator[Flask, None, None]:
"""Session-wide test `Flask` application."""

app = control_api_app_factory.create_app(
connectors_settings=connectors_settings,
testing_app_settings=ControlApiAppTestingsSettings(
fake_tenant=fake_tenant,
fake_auth_data=fake_auth_data,
),
testing_app_settings=ControlApiAppTestingsSettings(fake_tenant=fake_tenant),
close_loop_after_request=False,
)

Expand Down
26 changes: 3 additions & 23 deletions lib/dl_api_lib_testing/dl_api_lib_testing/connection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import contextlib
import json
from typing import (
Any,
ClassVar,
Generator,
Optional,
)
import uuid

import attr
import pytest

from dl_api_client.dsmaker.api.http_sync_base import SyncHttpClientBase
Expand All @@ -23,21 +22,6 @@
from dl_core_testing.testcases.service_base import DbServiceFixtureTextClass


@attr.s(kw_only=True, frozen=True)
class EditConnectionParamsCase:
params: dict[str, Any] = attr.ib(factory=dict) # params that will be sent in the edit request
load_only_field_names: list[str] = attr.ib(
factory=list
) # fields that will be sent in the edit request, but will not be present in the response (passwords, etc)
additional_fields_to_check: dict[str, Any] = attr.ib(
factory=dict
) # fields that must be present in the response after the edit in addition to params. Values from params can be overridden here.
check_absence_of_fields: list[str] = attr.ib(
factory=list
) # fields that must be absent or `None` in the response after the edit
supports_connection_test: bool = attr.ib(default=True)


class ConnectionTestBase(ApiTestBase, DbServiceFixtureTextClass):
conn_type: ClassVar[ConnectionType]

Expand All @@ -48,16 +32,12 @@ class ConnectionTestBase(ApiTestBase, DbServiceFixtureTextClass):
def connection_params(self) -> dict:
raise NotImplementedError

@pytest.fixture(scope="class")
def edit_connection_params_case(self) -> EditConnectionParamsCase | None:
return None

@pytest.fixture(scope="function")
def saved_connection_id(
self,
control_api_sync_client: SyncHttpClientBase,
connection_params: dict,
bi_headers: dict[str, str] | None = None,
bi_headers: Optional[dict[str, str]],
) -> Generator[str, None, None]:
with self.create_connection(
control_api_sync_client=control_api_sync_client,
Expand All @@ -71,7 +51,7 @@ def create_connection(
self,
control_api_sync_client: SyncHttpClientBase,
connection_params: dict,
bi_headers: dict[str, str] | None = None,
bi_headers: Optional[dict[str, str]],
) -> Generator[str, None, None]:
data = dict(
name=f"{self.conn_type.name} conn {uuid.uuid4()}",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import json
from typing import Optional
import uuid

import pytest

from dl_api_client.dsmaker.api.http_sync_base import SyncHttpClientBase
from dl_api_lib.app_settings import ControlApiAppSettings
from dl_api_lib.schemas.connection import GenericConnectionSchema
from dl_api_lib_testing.connection_base import (
ConnectionTestBase,
EditConnectionParamsCase,
)
from dl_api_lib_testing.connection_base import ConnectionTestBase
from dl_constants.api_constants import DLHeadersCommon
from dl_core.connectors.base.export_import import is_export_import_allowed
from dl_core.us_connection_base import ConnectionBase
Expand All @@ -22,7 +20,7 @@ def test_create_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
) -> None:
assert saved_connection_id
resp = control_api_sync_client.get(
Expand All @@ -31,66 +29,11 @@ def test_create_connection(
)
assert resp.status_code == 200, resp.json

def test_edit_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
edit_connection_params_case: EditConnectionParamsCase | None,
) -> None:
if edit_connection_params_case is None:
pytest.skip("No edit_connection_params_case fixture provided")

resp = control_api_sync_client.put(
url=f"/api/v1/connections/{saved_connection_id}",
content_type="application/json",
data=json.dumps(edit_connection_params_case.params),
headers=bi_headers,
)
assert resp.status_code == 200, resp.json

resp = control_api_sync_client.get(
url=f"/api/v1/connections/{saved_connection_id}",
headers=bi_headers,
)
assert resp.status_code == 200, resp.json
for param_name, val in (
edit_connection_params_case.params | edit_connection_params_case.additional_fields_to_check
).items():
if param_name in edit_connection_params_case.load_only_field_names:
continue
assert resp.json[param_name] == val, resp.json

for field_name in edit_connection_params_case.check_absence_of_fields:
if field_name in resp.json:
assert resp.json[field_name] is None, resp.json

def test_test_on_edit_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
edit_connection_params_case: EditConnectionParamsCase | None,
) -> None:
if edit_connection_params_case is None:
pytest.skip("No edit_connection_params_case fixture provided")

if not edit_connection_params_case.supports_connection_test:
pytest.skip("Connection test is not supported for this connection type")

resp = control_api_sync_client.post(
f"/api/v1/connections/test_connection/{saved_connection_id}",
content_type="application/json",
data=json.dumps(edit_connection_params_case.params),
headers=bi_headers,
)
assert resp.status_code == 200, resp.json

def test_export_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
sync_us_manager: SyncUSManager,
control_api_app_settings: ControlApiAppSettings,
) -> None:
Expand Down Expand Up @@ -127,7 +70,7 @@ def test_import_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
sync_us_manager: SyncUSManager,
control_api_app_settings: ControlApiAppSettings,
) -> None:
Expand Down Expand Up @@ -200,7 +143,7 @@ def test_test_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
) -> None:
resp = control_api_sync_client.post(
f"/api/v1/connections/test_connection/{saved_connection_id}",
Expand All @@ -214,7 +157,7 @@ def test_cache_ttl_sec_override(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
) -> None:
resp = control_api_sync_client.get(
url=f"/api/v1/connections/{saved_connection_id}",
Expand Down Expand Up @@ -243,7 +186,7 @@ def test_connection_options(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
) -> None:
resp = control_api_sync_client.get(
url=f"/api/v1/connections/{saved_connection_id}",
Expand All @@ -260,7 +203,7 @@ def test_connection_sources(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
) -> None:
resp = control_api_sync_client.get(
url=f"/api/v1/connections/{saved_connection_id}/info/sources",
Expand All @@ -275,7 +218,7 @@ def test_create_connections__query_params_in_db_name__error(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: dict[str, str] | None,
bi_headers: Optional[dict[str, str]],
connection_params: dict,
) -> None:
if "db_name" not in connection_params:
Expand Down