Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand extraction service to more connectors #3 #1694

Merged
merged 13 commits into from
Sep 29, 2023
65 changes: 48 additions & 17 deletions connectors/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,11 @@ def get_file_extension(self, filename):
return get_file_extension(filename)

def can_file_be_downloaded(self, file_extension, filename, file_size):
return self.is_valid_file_type(
file_extension, filename
) and self.is_file_size_within_limit(file_size, filename)

def is_valid_file_type(self, file_extension, filename):
if file_extension == "":
self._logger.debug(
f"Files without extension are not supported, skipping {filename}."
Expand All @@ -705,6 +710,9 @@ def can_file_be_downloaded(self, file_extension, filename, file_size):
)
return False

return True

def is_file_size_within_limit(self, file_size, filename):
if file_size > FILE_SIZE_LIMIT and not self.configuration.get(
"use_text_extraction_service"
):
Expand All @@ -716,26 +724,49 @@ def can_file_be_downloaded(self, file_extension, filename, file_size):
return True

async def download_and_extract_file(
self, doc, source_filename, file_extension, download_func
self,
doc,
source_filename,
file_extension,
download_func,
return_doc_if_failed=False,
):
# 1 create tempfile
async with self.create_temp_file(file_extension) as async_buffer:
temp_filename = async_buffer.name

# 2 download to tempfile
await self.download_to_temp_file(
temp_filename,
source_filename,
async_buffer,
download_func,
)
"""
Performs all the steps required for handling binary content:
1. Make temp file
2. Download content to temp file
3. Extract using local service or convert to b64

# 3 extract or convert content
doc = await self.handle_file_content_extraction(
doc, source_filename, temp_filename
)
Will return the doc with either `_attachment` or `body` added.
Returns `None` if any step fails.

return doc
If the optional arg `return_doc_if_failed` is `True`,
will return the original doc upon failure
"""
try:
async with self.create_temp_file(file_extension) as async_buffer:
temp_filename = async_buffer.name

await self.download_to_temp_file(
temp_filename,
source_filename,
async_buffer,
download_func,
)

doc = await self.handle_file_content_extraction(
doc, source_filename, temp_filename
)
return doc
except Exception as e:
self._logger.warning(
f"File download and extraction or conversion for file {source_filename} failed: {e}",
exc_info=True,
)
if return_doc_if_failed:
return doc
else:
return

@asynccontextmanager
async def create_temp_file(self, file_extension):
Expand Down
4 changes: 1 addition & 3 deletions connectors/sources/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,13 @@ async def get_content(self, blob, timestamp=None, doit=None):
return

