From 84da2c0911f2c79cfd298e37026c0d943fe95937 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 13:11:15 +0200 Subject: [PATCH 01/12] Expand to dropbox --- connectors/sources/dropbox.py | 1 + 1 file changed, 1 insertion(+) diff --git a/connectors/sources/dropbox.py b/connectors/sources/dropbox.py index b565d92c5..1cef3782c 100644 --- a/connectors/sources/dropbox.py +++ b/connectors/sources/dropbox.py @@ -667,6 +667,7 @@ async def get_content( ) return + self._logger.debug(f"Downloading {filename}") document = { "_id": attachment["id"], "_timestamp": attachment["server_modified"], From d260ce865c6c2d85d3e712460cdcfd53b8313975 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:18:26 +0200 Subject: [PATCH 02/12] Add extraction service to Jira --- connectors/sources/dropbox.py | 1 - 1 file changed, 1 deletion(-) diff --git a/connectors/sources/dropbox.py b/connectors/sources/dropbox.py index 1cef3782c..b565d92c5 100644 --- a/connectors/sources/dropbox.py +++ b/connectors/sources/dropbox.py @@ -667,7 +667,6 @@ async def get_content( ) return - self._logger.debug(f"Downloading {filename}") document = { "_id": attachment["id"], "_timestamp": attachment["server_modified"], From 179eba66a2bac5a2abb24e543c171828a85115eb Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:18:22 +0200 Subject: [PATCH 03/12] Add extraction service to GCS --- connectors/sources/google_cloud_storage.py | 48 +++++++-------- connectors/sources/s3.py | 2 + .../google_cloud_storage/connector.json | 15 +++++ tests/sources/test_google_cloud_storage.py | 60 ++++++++++++++++++- 4 files changed, 99 insertions(+), 26 deletions(-) diff --git a/connectors/sources/google_cloud_storage.py b/connectors/sources/google_cloud_storage.py index 392a4508e..2a4acefe8 100644 --- a/connectors/sources/google_cloud_storage.py +++ b/connectors/sources/google_cloud_storage.py @@ -229,6 +229,14 @@ def get_default_configuration(cls): "type": "int", "ui_restrictions": ["advanced"], }, + "use_text_extraction_service": { + "display": "toggle", + "label": "Use text extraction service", + "order": 3, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "value": False, + }, } async def validate_config(self): @@ -370,50 +378,40 @@ async def get_content(self, blob, timestamp=None, doit=None): Returns: dictionary: Content document with id, timestamp & text """ - blob_size = int(blob["size"]) - if not (doit and blob_size): + file_size = int(blob["size"]) + if not (doit and file_size): return - blob_name = blob["name"] - if (os.path.splitext(blob_name)[-1]).lower() not in TIKA_SUPPORTED_FILETYPES: - self._logger.debug(f"{blob_name} can't be extracted") + filename = blob["name"] + file_extension = self.get_file_extension(filename) + if not self.can_file_be_downloaded(file_extension, filename, file_size): return - if blob_size > DEFAULT_FILE_SIZE_LIMIT: - self._logger.warning( - f"File size {blob_size} of file {blob_name} is larger than {DEFAULT_FILE_SIZE_LIMIT} bytes. Discarding the file content" - ) - return - self._logger.debug(f"Downloading {blob_name}") document = { "_id": blob["id"], "_timestamp": blob["_timestamp"], } - source_file_name = "" - async with NamedTemporaryFile(mode="wb", delete=False) as async_buffer: + + # gcs has a unique download method so we can't utilize + # the generic download_and_extract_file func + async with self.create_temp_file(file_extension) as async_buffer: await anext( self._google_storage_client.api_call( resource="objects", method="get", bucket=blob["bucket_name"], - object=blob_name, + object=filename, alt="media", userProject=self._google_storage_client.user_project_id, pipe_to=async_buffer, ) ) - source_file_name = async_buffer.name + await async_buffer.close() + + document = await self.handle_file_content_extraction( + document, filename, async_buffer.name + ) - self._logger.debug(f"Calling convert_to_b64 for file : {blob_name}") - await asyncio.to_thread( - convert_to_b64, - source=source_file_name, - ) - async with aiofiles.open(file=source_file_name, mode="r") as target_file: - # base64 on macOS will add a EOL, so we strip() here - document["_attachment"] = (await target_file.read()).strip() - await remove(str(source_file_name)) - self._logger.debug(f"Downloaded {blob_name} for {blob_size} bytes ") return document async def get_docs(self, filtering=None): diff --git a/connectors/sources/s3.py b/connectors/sources/s3.py index 34c17d410..0c45254c1 100644 --- a/connectors/sources/s3.py +++ b/connectors/sources/s3.py @@ -250,6 +250,8 @@ async def get_content(self, doc, s3_client, timestamp=None, doit=None): await s3_client.download_fileobj( Bucket=bucket, Key=filename, Fileobj=async_buffer ) + await async_buffer.close() + document = await self.handle_file_content_extraction( document, filename, async_buffer.name ) diff --git a/tests/sources/fixtures/google_cloud_storage/connector.json b/tests/sources/fixtures/google_cloud_storage/connector.json index 80397805d..bd7c6295b 100644 --- a/tests/sources/fixtures/google_cloud_storage/connector.json +++ b/tests/sources/fixtures/google_cloud_storage/connector.json @@ -37,6 +37,21 @@ "ui_restrictions": [ "advanced" ] + }, + "use_text_extraction_service": { + "default_value": null, + "depends_on": [], + "display": "toggle", + "label": "Use text extraction service", + "options": [], + "order": 7, + "required": true, + "sensitive": false, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "ui_restrictions": [], + "validations": [], + "value": false } }, "custom_scheduling": {}, diff --git a/tests/sources/test_google_cloud_storage.py b/tests/sources/test_google_cloud_storage.py index d9a9d7c8b..92cf84cbb 100644 --- a/tests/sources/test_google_cloud_storage.py +++ b/tests/sources/test_google_cloud_storage.py @@ -8,6 +8,7 @@ import asyncio from contextlib import asynccontextmanager from unittest import mock +from unittest.mock import patch import pytest from aiogoogle import Aiogoogle @@ -24,11 +25,12 @@ @asynccontextmanager -async def create_gcs_source(): +async def create_gcs_source(use_text_extraction_service=False): async with create_source( GoogleCloudStorageDataSource, service_account_credentials=SERVICE_ACCOUNT_CREDENTIALS, retry_count=0, + use_text_extraction_service=use_text_extraction_service, ) as source: yield source @@ -364,6 +366,62 @@ async def test_get_content(): assert content == expected_blob_document +@pytest.mark.asyncio +@patch( + "connectors.content_extraction.ContentExtraction._check_configured", + lambda *_: True, +) +async def test_get_content_with_text_extraction_enabled_adds_body(): + """Test the module responsible for fetching the content of the file if it is extractable.""" + + with patch( + "connectors.content_extraction.ContentExtraction.extract_text", + return_value="file content", + ), patch( + "connectors.content_extraction.ContentExtraction.get_extraction_config", + return_value={"host": "http://localhost:8090"}, + ): + async with create_gcs_source(use_text_extraction_service=True) as source: + blob_document = { + "id": "bucket_1/blob_1/123123123", + "component_count": None, + "content_encoding": None, + "content_language": None, + "created_at": None, + "last_updated": "2011-10-12T00:01:00Z", + "metadata": None, + "name": "blob_1.txt", + "size": "15", + "storage_class": None, + "_timestamp": "2011-10-12T00:01:00Z", + "type": None, + "url": "https://console.cloud.google.com/storage/browser/_details/bucket_1/blob_1;tab=live_object?project=dummy123", + "version": None, + "bucket_name": "bucket_1", + } + expected_blob_document = { + "_id": "bucket_1/blob_1/123123123", + "_timestamp": "2011-10-12T00:01:00Z", + "body": "file content", + } + blob_content_response = "" + + with mock.patch.object( + Aiogoogle, "as_service_account", return_value=blob_content_response + ): + async with Aiogoogle( + service_account_creds=source._google_storage_client.service_account_credentials + ) as google_client: + storage_client = await google_client.discover( + api_name=API_NAME, api_version=API_VERSION + ) + storage_client.objects = mock.MagicMock() + content = await source.get_content( + blob=blob_document, + doit=True, + ) + assert content == expected_blob_document + @pytest.mark.asyncio async def test_get_content_with_upper_extension(): """Test the module responsible for fetching the content of the file if it is extractable.""" From 1e8da573d3007dca38be523e0e21f673d7c4b0b5 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:54:23 +0200 Subject: [PATCH 04/12] Add extraction service to google drive --- connectors/source.py | 8 ++ connectors/sources/google_cloud_storage.py | 5 +- connectors/sources/google_drive.py | 88 ++++++--------- .../fixtures/google_drive/connector.json | 15 +++ tests/sources/test_google_cloud_storage.py | 1 + tests/sources/test_google_drive.py | 105 +++++++++++++++++- 6 files changed, 159 insertions(+), 63 deletions(-) diff --git a/connectors/source.py b/connectors/source.py index 930636226..2b98d32dd 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -693,6 +693,11 @@ def get_file_extension(self, filename): return get_file_extension(filename) def can_file_be_downloaded(self, file_extension, filename, file_size): + return self.is_valid_file_type( + file_extension, filename + ) and self.is_file_size_within_limit(file_size, filename) + + def is_valid_file_type(self, file_extension, filename): if file_extension == "": self._logger.debug( f"Files without extension are not supported, skipping {filename}." @@ -705,6 +710,9 @@ def can_file_be_downloaded(self, file_extension, filename, file_size): ) return False + return True + + def is_file_size_within_limit(self, file_size, filename): if file_size > FILE_SIZE_LIMIT and not self.configuration.get( "use_text_extraction_service" ): diff --git a/connectors/sources/google_cloud_storage.py b/connectors/sources/google_cloud_storage.py index 2a4acefe8..c17acbbb5 100644 --- a/connectors/sources/google_cloud_storage.py +++ b/connectors/sources/google_cloud_storage.py @@ -10,9 +10,6 @@ import urllib.parse from functools import cached_property, partial -import aiofiles -from aiofiles.os import remove -from aiofiles.tempfile import NamedTemporaryFile from aiogoogle import Aiogoogle from aiogoogle.auth.creds import ServiceAccountCreds @@ -22,7 +19,7 @@ load_service_account_json, validate_service_account_json, ) -from connectors.utils import TIKA_SUPPORTED_FILETYPES, convert_to_b64, get_pem_format +from connectors.utils import get_pem_format CLOUD_STORAGE_READ_ONLY_SCOPE = "https://www.googleapis.com/auth/devstorage.read_only" CLOUD_STORAGE_BASE_URL = "https://console.cloud.google.com/storage/browser/_details/" diff --git a/connectors/sources/google_drive.py b/connectors/sources/google_drive.py index 15099efec..66f29bfe1 100644 --- a/connectors/sources/google_drive.py +++ b/connectors/sources/google_drive.py @@ -7,9 +7,6 @@ import os from functools import cached_property, partial -import aiofiles -from aiofiles.os import remove, stat -from aiofiles.tempfile import NamedTemporaryFile from aiogoogle import Aiogoogle, HTTPError from aiogoogle.auth.creds import ServiceAccountCreds from aiogoogle.sessions.aiohttp_session import AiohttpSession @@ -27,9 +24,7 @@ ) from connectors.utils import ( EMAIL_REGEX_PATTERN, - TIKA_SUPPORTED_FILETYPES, RetryStrategy, - convert_to_b64, retryable, validate_email_address, ) @@ -39,7 +34,6 @@ RETRIES = 3 RETRY_INTERVAL = 2 -FILE_SIZE_LIMIT = 10485760 # ~ 10 Megabytes GOOGLE_API_MAX_CONCURRENCY = 25 # Max open connections to Google API @@ -513,6 +507,14 @@ def get_default_configuration(cls): "ui_restrictions": ["advanced"], "validations": [{"type": "greater_than", "constraint": 0}], }, + "use_text_extraction_service": { + "display": "toggle", + "label": "Use text extraction service", + "order": 5, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "value": False, + }, } @cached_property @@ -778,43 +780,23 @@ async def _download_content(self, file, download_func): attachment, file_size (tuple): base64 encoded contnet of the file and size in bytes of the attachment """ - temp_file_name = "" file_name = file["name"] - attachment, file_size = None, 0 - - self._logger.debug(f"Downloading {file_name}") - - try: - async with NamedTemporaryFile(mode="wb", delete=False) as async_buffer: - await download_func( - pipe_to=async_buffer, - ) - - temp_file_name = async_buffer.name + file_extension = self.get_file_extension(file_name) + attachment, body, file_size = None, None, 0 - await asyncio.to_thread( - convert_to_b64, - source=temp_file_name, + async with self.create_temp_file(file_extension) as async_buffer: + await download_func( + pipe_to=async_buffer, ) + await async_buffer.close() - file_stat = await stat(temp_file_name) - file_size = file_stat.st_size - - async with aiofiles.open(file=temp_file_name, mode="r") as target_file: - attachment = (await target_file.read()).strip() - - self._logger.debug( - f"Downloaded {file_name} with the size of {file_size} bytes " + doc = await self.handle_file_content_extraction( + {}, file_name, async_buffer.name ) - except Exception as e: - self._logger.error( - f"Exception encountered when processing file: {file_name}. Exception: {e}" - ) - finally: - if temp_file_name: - await remove(str(temp_file_name)) + attachment = doc.get("_attachment") + body = doc.get("body") - return attachment, file_size + return attachment, body, file_size async def get_google_workspace_content(self, file, timestamp=None): """Exports Google Workspace documents to an allowed file type and extracts its text content. @@ -837,8 +819,7 @@ async def get_google_workspace_content(self, file, timestamp=None): "_id": file_id, "_timestamp": file["_timestamp"], } - - attachment, file_size = await self._download_content( + attachment, body, file_size = await self._download_content( file=file, download_func=partial( self.google_drive_client.api_call, @@ -854,13 +835,14 @@ async def get_google_workspace_content(self, file, timestamp=None): # into text/plain format. We usually we end up with tiny .txt files. # 2. Google will ofter report the Google Workspace shared documents to have size 0 # as they don't count against user's storage quota. - if file_size > FILE_SIZE_LIMIT: - self._logger.warning( - f"File size {file_size} of file {file_name} is larger than {FILE_SIZE_LIMIT} bytes. Discarding the file content" - ) + if not self.is_file_size_within_limit(file_size, file_name): return - document["_attachment"] = attachment + if attachment: + document["_attachment"] = attachment + elif body: + document["body"] = body + return document async def get_generic_file_content(self, file, timestamp=None): @@ -885,22 +867,14 @@ async def get_generic_file_content(self, file, timestamp=None): f".{file['file_extension']}", ) - if file_extension not in TIKA_SUPPORTED_FILETYPES: - self._logger.debug(f"{file_name} can't be extracted") - return - - if file_size > FILE_SIZE_LIMIT: - self._logger.warning( - f"File size {file_size} of file {file_name} is larger than {FILE_SIZE_LIMIT} bytes. Discarding the file content" - ) + if not self.can_file_be_downloaded(file_extension, file_name, file_size): return document = { "_id": file_id, "_timestamp": file["_timestamp"], } - - attachment, _ = await self._download_content( + attachment, body, _ = await self._download_content( file=file, download_func=partial( self.google_drive_client.api_call, @@ -912,7 +886,11 @@ async def get_generic_file_content(self, file, timestamp=None): ), ) - document["_attachment"] = attachment + if attachment: + document["_attachment"] = attachment + elif body: + document["body"] = body + return document async def get_content(self, file, timestamp=None, doit=None): diff --git a/tests/sources/fixtures/google_drive/connector.json b/tests/sources/fixtures/google_drive/connector.json index e13e6793d..8b99d08e9 100644 --- a/tests/sources/fixtures/google_drive/connector.json +++ b/tests/sources/fixtures/google_drive/connector.json @@ -77,6 +77,21 @@ "ui_restrictions": [ "advanced" ] + }, + "use_text_extraction_service": { + "default_value": null, + "depends_on": [], + "display": "toggle", + "label": "Use text extraction service", + "options": [], + "order": 7, + "required": true, + "sensitive": false, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "ui_restrictions": [], + "validations": [], + "value": false } }, "custom_scheduling": {}, diff --git a/tests/sources/test_google_cloud_storage.py b/tests/sources/test_google_cloud_storage.py index 92cf84cbb..5977986be 100644 --- a/tests/sources/test_google_cloud_storage.py +++ b/tests/sources/test_google_cloud_storage.py @@ -422,6 +422,7 @@ async def test_get_content_with_text_extraction_enabled_adds_body(): ) assert content == expected_blob_document + @pytest.mark.asyncio async def test_get_content_with_upper_extension(): """Test the module responsible for fetching the content of the file if it is extractable.""" diff --git a/tests/sources/test_google_drive.py b/tests/sources/test_google_drive.py index a794ec162..a15f9f336 100644 --- a/tests/sources/test_google_drive.py +++ b/tests/sources/test_google_drive.py @@ -26,12 +26,13 @@ @asynccontextmanager -async def create_gdrive_source(): +async def create_gdrive_source(use_text_extraction_service=False): async with create_source( GoogleDriveDataSource, service_account_credentials=SERVICE_ACCOUNT_CREDENTIALS, use_document_level_security=False, google_workspace_admin_email="admin@your-organization.com", + use_text_extraction_service=use_text_extraction_service, ) as source: yield source @@ -719,7 +720,7 @@ async def test_get_google_workspace_content(): "_timestamp": "2023-06-28T07:46:28.000Z", "_attachment": "I love unit tests", } - file_content_response = ("I love unit tests", 1234) + file_content_response = ("I love unit tests", None, 1234) future_file_content_response = asyncio.Future() future_file_content_response.set_result(file_content_response) @@ -733,6 +734,52 @@ async def test_get_google_workspace_content(): assert content == expected_file_document +@pytest.mark.asyncio +@patch( + "connectors.content_extraction.ContentExtraction._check_configured", + lambda *_: True, +) +async def test_get_google_workspace_content_with_text_extraction_enabled_adds_body(): + """Test the module responsible for fetching the content of the Google Suite document.""" + with patch( + "connectors.content_extraction.ContentExtraction.extract_text", + return_value="I love unit tests", + ), patch( + "connectors.content_extraction.ContentExtraction.get_extraction_config", + return_value={"host": "http://localhost:8090"}, + ): + async with create_gdrive_source(use_text_extraction_service=True) as source: + file_document = { + "id": "id1", + "created_at": None, + "last_updated": "2023-06-28T07:46:28.000Z", + "name": "test.txt", + "size": 28, + "_timestamp": "2023-06-28T07:46:28.000Z", + "mime_type": "application/vnd.google-apps.document", + "file_extension": None, + "url": None, + "type": "file", + } + expected_file_document = { + "_id": "id1", + "_timestamp": "2023-06-28T07:46:28.000Z", + "body": "I love unit tests", + } + file_content_response = (None, "I love unit tests", 1234) + future_file_content_response = asyncio.Future() + future_file_content_response.set_result(file_content_response) + + source._download_content = mock.MagicMock( + return_value=future_file_content_response + ) + content = await source.get_content( + file=file_document, + doit=True, + ) + assert content == expected_file_document + + @pytest.mark.asyncio async def test_get_google_workspace_content_size_limit(): """Test the module responsible for fetching the content of the Google Suite document if its size @@ -752,7 +799,11 @@ async def test_get_google_workspace_content_size_limit(): "type": "file", } - file_content_response = ("I love unit tests", MORE_THAN_DEFAULT_FILE_SIZE_LIMIT) + file_content_response = ( + "I love unit tests", + None, + MORE_THAN_DEFAULT_FILE_SIZE_LIMIT, + ) future_file_content_response = asyncio.Future() future_file_content_response.set_result(file_content_response) @@ -788,7 +839,7 @@ async def test_get_generic_file_content(): "_timestamp": "2023-06-28T07:46:28.000Z", "_attachment": "I love unit tests generic file", } - file_content_response = ("I love unit tests generic file", 1234) + file_content_response = ("I love unit tests generic file", None, 1234) future_file_content_response = asyncio.Future() future_file_content_response.set_result(file_content_response) @@ -802,6 +853,52 @@ async def test_get_generic_file_content(): assert content == expected_file_document +@pytest.mark.asyncio +@patch( + "connectors.content_extraction.ContentExtraction._check_configured", + lambda *_: True, +) +async def test_get_generic_file_content_with_text_extraction_enabled_adds_body(): + """Test the module responsible for fetching the content of the file if it is extractable.""" + with patch( + "connectors.content_extraction.ContentExtraction.extract_text", + return_value="I love unit tests generic file", + ), patch( + "connectors.content_extraction.ContentExtraction.get_extraction_config", + return_value={"host": "http://localhost:8090"}, + ): + async with create_gdrive_source(use_text_extraction_service=True) as source: + file_document = { + "id": "id1", + "created_at": None, + "last_updated": "2023-06-28T07:46:28.000Z", + "name": "test.txt", + "size": 28, + "_timestamp": "2023-06-28T07:46:28.000Z", + "mime_type": "text/plain", + "file_extension": "txt", + "url": None, + "type": "file", + } + expected_file_document = { + "_id": "id1", + "_timestamp": "2023-06-28T07:46:28.000Z", + "body": "I love unit tests generic file", + } + file_content_response = (None, "I love unit tests generic file", 1234) + future_file_content_response = asyncio.Future() + future_file_content_response.set_result(file_content_response) + + source._download_content = mock.MagicMock( + return_value=future_file_content_response + ) + content = await source.get_content( + file=file_document, + doit=True, + ) + assert content == expected_file_document + + @pytest.mark.asyncio async def test_get_generic_file_content_size_limit(): """Test the module responsible for fetching the content of the file size is above the limit.""" From d2b0443d9f5da3bfeb13466d3eb7118fb58f6f7e Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:43:52 +0200 Subject: [PATCH 05/12] Add extraction service to servicenow --- connectors/source.py | 35 +++--- connectors/sources/azure_blob_storage.py | 4 +- connectors/sources/confluence.py | 4 +- connectors/sources/dropbox.py | 4 +- connectors/sources/google_drive.py | 8 +- connectors/sources/jira.py | 3 +- connectors/sources/servicenow.py | 116 ++++++------------ .../fixtures/servicenow/connector.json | 15 +++ 8 files changed, 80 insertions(+), 109 deletions(-) diff --git a/connectors/source.py b/connectors/source.py index 2b98d32dd..694217c0b 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -727,23 +727,28 @@ async def download_and_extract_file( self, doc, source_filename, file_extension, download_func ): # 1 create tempfile - async with self.create_temp_file(file_extension) as async_buffer: - temp_filename = async_buffer.name - - # 2 download to tempfile - await self.download_to_temp_file( - temp_filename, - source_filename, - async_buffer, - download_func, - ) + try: + async with self.create_temp_file(file_extension) as async_buffer: + temp_filename = async_buffer.name - # 3 extract or convert content - doc = await self.handle_file_content_extraction( - doc, source_filename, temp_filename - ) + # 2 download to tempfile + await self.download_to_temp_file( + temp_filename, + source_filename, + async_buffer, + download_func, + ) - return doc + # 3 extract or convert content + doc = await self.handle_file_content_extraction( + doc, source_filename, temp_filename + ) + return doc + except Exception as e: + self._logger.warning( + f"File download and extraction or conversion for file {source_filename} failed: {e}" + ) + return @asynccontextmanager async def create_temp_file(self, file_extension): diff --git a/connectors/sources/azure_blob_storage.py b/connectors/sources/azure_blob_storage.py index ed305131f..d99838887 100644 --- a/connectors/sources/azure_blob_storage.py +++ b/connectors/sources/azure_blob_storage.py @@ -178,15 +178,13 @@ async def get_content(self, blob, timestamp=None, doit=None): return document = {"_id": blob["id"], "_timestamp": blob["_timestamp"]} - document = await self.download_and_extract_file( + return await self.download_and_extract_file( document, filename, file_extension, partial(self.blob_download_func, filename, blob["container"]), ) - return document - async def blob_download_func(self, blob_name, container_name): async with BlobClient.from_connection_string( conn_str=self.connection_string, diff --git a/connectors/sources/confluence.py b/connectors/sources/confluence.py index e34e10957..7a0d10520 100644 --- a/connectors/sources/confluence.py +++ b/connectors/sources/confluence.py @@ -653,7 +653,7 @@ async def download_attachment(self, url, attachment, timestamp=None, doit=False) return document = {"_id": attachment["_id"], "_timestamp": attachment["_timestamp"]} - document = await self.download_and_extract_file( + return await self.download_and_extract_file( document, filename, file_extension, @@ -666,8 +666,6 @@ async def download_attachment(self, url, attachment, timestamp=None, doit=False) ), ) - return document - async def _attachment_coro(self, document, access_control): """Coroutine to add attachments to Queue and download content diff --git a/connectors/sources/dropbox.py b/connectors/sources/dropbox.py index b565d92c5..1fc91ee74 100644 --- a/connectors/sources/dropbox.py +++ b/connectors/sources/dropbox.py @@ -671,7 +671,7 @@ async def get_content( "_id": attachment["id"], "_timestamp": attachment["server_modified"], } - document = await self.download_and_extract_file( + return await self.download_and_extract_file( document, filename, file_extension, @@ -681,8 +681,6 @@ async def get_content( ), ) - return document - def download_func(self, is_shared, attachment, filename): if is_shared: return partial( diff --git a/connectors/sources/google_drive.py b/connectors/sources/google_drive.py index 66f29bfe1..95d2911a3 100644 --- a/connectors/sources/google_drive.py +++ b/connectors/sources/google_drive.py @@ -838,9 +838,9 @@ async def get_google_workspace_content(self, file, timestamp=None): if not self.is_file_size_within_limit(file_size, file_name): return - if attachment: + if attachment is not None: document["_attachment"] = attachment - elif body: + elif body is not None: document["body"] = body return document @@ -886,9 +886,9 @@ async def get_generic_file_content(self, file, timestamp=None): ), ) - if attachment: + if attachment is not None: document["_attachment"] = attachment - elif body: + elif body is not None: document["body"] = body return document diff --git a/connectors/sources/jira.py b/connectors/sources/jira.py index f2fb70265..373e190fc 100644 --- a/connectors/sources/jira.py +++ b/connectors/sources/jira.py @@ -645,7 +645,7 @@ async def get_content(self, issue_key, attachment, timestamp=None, doit=False): "_id": f"{issue_key}-{attachment['id']}", "_timestamp": attachment["created"], } - document = await self.download_and_extract_file( + return await self.download_and_extract_file( document, filename, file_extension, @@ -659,7 +659,6 @@ async def get_content(self, issue_key, attachment, timestamp=None, doit=False): ), ), ) - return document async def ping(self): """Verify the connection with Jira""" diff --git a/connectors/sources/servicenow.py b/connectors/sources/servicenow.py index 606d25e14..60351ed27 100644 --- a/connectors/sources/servicenow.py +++ b/connectors/sources/servicenow.py @@ -4,7 +4,6 @@ # you may not use this file except in compliance with the Elastic License 2.0. # """ServiceNow source module responsible to fetch documents from ServiceNow.""" -import asyncio import base64 import json import os @@ -13,12 +12,9 @@ from functools import cached_property, partial from urllib.parse import urlencode -import aiofiles import aiohttp import dateutil.parser as parser import fastjsonschema -from aiofiles.os import remove -from aiofiles.tempfile import NamedTemporaryFile from connectors.filtering.validation import ( AdvancedRulesValidator, @@ -27,12 +23,10 @@ from connectors.logger import logger from connectors.source import BaseDataSource, ConfigurableFieldValueError from connectors.utils import ( - TIKA_SUPPORTED_FILETYPES, CancellableSleeps, ConcurrentTasks, MemQueue, RetryStrategy, - convert_to_b64, iso_utc, retryable, ) @@ -278,6 +272,10 @@ async def _api_call(self, url, params, actions, method): url=url, params=params, json=actions ) + async def download_func(self, url): + response = await self._api_call(url, {}, {}, "get") + yield response + async def filter_services(self, configured_service): """Filter services based on service mappings. @@ -318,78 +316,6 @@ async def filter_services(self, configured_service): ) raise - async def fetch_attachment_content(self, metadata, timestamp=None, doit=False): - """Fetch attachment content via metadata. - - Args: - metadata (dict): Attachment metadata. - timestamp (timestamp, None): Attachment last modified timestamp. Defaults to None. - doit (bool, False): Whether to get content or not. Defaults to False. - - Returns: - dict: Document with id, timestamp & content. - """ - - attachment_size = int(metadata["size_bytes"]) - if not (doit and attachment_size > 0): - return - - attachment_name = metadata["file_name"] - attachment_extension = os.path.splitext(attachment_name)[-1] - if attachment_extension == "": - self._logger.warning( - f"Files without extension are not supported by TIKA, skipping {attachment_name}." - ) - return - elif attachment_extension.lower() not in TIKA_SUPPORTED_FILETYPES: - self._logger.warning( - f"Files with the extension {attachment_extension} are not supported by TIKA, skipping {attachment_name}." - ) - return - - if attachment_size > FILE_SIZE_LIMIT: - self._logger.warning( - f"File size {attachment_size} of file {attachment_name} is larger than {FILE_SIZE_LIMIT} bytes. Discarding file content." - ) - return - - document = {"_id": metadata["id"], "_timestamp": metadata["_timestamp"]} - - temp_filename = "" - async with NamedTemporaryFile(mode="wb", delete=False) as async_buffer: - temp_filename = str(async_buffer.name) - - try: - response = await self._api_call( - url=ENDPOINTS["DOWNLOAD"].format(sys_id=metadata["id"]), - params={}, - actions={}, - method="get", - ) - async for data in response.content.iter_chunked(CHUNK_SIZE): - await async_buffer.write(data) - - except Exception as exception: - self._logger.warning( - f"Skipping content for {attachment_name}. Exception: {exception}." - ) - return - - self._logger.debug(f"Calling convert_to_b64 for file : {attachment_name}.") - await asyncio.to_thread(convert_to_b64, source=temp_filename) - - async with aiofiles.open(file=temp_filename, mode="r") as async_buffer: - document["_attachment"] = (await async_buffer.read()).strip() - - try: - await remove(temp_filename) - except Exception as exception: - self._logger.warning( - f"Error while deleting the file: {temp_filename} from disk. Error: {exception}" - ) - - return document - async def ping(self): await self.get_table_length(table_name="sys_db_object") @@ -539,6 +465,14 @@ def get_default_configuration(cls): "type": "int", "ui_restrictions": ["advanced"], }, + "use_text_extraction_service": { + "display": "toggle", + "label": "Use text extraction service", + "order": 7, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "value": False, + }, } async def _remote_validation(self): @@ -614,7 +548,7 @@ async def _fetch_attachment_metadata(self, batched_apis): ( # pyright: ignore serialized_attachment_metadata, partial( - self.servicenow_client.fetch_attachment_content, + self.get_content, serialized_attachment_metadata, ), ) @@ -769,3 +703,27 @@ async def get_docs(self, filtering=None): yield item await self.fetchers.join() + + async def get_content(self, metadata, timestamp=None, doit=False): + file_size = int(metadata["size_bytes"]) + if not (doit and file_size > 0): + return + + filename = metadata["file_name"] + file_extension = self.get_file_extension(filename) + if not self.can_file_be_downloaded(file_extension, filename, file_size): + return + + document = {"_id": metadata["id"], "_timestamp": metadata["_timestamp"]} + return await self.download_and_extract_file( + document, + filename, + file_extension, + partial( + self.generic_chunked_download_func, + partial( + self.servicenow_client.download_func, + ENDPOINTS["DOWNLOAD"].format(sys_id=metadata["id"]), + ), + ), + ) diff --git a/tests/sources/fixtures/servicenow/connector.json b/tests/sources/fixtures/servicenow/connector.json index c76b692d7..5989f167b 100644 --- a/tests/sources/fixtures/servicenow/connector.json +++ b/tests/sources/fixtures/servicenow/connector.json @@ -94,6 +94,21 @@ "value": "some_test_user", "order": 2, "ui_restrictions": [] + }, + "use_text_extraction_service": { + "default_value": null, + "depends_on": [], + "display": "toggle", + "label": "Use text extraction service", + "options": [], + "order": 7, + "required": true, + "sensitive": false, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "ui_restrictions": [], + "validations": [], + "value": false } }, "custom_scheduling": {}, From 8cc6d1446fcebde78fa682d9d5aa127f2933c2b6 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:45:20 +0200 Subject: [PATCH 06/12] Fix missed tests --- tests/sources/test_servicenow.py | 50 +++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/tests/sources/test_servicenow.py b/tests/sources/test_servicenow.py index 375e64472..4e3354956 100644 --- a/tests/sources/test_servicenow.py +++ b/tests/sources/test_servicenow.py @@ -28,13 +28,14 @@ @asynccontextmanager -async def create_service_now_source(): +async def create_service_now_source(use_text_extraction_service=False): async with create_source( ServiceNowDataSource, url="http://127.0.0.1:1234", username="admin", password="changeme", services="*", + use_text_extraction_service=use_text_extraction_service, ) as source: yield source @@ -402,7 +403,7 @@ async def test_fetch_attachment_content_with_doit(): return_value=MockResponse(res=b"Attachment Content", headers={}) ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", @@ -419,6 +420,39 @@ async def test_fetch_attachment_content_with_doit(): } +@pytest.mark.asyncio +async def test_fetch_attachment_content_with_extraction_service(): + with patch( + "connectors.content_extraction.ContentExtraction.extract_text", + return_value="Attachment Content", + ), patch( + "connectors.content_extraction.ContentExtraction.get_extraction_config", + return_value={"host": "http://localhost:8090"}, + ): + async with create_service_now_source( + use_text_extraction_service=True + ) as source: + source.servicenow_client._api_call = mock.AsyncMock( + return_value=MockResponse(res=b"Attachment Content", headers={}) + ) + + response = await source.get_content( + metadata={ + "id": "id_1", + "_timestamp": "1212-12-12 12:12:12", + "file_name": "file_1.txt", + "size_bytes": "2048", + }, + doit=True, + ) + + assert response == { + "_id": "id_1", + "_timestamp": "1212-12-12 12:12:12", + "body": "Attachment Content", + } + + @pytest.mark.asyncio async def test_fetch_attachment_content_with_upper_extension(): async with create_service_now_source() as source: @@ -426,7 +460,7 @@ async def test_fetch_attachment_content_with_upper_extension(): return_value=MockResponse(res=b"Attachment Content", headers={}) ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", @@ -450,7 +484,7 @@ async def test_fetch_attachment_content_without_doit(): return_value=MockResponse(res=b"Attachment Content", headers={}) ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", @@ -469,7 +503,7 @@ async def test_fetch_attachment_content_with_exception(): side_effect=Exception("Something went wrong") ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", @@ -489,7 +523,7 @@ async def test_fetch_attachment_content_with_unsupported_extension_then_skip(): return_value=MockResponse(res=b"Attachment Content", headers={}) ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", @@ -509,7 +543,7 @@ async def test_fetch_attachment_content_without_extension_then_skip(): return_value=MockResponse(res=b"Attachment Content", headers={}) ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", @@ -529,7 +563,7 @@ async def test_fetch_attachment_content_with_unsupported_file_size_then_skip(): return_value=MockResponse(res=b"Attachment Content", headers={}) ) - response = await source.servicenow_client.fetch_attachment_content( + response = await source.get_content( metadata={ "id": "id_1", "_timestamp": "1212-12-12 12:12:12", From 95dbb2f098bc203c8d9fccf6c768e754e72b7776 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:04:32 +0200 Subject: [PATCH 07/12] Add extraction service for servicenow --- connectors/sources/onedrive.py | 65 +++++----------- .../sources/fixtures/onedrive/connector.json | 15 ++++ tests/sources/test_onedrive.py | 77 ++++++++++++++----- 3 files changed, 95 insertions(+), 62 deletions(-) diff --git a/connectors/sources/onedrive.py b/connectors/sources/onedrive.py index ac8ee4a36..5adf9b041 100644 --- a/connectors/sources/onedrive.py +++ b/connectors/sources/onedrive.py @@ -31,7 +31,6 @@ from connectors.logger import logger from connectors.source import BaseDataSource from connectors.utils import ( - TIKA_SUPPORTED_FILETYPES, CacheWithTimeout, CancellableSleeps, RetryStrategy, @@ -451,6 +450,14 @@ def get_default_configuration(cls): "type": "bool", "value": False, }, + "use_text_extraction_service": { + "display": "toggle", + "label": "Use text extraction service", + "order": 7, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "value": False, + }, } def tweak_bulk_options(self, options): @@ -479,28 +486,6 @@ async def ping(self): self._logger.exception("Error while connecting to OneDrive") raise - def _pre_checks_for_get_content( - self, attachment_extension, attachment_name, attachment_size - ): - if attachment_extension == "": - self._logger.warning( - f"Files without extension are not supported, skipping {attachment_name}." - ) - return False - - if attachment_extension.lower() not in TIKA_SUPPORTED_FILETYPES: - self._logger.warning( - f"Files with the extension {attachment_extension} are not supported, skipping {attachment_name}." - ) - return False - - if attachment_size > FILE_SIZE_LIMIT: - self._logger.warning( - f"File size {attachment_size} of file {attachment_name} is larger than {FILE_SIZE_LIMIT} bytes. Discarding file content" - ) - return - return True - async def _get_document_with_content(self, attachment_name, document, url): temp_filename = "" @@ -541,36 +526,28 @@ async def get_content(self, file, download_url, timestamp=None, doit=False): dictionary: Content document with _id, _timestamp and file content """ - attachment_size = int(file["size"]) - if not (doit and attachment_size > 0): + file_size = int(file["size"]) + if not (doit and file_size > 0): return - attachment_name = file["title"] + filename = file["title"] - attachment_extension = ( - attachment_name[attachment_name.rfind(".") :] # noqa - if "." in attachment_name - else "" - ) - - if not self._pre_checks_for_get_content( - attachment_extension=attachment_extension, - attachment_name=attachment_name, - attachment_size=attachment_size, - ): + file_extension = self.get_file_extension(filename) + if not self.can_file_be_downloaded(file_extension, filename, file_size): return - self._logger.debug(f"Downloading {attachment_name}") - document = { "_id": file["_id"], "_timestamp": file["_timestamp"], } - - return await self._get_document_with_content( - attachment_name=attachment_name, - document=document, - url=download_url, + return await self.download_and_extract_file( + document, + filename, + file_extension, + partial( + self.generic_chunked_download_func, + partial(self.client.get, url=download_url), + ), ) def prepare_doc(self, file): diff --git a/tests/sources/fixtures/onedrive/connector.json b/tests/sources/fixtures/onedrive/connector.json index 4443219fa..23c774511 100644 --- a/tests/sources/fixtures/onedrive/connector.json +++ b/tests/sources/fixtures/onedrive/connector.json @@ -87,6 +87,21 @@ "value": false, "order": 6, "ui_restrictions": [] + }, + "use_text_extraction_service": { + "default_value": null, + "depends_on": [], + "display": "toggle", + "label": "Use text extraction service", + "options": [], + "order": 7, + "required": true, + "sensitive": false, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "ui_restrictions": [], + "validations": [], + "value": false } }, "custom_scheduling": {}, diff --git a/tests/sources/test_onedrive.py b/tests/sources/test_onedrive.py index d1bebb1bd..a72a8da8f 100644 --- a/tests/sources/test_onedrive.py +++ b/tests/sources/test_onedrive.py @@ -4,6 +4,7 @@ # you may not use this file except in compliance with the Elastic License 2.0. # """Tests the OneDrive source class methods""" +from contextlib import asynccontextmanager from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import pytest @@ -288,6 +289,11 @@ "_timestamp": "2023-05-01T09:10:21Z", "_attachment": "IyBUaGlzIGlzIHRoZSBkdW1teSBmaWxl", } +EXPECTED_CONTENT_EXTRACTED = { + "_id": "01DABHRNUACUYC4OM3GJG2NVHDI2ABGP4E", + "_timestamp": "2023-05-01T09:10:21Z", + "body": RESPONSE_CONTENT, +} EXPECTED_FILES_FOLDERS = {} @@ -423,6 +429,18 @@ def test_get_configuration(): assert config["client_id"] == "" +@asynccontextmanager +async def create_onedrive_source(use_text_extraction_service=False): + async with create_source( + OneDriveDataSource, + client_id="foo", + client_secret="bar", + tenant_id="faa", + use_text_extraction_service=use_text_extraction_service, + ) as source: + yield source + + @pytest.mark.asyncio @pytest.mark.parametrize( "extras", @@ -446,7 +464,7 @@ async def test_validate_configuration_with_invalid_dependency_fields_raises_erro @pytest.mark.asyncio async def test_close_with_client_session(): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: source.client.access_token = "dummy" await source.close() @@ -456,7 +474,7 @@ async def test_close_with_client_session(): @pytest.mark.asyncio async def test_set_access_token(): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: mock_token = {"access_token": "msgraphtoken", "expires_in": "1234555"} async_response = AsyncMock() async_response.__aenter__ = AsyncMock( @@ -471,7 +489,7 @@ async def test_set_access_token(): @pytest.mark.asyncio async def test_ping_for_successful_connection(): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: DUMMY_RESPONSE = {} source.client.get = AsyncIterator([[DUMMY_RESPONSE]]) @@ -481,7 +499,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio @patch("aiohttp.ClientSession.get") async def test_ping_for_failed_connection_exception(mock_get): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object( OneDriveClient, "get", side_effect=Exception("Something went wrong") ): @@ -559,7 +577,7 @@ async def test_get_with_429_status(): payload = {"value": "Test rate limit"} retried_response.__aenter__ = AsyncMock(return_value=JSONAsyncMock(payload)) - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch( "aiohttp.ClientSession.get", @@ -585,7 +603,7 @@ async def test_get_with_429_status_without_retry_after_header(): retried_response.__aenter__ = AsyncMock(return_value=JSONAsyncMock(payload)) with patch("connectors.sources.onedrive.DEFAULT_RETRY_SECONDS", 0.3): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch( "aiohttp.ClientSession.get", @@ -604,7 +622,7 @@ async def test_get_with_404_status(): error = ClientResponseError(None, None) error.status = 404 - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch( "aiohttp.ClientSession.get", @@ -623,7 +641,7 @@ async def test_get_with_500_status(): error = ClientResponseError(None, None) error.status = 500 - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch( "aiohttp.ClientSession.get", @@ -638,7 +656,7 @@ async def test_get_with_500_status(): @pytest.mark.asyncio async def test_get_owned_files(): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: async_response = AsyncMock() async_response.__aenter__ = AsyncMock( return_value=JSONAsyncMock(RESPONSE_FILES) @@ -659,7 +677,7 @@ async def test_get_owned_files(): @pytest.mark.asyncio async def test_list_users(): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: response = [] async_response = AsyncMock() async_response.__aenter__ = AsyncMock( @@ -686,7 +704,7 @@ async def test_list_users(): async def test_get_content_when_is_downloadable_is_true( file, download_url, expected_content ): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch("aiohttp.ClientSession.get", return_value=get_stream_reader()): with patch( @@ -701,6 +719,29 @@ async def test_get_content_when_is_downloadable_is_true( assert response == expected_content +@pytest.mark.asyncio +async def test_get_content_with_extraction_service(): + with patch( + "connectors.content_extraction.ContentExtraction.extract_text", + return_value=RESPONSE_CONTENT, + ), patch( + "connectors.content_extraction.ContentExtraction.get_extraction_config", + return_value={"host": "http://localhost:8090"}, + ): + async with create_onedrive_source(use_text_extraction_service=True) as source: + with patch.object(AccessToken, "get", return_value="abc"): + with patch("aiohttp.ClientSession.get", return_value=get_stream_reader()): + with patch( + "aiohttp.StreamReader.iter_chunked", + return_value=AsyncIterator([bytes(RESPONSE_CONTENT, "utf-8")]), + ): + response = await source.get_content( + file=MOCK_ATTACHMENT, + download_url="https://content1", + doit=True, + ) + assert response == EXPECTED_CONTENT_EXTRACTED + @patch.object(OneDriveClient, "list_users", return_value=AsyncIterator(EXPECTED_USERS)) @patch.object( OneDriveClient, @@ -726,7 +767,7 @@ async def test_get_content_when_is_downloadable_is_true( ) @pytest.mark.asyncio async def test_get_docs(users_patch, files_patch): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: expected_responses = [*EXPECTED_USER1_FILES, *EXPECTED_USER2_FILES] source.get_content = AsyncMock(return_value=EXPECTED_CONTENT) @@ -835,7 +876,7 @@ async def test_get_docs(users_patch, files_patch): ) @pytest.mark.asyncio async def test_advanced_rules_validation(advanced_rules, expected_validation_result): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: validation_result = await OneDriveAdvancedRulesValidator(source).validate( advanced_rules ) @@ -866,7 +907,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res ) @pytest.mark.asyncio async def test_get_docs_with_advanced_rules(filtering): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch.object( OneDriveClient, "list_users", return_value=AsyncIterator(EXPECTED_USERS) @@ -898,7 +939,7 @@ async def test_get_docs_with_advanced_rules(filtering): @pytest.mark.asyncio async def test_get_access_control_dls_disabled(): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=False) acl = [] @@ -925,7 +966,7 @@ async def test_get_access_control_dls_enabled(): ], ] - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=True) with patch.object(AccessToken, "get", return_value="abc"): @@ -967,7 +1008,7 @@ async def test_get_access_control_dls_enabled(): ) @pytest.mark.asyncio async def test_get_docs_without_dls_enabled(users_patch, files_patch): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=False) expected_responses = [*EXPECTED_USER1_FILES, *EXPECTED_USER2_FILES] @@ -1020,7 +1061,7 @@ async def test_get_docs_without_dls_enabled(users_patch, files_patch): ) @pytest.mark.asyncio async def test_get_docs_with_dls_enabled(users_patch, files_patch, permissions_patch): - async with create_source(OneDriveDataSource) as source: + async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=True) expected_responses = [ From 7a4614ce8b5367bbf5a81cb5313640c89ba1cd09 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:04:47 +0200 Subject: [PATCH 08/12] Autoformat --- tests/sources/test_onedrive.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/sources/test_onedrive.py b/tests/sources/test_onedrive.py index a72a8da8f..f8b866b33 100644 --- a/tests/sources/test_onedrive.py +++ b/tests/sources/test_onedrive.py @@ -730,7 +730,9 @@ async def test_get_content_with_extraction_service(): ): async with create_onedrive_source(use_text_extraction_service=True) as source: with patch.object(AccessToken, "get", return_value="abc"): - with patch("aiohttp.ClientSession.get", return_value=get_stream_reader()): + with patch( + "aiohttp.ClientSession.get", return_value=get_stream_reader() + ): with patch( "aiohttp.StreamReader.iter_chunked", return_value=AsyncIterator([bytes(RESPONSE_CONTENT, "utf-8")]), @@ -742,6 +744,7 @@ async def test_get_content_with_extraction_service(): ) assert response == EXPECTED_CONTENT_EXTRACTED + @patch.object(OneDriveClient, "list_users", return_value=AsyncIterator(EXPECTED_USERS)) @patch.object( OneDriveClient, From bec7ff02eb7a79ff2859ffb2afa2613d5be10a3e Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:32:19 +0200 Subject: [PATCH 09/12] Add content extraction for salesforce --- connectors/source.py | 22 +++++-- connectors/sources/salesforce.py | 66 +++++++++++++++---- .../fixtures/salesforce/connector.json | 15 +++++ tests/sources/test_salesforce.py | 63 +++++++++++++++++- 4 files changed, 144 insertions(+), 22 deletions(-) diff --git a/connectors/source.py b/connectors/source.py index 694217c0b..eec9aa439 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -724,14 +724,24 @@ def is_file_size_within_limit(self, file_size, filename): return True async def download_and_extract_file( - self, doc, source_filename, file_extension, download_func + self, doc, source_filename, file_extension, download_func, return_doc_if_failed=False ): - # 1 create tempfile + """ + Performs all the steps required for handling binary content: + 1. Make temp file + 2. Download content to temp file + 3. Extract using local service or convert to b64 + + Will return the doc with either `_attachment` or `body` added. + Returns `None` if any step fails. + + If the optional arg `return_doc_if_failed` is `True`, + will return the original doc upon failure + """ try: async with self.create_temp_file(file_extension) as async_buffer: temp_filename = async_buffer.name - # 2 download to tempfile await self.download_to_temp_file( temp_filename, source_filename, @@ -739,7 +749,6 @@ async def download_and_extract_file( download_func, ) - # 3 extract or convert content doc = await self.handle_file_content_extraction( doc, source_filename, temp_filename ) @@ -748,7 +757,10 @@ async def download_and_extract_file( self._logger.warning( f"File download and extraction or conversion for file {source_filename} failed: {e}" ) - return + if return_doc_if_failed: + return doc + else: + return @asynccontextmanager async def create_temp_file(self, file_extension): diff --git a/connectors/sources/salesforce.py b/connectors/sources/salesforce.py index 9da40b28b..a0b214943 100644 --- a/connectors/sources/salesforce.py +++ b/connectors/sources/salesforce.py @@ -6,7 +6,7 @@ """Salesforce source module responsible to fetch documents from Salesforce.""" import asyncio import os -from functools import cached_property +from functools import cached_property, partial from itertools import groupby import aiofiles @@ -293,14 +293,14 @@ async def get_case_feeds(self, case_ids): return all_case_feeds - async def download_content_documents(self, content_documents): - for content_document in content_documents: - content_version = content_document.get("LatestPublishedVersion", {}) or {} - download_url = content_version.get("VersionDataUrl") - if download_url: - content_document["_attachment"] = await self._download(download_url) - - yield content_document + # async def download_content_documents(self, content_documents): + # for content_document in content_documents: + # content_version = content_document.get("LatestPublishedVersion", {}) or {} + # download_url = content_version.get("VersionDataUrl") + # if download_url: + # content_document["_attachment"] = await self._download(download_url) + # + # yield content_document async def queryable_sobjects(self): """Cached async property""" @@ -484,6 +484,10 @@ async def _get(self, url, params=None): params=params, ) + async def _download(self, url): + response = await self._get(url) + yield response + async def _handle_client_response_error(self, response_body, e): exception_details = f"status: {e.status}, message: {e.message}" @@ -1201,7 +1205,6 @@ def map_content_document(self, content_document): return { "_id": content_document.get("Id"), - "_attachment": content_document.get("_attachment"), "content_size": content_document.get("ContentSize"), "created_at": content_document.get("CreatedDate"), "created_by": created_by.get("Name"), @@ -1374,6 +1377,14 @@ def get_default_configuration(cls): "tooltip": "The client secret for your OAuth2-enabled connected app. Also called 'consumer secret'", "type": "str", }, + "use_text_extraction_service": { + "display": "toggle", + "label": "Use text extraction service", + "order": 7, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "value": False, + }, } async def validate_config(self): @@ -1420,10 +1431,37 @@ async def get_docs(self, filtering=None): # Note: this could possibly be done on the fly if memory becomes an issue content_docs = self._combine_duplicate_content_docs(content_docs) - async for content_doc in self.salesforce_client.download_content_documents( - content_docs - ): - yield self.doc_mapper.map_content_document(content_doc), None + for content_doc in content_docs: + download_url = (content_doc.get("LatestPublishedVersion", {}) or {}).get("VersionDataUrl") + if not download_url: + self._logger.debug(f"No download URL found for {content_doc.get('title')}, skipping.") + continue + + doc = self.doc_mapper.map_content_document(content_doc) + doc = await self.get_content(doc, download_url) + + yield doc, None + + async def get_content(self, doc, download_url): + file_size = doc["content_size"] + filename = doc["title"] + file_extension = self.get_file_extension(filename) + if not self.can_file_be_downloaded(file_extension, filename, file_size): + return + + return await self.download_and_extract_file( + doc, + filename, + file_extension, + partial( + self.generic_chunked_download_func, + partial( + self.salesforce_client._download, + download_url, + ), + ), + return_doc_if_failed=True, # we still ingest on download failure for Salesforce + ) def _parse_content_documents(self, record): content_docs = [] diff --git a/tests/sources/fixtures/salesforce/connector.json b/tests/sources/fixtures/salesforce/connector.json index 5e9988552..7db98dfe2 100644 --- a/tests/sources/fixtures/salesforce/connector.json +++ b/tests/sources/fixtures/salesforce/connector.json @@ -15,6 +15,21 @@ "label": "Domain", "type": "str", "value": "fake.sandbox" + }, + "use_text_extraction_service": { + "default_value": null, + "depends_on": [], + "display": "toggle", + "label": "Use text extraction service", + "options": [], + "order": 7, + "required": true, + "sensitive": false, + "tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.", + "type": "bool", + "ui_restrictions": [], + "validations": [], + "value": false } }, "custom_scheduling": {}, diff --git a/tests/sources/test_salesforce.py b/tests/sources/test_salesforce.py index f4664cb04..cc2a3f7e7 100644 --- a/tests/sources/test_salesforce.py +++ b/tests/sources/test_salesforce.py @@ -8,6 +8,7 @@ from contextlib import asynccontextmanager from copy import deepcopy from unittest import TestCase, mock +from unittest.mock import patch import pytest from aiohttp.client_exceptions import ClientConnectionError @@ -421,12 +422,13 @@ @asynccontextmanager -async def create_salesforce_source(mock_token=True, mock_queryables=True): +async def create_salesforce_source(use_text_extraction_service=False, mock_token=True, mock_queryables=True): async with create_source( SalesforceDataSource, domain=TEST_DOMAIN, client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET, + use_text_extraction_service=use_text_extraction_service, ) as source: if mock_token is True: source.salesforce_client.api_token.token = mock.AsyncMock( @@ -1041,12 +1043,13 @@ async def test_get_all_with_content_docs_when_success( "url": f"{TEST_BASE_URL}/content_document_id", "version_number": "2", "version_url": f"{TEST_BASE_URL}/content_version_id", - "_attachment": expected_attachment, } + if expected_attachment is not None: + expected_doc["_attachment"] = expected_attachment mock_responses.get( f"{TEST_FILE_DOWNLOAD_URL}/sfc/servlet.shepherd/version/download/download_id", - status=response_body, + status=response_status, body=response_body, ) mock_responses.get( @@ -1061,6 +1064,60 @@ async def test_get_all_with_content_docs_when_success( TestCase().assertCountEqual(content_document_records, [expected_doc]) +@pytest.mark.asyncio +async def test_get_all_with_content_docs_and_extraction_service(mock_responses): + with patch( + "connectors.content_extraction.ContentExtraction.extract_text", + return_value="chunk1", + ), patch( + "connectors.content_extraction.ContentExtraction.get_extraction_config", + return_value={"host": "http://localhost:8090"}, + ): + async with create_salesforce_source(use_text_extraction_service=True) as source: + expected_doc = { + "_id": "content_document_id", + "content_size": 1000, + "created_at": "", + "created_by": "Frodo", + "created_by_email": "frodo@tlotr.com", + "body": "chunk1", + "description": "A file about a ring.", + "file_extension": "txt", + "last_updated": "", + "linked_ids": [ + "account_id", + "campaign_id", + "case_id", + "contact_id", + "lead_id", + "opportunity_id", + ], # contains every SObject that is connected to this doc + "owner": "Frodo", + "owner_email": "frodo@tlotr.com", + "title": "the_ring.txt", + "type": "content_document", + "url": f"{TEST_BASE_URL}/content_document_id", + "version_number": "2", + "version_url": f"{TEST_BASE_URL}/content_version_id", + } + + mock_responses.get( + f"{TEST_FILE_DOWNLOAD_URL}/sfc/servlet.shepherd/version/download/download_id", + status=200, + body=b"chunk1", + ) + mock_responses.get( + TEST_QUERY_MATCH_URL, repeat=True, callback=salesforce_query_callback + ) + + content_document_records = [] + async for record, _ in source.get_docs(): + if record["type"] == "content_document": + content_document_records.append(record) + + TestCase().assertCountEqual(content_document_records, [expected_doc]) + + @pytest.mark.asyncio async def test_prepare_sobject_cache(mock_responses): async with create_salesforce_source() as source: From 439efa2f54017436de40e6072d55305c7215e778 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:34:04 +0200 Subject: [PATCH 10/12] Lint --- connectors/source.py | 7 ++++- connectors/sources/salesforce.py | 52 +++++--------------------------- tests/sources/test_salesforce.py | 4 ++- 3 files changed, 16 insertions(+), 47 deletions(-) diff --git a/connectors/source.py b/connectors/source.py index eec9aa439..f15df3507 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -724,7 +724,12 @@ def is_file_size_within_limit(self, file_size, filename): return True async def download_and_extract_file( - self, doc, source_filename, file_extension, download_func, return_doc_if_failed=False + self, + doc, + source_filename, + file_extension, + download_func, + return_doc_if_failed=False, ): """ Performs all the steps required for handling binary content: diff --git a/connectors/sources/salesforce.py b/connectors/sources/salesforce.py index a0b214943..840453c0c 100644 --- a/connectors/sources/salesforce.py +++ b/connectors/sources/salesforce.py @@ -4,15 +4,11 @@ # you may not use this file except in compliance with the Elastic License 2.0. # """Salesforce source module responsible to fetch documents from Salesforce.""" -import asyncio import os from functools import cached_property, partial from itertools import groupby -import aiofiles import aiohttp -from aiofiles.os import remove -from aiofiles.tempfile import NamedTemporaryFile from aiohttp.client_exceptions import ClientResponseError from connectors.logger import logger @@ -20,7 +16,6 @@ from connectors.utils import ( TIKA_SUPPORTED_FILETYPES, CancellableSleeps, - convert_to_b64, retryable, ) @@ -293,15 +288,6 @@ async def get_case_feeds(self, case_ids): return all_case_feeds - # async def download_content_documents(self, content_documents): - # for content_document in content_documents: - # content_version = content_document.get("LatestPublishedVersion", {}) or {} - # download_url = content_version.get("VersionDataUrl") - # if download_url: - # content_document["_attachment"] = await self._download(download_url) - # - # yield content_document - async def queryable_sobjects(self): """Cached async property""" if self._queryable_sobjects is not None: @@ -425,34 +411,6 @@ async def _execute_non_paginated_query(self, soql_query): ) return response.get("records") - async def _download(self, download_url): - attachment = None - source_file_name = "" - - try: - async with NamedTemporaryFile(mode="wb", delete=False) as async_buffer: - resp = await self._get(download_url) - async for data in resp.content.iter_chunked(CHUNK_SIZE): - await async_buffer.write(data) - source_file_name = async_buffer.name - - await asyncio.to_thread( - convert_to_b64, - source=source_file_name, - ) - - async with aiofiles.open(file=source_file_name, mode="r") as target_file: - attachment = (await target_file.read()).strip() - except Exception as e: - self._logger.error( - f"Exception encountered when processing file: {source_file_name}. Exception: {e}" - ) - finally: - if source_file_name: - await remove(str(source_file_name)) - - return attachment - async def _auth_headers(self): token = await self.api_token.token() return {"authorization": f"Bearer {token}"} @@ -1432,9 +1390,13 @@ async def get_docs(self, filtering=None): # Note: this could possibly be done on the fly if memory becomes an issue content_docs = self._combine_duplicate_content_docs(content_docs) for content_doc in content_docs: - download_url = (content_doc.get("LatestPublishedVersion", {}) or {}).get("VersionDataUrl") + download_url = (content_doc.get("LatestPublishedVersion", {}) or {}).get( + "VersionDataUrl" + ) if not download_url: - self._logger.debug(f"No download URL found for {content_doc.get('title')}, skipping.") + self._logger.debug( + f"No download URL found for {content_doc.get('title')}, skipping." + ) continue doc = self.doc_mapper.map_content_document(content_doc) @@ -1460,7 +1422,7 @@ async def get_content(self, doc, download_url): download_url, ), ), - return_doc_if_failed=True, # we still ingest on download failure for Salesforce + return_doc_if_failed=True, # we still ingest on download failure for Salesforce ) def _parse_content_documents(self, record): diff --git a/tests/sources/test_salesforce.py b/tests/sources/test_salesforce.py index cc2a3f7e7..a4b5b4ab1 100644 --- a/tests/sources/test_salesforce.py +++ b/tests/sources/test_salesforce.py @@ -422,7 +422,9 @@ @asynccontextmanager -async def create_salesforce_source(use_text_extraction_service=False, mock_token=True, mock_queryables=True): +async def create_salesforce_source( + use_text_extraction_service=False, mock_token=True, mock_queryables=True +): async with create_source( SalesforceDataSource, domain=TEST_DOMAIN, From cf4ef15d4dd37175a20490ab3eb3adcbad9de057 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Thu, 28 Sep 2023 18:02:37 +0200 Subject: [PATCH 11/12] Include stack trace in failed download logs --- connectors/source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connectors/source.py b/connectors/source.py index f15df3507..c4dd7341a 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -760,7 +760,7 @@ async def download_and_extract_file( return doc except Exception as e: self._logger.warning( - f"File download and extraction or conversion for file {source_filename} failed: {e}" + f"File download and extraction or conversion for file {source_filename} failed: {e}", exc_info=True ) if return_doc_if_failed: return doc From b9fe37744bd9724b3b4b82104f0c124d70bf3b02 Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Thu, 28 Sep 2023 18:06:06 +0200 Subject: [PATCH 12/12] Placate linter --- connectors/source.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/connectors/source.py b/connectors/source.py index c4dd7341a..5be97303c 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -760,7 +760,8 @@ async def download_and_extract_file( return doc except Exception as e: self._logger.warning( - f"File download and extraction or conversion for file {source_filename} failed: {e}", exc_info=True + f"File download and extraction or conversion for file {source_filename} failed: {e}", + exc_info=True, ) if return_doc_if_failed: return doc