Skip to content

Commit

Permalink
Fix tests, add more tests, review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jedrazb committed Oct 2, 2023
1 parent a8154c2 commit a4e1815
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 27 deletions.
39 changes: 26 additions & 13 deletions connectors/sources/google_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_default_configuration(cls):
"use_text_extraction_service": {
"display": "toggle",
"label": "Use text extraction service",
"order": 5,
"order": 8,
"tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.",
"type": "bool",
"ui_restrictions": ["advanced"],
Expand All @@ -587,11 +587,25 @@ def get_default_configuration(cls):
}

def google_drive_client(self, impersonate_email=None):
"""Initialize and return the GoogleDriveClient
"""
Initialize and return an instance of the GoogleDriveClient.
This method sets up a Google Drive client using service account credentials.
If an impersonate_email is provided, the client will be set up for domain-wide
delegation, allowing it to impersonate the provided user account within
a Google Workspace domain.
GoogleDriveClient needs to be reinstantiated for different values of impersonate_email,
therefore the client is not cached.
Args:
impersonate_email (str, optional): The email of the user account to impersonate.
Defaults to None, in which case no impersonation is set up (in case domain-wide delegation is disabled).
Returns:
GoogleDriveClient: An instance of the GoogleDriveClient.
GoogleDriveClient: An initialized instance of the GoogleDriveClient.
"""

service_account_credentials = self.configuration["service_account_credentials"]

validate_service_account_json(
Expand Down Expand Up @@ -751,7 +765,6 @@ def _get_google_workspace_admin_email(self):
return None

def _google_google_workspace_email_for_shared_drives_sync(self):
"""Get Google Workspace email for syncing shared drives"""
return self.configuration.get("google_workspace_email_for_shared_drives_sync")

def _dls_enabled(self):
Expand Down Expand Up @@ -1249,14 +1262,6 @@ async def get_docs(self, filtering=None):
seen_ids = set()

if self._domain_wide_delegation_sync_enabled():
email_for_shared_drives_sync = (
self._google_google_workspace_email_for_shared_drives_sync()
)

shared_drives_client = self.google_drive_client(
impersonate_email=email_for_shared_drives_sync
)

# sync personal drives first
async for user in self.google_admin_directory_client.users():
email = user.get(UserFields.EMAIL.value)
Expand All @@ -1273,12 +1278,20 @@ async def get_docs(self, filtering=None):
):
yield file, partial(self.get_content, google_drive_client, file)

email_for_shared_drives_sync = (
self._google_google_workspace_email_for_shared_drives_sync()
)

shared_drives_client = self.google_drive_client(
impersonate_email=email_for_shared_drives_sync
)

# Build a path lookup, parentId -> parent path
resolved_paths = await self.resolve_paths(
google_drive_client=shared_drives_client
)

# sync shared drives impersonating the admin account
# sync shared drives
self._logger.debug(
f"Syncing shared drives using admin account: {email_for_shared_drives_sync}"
)
Expand Down
4 changes: 2 additions & 2 deletions tests/sources/fixtures/google_drive/connector.json
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@
"advanced"
]
},
"use_text_extraction_service": {
"use_text_extraction_service": {
"default_value": null,
"depends_on": [],
"display": "toggle",
"label": "Use text extraction service",
"options": [],
"order": 7,
"order": 8,
"required": true,
"sensitive": false,
"tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.",
Expand Down
71 changes: 59 additions & 12 deletions tests/sources/test_google_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,46 @@ async def test_raise_on_invalid_configuration():
await gd_object.validate_config()


@pytest.mark.asyncio
async def test_raise_on_invalid_email_configuration_misformatted_email():
"""Test if invalid configuration raises an expected Exception"""

configuration = DataSourceConfiguration(
{
"service_account_credentials": "{'abc':'bcd','cd'}",
"use_domain_wide_delegation_for_sync": True,
"google_workspace_admin_email_for_data_sync": None,
"google_workspace_email_for_shared_drives_sync": "",
}
)
gd_object = GoogleDriveDataSource(configuration=configuration)

with pytest.raises(
ConfigurableFieldValueError,
):
await gd_object._validate_google_workspace_email_for_shared_drives_sync()


@pytest.mark.asyncio
async def test_raise_on_invalid_email_configuration_empty_email():
"""Test if invalid configuration raises an expected Exception"""

configuration = DataSourceConfiguration(
{
"service_account_credentials": "{'abc':'bcd','cd'}",
"use_domain_wide_delegation_for_sync": True,
"google_workspace_admin_email_for_data_sync": "[email protected]",
"google_workspace_email_for_shared_drives_sync": "admin.com",
}
)
gd_object = GoogleDriveDataSource(configuration=configuration)

with pytest.raises(
ConfigurableFieldValueError,
):
await gd_object._validate_google_workspace_email_for_shared_drives_sync()


@pytest.mark.asyncio
async def test_ping_for_successful_connection():
"""Tests the ping functionality for ensuring connection to Google Drive."""
Expand Down Expand Up @@ -549,21 +589,26 @@ async def test_get_docs_with_domain_wide_delegation():
"url": None,
"type": "file",
}
dummy_url = "https://www.googleapis.com/drive/v3/files"

expected_response_object = Response(
status_code=200,
url=dummy_url,
json=expected_response,
req=Request(method="GET", url=dummy_url),
mock_gdrive_client = mock.MagicMock()
mock_gdrive_client.list_files_from_my_drive = mock.MagicMock(
return_value=AsyncIterator([expected_response])
)

with mock.patch.object(
Aiogoogle, "as_service_account", return_value=expected_response_object
):
with mock.patch.object(ServiceAccountManager, "refresh"):
async for file_document in source.get_docs():
assert file_document[0] == expected_file_document
mock_empty_response_future = asyncio.Future()
mock_empty_response_future.set_result(dict())

mock_gdrive_client.get_all_folders = mock.MagicMock(
return_value=mock_empty_response_future
)
mock_gdrive_client.get_all_drives = mock.MagicMock(
return_value=mock_empty_response_future
)

source.google_drive_client = mock.MagicMock(return_value=mock_gdrive_client)

async for file_document in source.get_docs():
assert file_document[0] == expected_file_document


@pytest.mark.asyncio
Expand Down Expand Up @@ -851,6 +896,7 @@ async def test_get_google_workspace_content_with_text_extraction_enabled_adds_bo
return_value=future_file_content_response
)
content = await source.get_content(
client=source.google_drive_client(),
file=file_document,
doit=True,
)
Expand Down Expand Up @@ -974,6 +1020,7 @@ async def test_get_generic_file_content_with_text_extraction_enabled_adds_body()
content = await source.get_content(
file=file_document,
doit=True,
client=source.google_drive_client(),
)
assert content == expected_file_document

Expand Down

0 comments on commit a4e1815

Please sign in to comment.