document = {"_id": blob["id"], "_timestamp": blob["_timestamp"]}
document = await self.download_and_extract_file(
return await self.download_and_extract_file(
document,
filename,
file_extension,
partial(self.blob_download_func, filename, blob["container"]),
)

return document

async def blob_download_func(self, blob_name, container_name):
async with BlobClient.from_connection_string(
conn_str=self.connection_string,
Expand Down
4 changes: 1 addition & 3 deletions connectors/sources/confluence.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ async def download_attachment(self, url, attachment, timestamp=None, doit=False)
return

document = {"_id": attachment["_id"], "_timestamp": attachment["_timestamp"]}
document = await self.download_and_extract_file(
return await self.download_and_extract_file(
document,
filename,
file_extension,
Expand All @@ -666,8 +666,6 @@ async def download_attachment(self, url, attachment, timestamp=None, doit=False)
),
)

return document

async def _attachment_coro(self, document, access_control):
"""Coroutine to add attachments to Queue and download content

Expand Down
4 changes: 1 addition & 3 deletions connectors/sources/dropbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ async def get_content(
"_id": attachment["id"],
"_timestamp": attachment["server_modified"],
}
document = await self.download_and_extract_file(
return await self.download_and_extract_file(
document,
filename,
file_extension,
Expand All @@ -681,8 +681,6 @@ async def get_content(
),
)

return document

def download_func(self, is_shared, attachment, filename):
if is_shared:
return partial(
Expand Down
53 changes: 24 additions & 29 deletions connectors/sources/google_cloud_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import urllib.parse
from functools import cached_property, partial

import aiofiles
from aiofiles.os import remove
from aiofiles.tempfile import NamedTemporaryFile
from aiogoogle import Aiogoogle
from aiogoogle.auth.creds import ServiceAccountCreds

Expand All @@ -22,7 +19,7 @@
load_service_account_json,
validate_service_account_json,
)
from connectors.utils import TIKA_SUPPORTED_FILETYPES, convert_to_b64, get_pem_format
from connectors.utils import get_pem_format

CLOUD_STORAGE_READ_ONLY_SCOPE = "https://www.googleapis.com/auth/devstorage.read_only"
CLOUD_STORAGE_BASE_URL = "https://console.cloud.google.com/storage/browser/_details/"
Expand Down Expand Up @@ -229,6 +226,14 @@ def get_default_configuration(cls):
"type": "int",
"ui_restrictions": ["advanced"],
},
"use_text_extraction_service": {
"display": "toggle",
"label": "Use text extraction service",
"order": 3,
"tooltip": "Requires a separate deployment of the Elastic Text Extraction Service. Requires that pipeline settings disable text extraction.",
"type": "bool",
"value": False,
},
}

async def validate_config(self):
Expand Down Expand Up @@ -370,50 +375,40 @@ async def get_content(self, blob, timestamp=None, doit=None):
Returns:
dictionary: Content document with id, timestamp & text
"""
blob_size = int(blob["size"])
if not (doit and blob_size):
file_size = int(blob["size"])
if not (doit and file_size):
return

blob_name = blob["name"]
if (os.path.splitext(blob_name)[-1]).lower() not in TIKA_SUPPORTED_FILETYPES:
self._logger.debug(f"{blob_name} can't be extracted")
filename = blob["name"]
file_extension = self.get_file_extension(filename)
if not self.can_file_be_downloaded(file_extension, filename, file_size):
return

if blob_size > DEFAULT_FILE_SIZE_LIMIT:
self._logger.warning(
f"File size {blob_size} of file {blob_name} is larger than {DEFAULT_FILE_SIZE_LIMIT} bytes. Discarding the file content"
)
return
self._logger.debug(f"Downloading {blob_name}")
document = {
"_id": blob["id"],
"_timestamp": blob["_timestamp"],
}
source_file_name = ""
async with NamedTemporaryFile(mode="wb", delete=False) as async_buffer:

# gcs has a unique download method so we can't utilize
# the generic download_and_extract_file func
Comment on lines +392 to +393
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful comment 👍

Is there anything we could generify, so that we don't have usages of create_temp_file in multiple places? Like would it help if the base function could take a proc as an optional arg or something? Not for this PR, but something we can think about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue with calling create_temp_file here?
I think we can look into generifying some of the non-standard downloads. Mostly the issue is they pipe directly to a file, but my generic download func doesn't support that. I felt strapped for time to make two different versions of it so for now I've not generified downloads that pipe directly to files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with calling create_temp_file is that every time it's called, there's a chance that the author does it in such a way that the temp file won't be cleaned up. We've had this issue before, where like 8/10 connectors were cleaning up their tempfiles appropriately, but due to copy-paste errors, occasionally some wouldn't. These types of bugs can be hard to catch, and its easier to keep them from propagating if you just don't have numerous usages of the risky code.

But again, not necessary to solve right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanstory I was also concerned about that. It may alleviate your concerns but I think I have this covered with the way tempfiles are being created now. If a connector uses create_temp_file it will clean itself up after everything is done, including deleting the file and outputting an error if the file deletion failed.

The code in question: https://github.com/elastic/connectors-python/blob/b9fe37744bd9724b3b4b82104f0c124d70bf3b02/connectors/source.py#L771-L783

Of course we should properly check to see if this is actually the case.

async with self.create_temp_file(file_extension) as async_buffer:
await anext(
self._google_storage_client.api_call(
resource="objects",
method="get",
bucket=blob["bucket_name"],
object=blob_name,
object=filename,
alt="media",
userProject=self._google_storage_client.user_project_id,
pipe_to=async_buffer,
)
)
source_file_name = async_buffer.name
await async_buffer.close()

document = await self.handle_file_content_extraction(
document, filename, async_buffer.name
)

self._logger.debug(f"Calling convert_to_b64 for file : {blob_name}")
await asyncio.to_thread(
convert_to_b64,
source=source_file_name,
)
async with aiofiles.open(file=source_file_name, mode="r") as target_file:
# base64 on macOS will add a EOL, so we strip() here
document["_attachment"] = (await target_file.read()).strip()
await remove(str(source_file_name))
self._logger.debug(f"Downloaded {blob_name} for {blob_size} bytes ")
return document

async def get_docs(self, filtering=None):
Expand Down
Loading