Skip to content

Commit

Permalink
Merge pull request #20 from LCOGT/fix/check-for-source-archive
Browse files Browse the repository at this point in the history
Checking for source of inputs to determine where to fetch fits file from
  • Loading branch information
LTDakin authored Jul 18, 2024
2 parents e4b54ed + 803150f commit c39fb22
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
17 changes: 6 additions & 11 deletions datalab/datalab_session/data_operations/data_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np

from datalab.datalab_session.tasks import execute_data_operation
from datalab.datalab_session.util import add_file_to_bucket, get_archive_from_basename
from datalab.datalab_session.util import add_file_to_bucket, get_hdu

CACHE_DURATION = 60 * 60 * 24 * 30 # cache for 30 days

Expand Down Expand Up @@ -58,7 +58,7 @@ def operate(self):
def perform_operation(self):
""" The generic method to perform perform the operation if its not in progress """
status = self.get_status()
if status == 'PENDING':
if status == 'PENDING' or status == 'FAILED':
self.set_status('IN_PROGRESS')
self.set_percent_completion(0.0)
# This asynchronous task will call the operate() method on the proper operation
Expand Down Expand Up @@ -126,6 +126,7 @@ def create_and_store_fits(self, hdu_list: fits.HDUList, percent=None, cur_percen
thumbnail_jpg_url = add_file_to_bucket(f'{self.cache_key}/{self.cache_key}-{index}-small.jpg', thumbnail_jpg_path)

output.append({
'fits_url': fits_url,
'large_url': large_jpg_url,
'thumbnail_url': thumbnail_jpg_url,
'basename': f'{self.cache_key}-{index}',
Expand All @@ -144,16 +145,10 @@ def get_fits_npdata(self, input_files: list[dict], percent=None, cur_percent=Non
# get the fits urls and extract the image data
for index, file_info in enumerate(input_files, start=1):
basename = file_info.get('basename', 'No basename found')
archive_record = get_archive_from_basename(basename)
source = file_info.get('source', 'No source found')

try:
fits_url = archive_record[0].get('url', 'No URL found')
except IndexError as e:
raise FileNotFoundError(f"No image found with specified basename: {basename} Error: {e}")

with fits.open(fits_url) as hdu_list:
data = hdu_list['SCI'].data
image_data_list.append(data)
sci_hdu = get_hdu(basename, 'SCI', source)
image_data_list.append(sci_hdu.data)

if percent is not None and cur_percent is not None:
self.set_percent_completion(cur_percent + index/total_files * percent)
Expand Down
5 changes: 5 additions & 0 deletions datalab/datalab_session/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import dramatiq
import logging

from datalab.datalab_session.data_operations.utils import available_operations
from requests.exceptions import RequestException

log = logging.getLogger()
log.setLevel(logging.INFO)

# Retry network connection errors 3 times, all other exceptions are not retried
def should_retry(retries_so_far, exception):
return retries_so_far < 3 and isinstance(exception, RequestException)
Expand All @@ -16,4 +20,5 @@ def execute_data_operation(data_operation_name: str, input_data: dict):
try:
operation_class(input_data).operate()
except Exception as e:
log.error(f"Error executing {data_operation_name}: {type(e).__name__}:{e}")
operation_class(input_data).set_failed(str(e))
50 changes: 28 additions & 22 deletions datalab/datalab_session/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def add_file_to_bucket(item_key: str, path: object) -> str:
Returns:
A presigned url for the object just added to the bucket
"""
log.info(f'Adding {item_key} to {settings.DATALAB_OPERATION_BUCKET}')

s3 = boto3.client('s3')
try:
response = s3.upload_file(
Expand All @@ -37,11 +35,11 @@ def add_file_to_bucket(item_key: str, path: object) -> str:
log.error(f'Error uploading the operation output: {e}')
raise ClientError(f'Error uploading the operation output')

return get_presigned_url(item_key)
return get_s3_url(item_key)

def get_presigned_url(key: str) -> str:
def get_s3_url(key: str, bucket: str = settings.DATALAB_OPERATION_BUCKET) -> str:
"""
Gets a presigned url from the operation bucket using the key
Gets a presigned url from the bucket using the key
Args:
item_key -- name to look up in the bucket
Expand All @@ -55,14 +53,14 @@ def get_presigned_url(key: str) -> str:
url = s3.generate_presigned_url(
ClientMethod='get_object',
Params={
'Bucket': settings.DATALAB_OPERATION_BUCKET,
'Bucket': bucket,
'Key': key
},
ExpiresIn = 60 * 60 * 24 * 30 # URL will be valid for 30 days
)
except ClientError as e:
log.error(f'Could not find the image for {key}: {e}')
raise ClientError(f'Could not find the image for {key}')
log.error(f'Could not generate url for {key}: {e}')
raise ClientError(f'Could not create url for {key}')

return url

Expand All @@ -81,7 +79,7 @@ def key_exists(key: str) -> bool:
response = s3.list_objects_v2(Bucket=settings.DATALAB_OPERATION_BUCKET, Prefix=key, MaxKeys=1)
return 'Contents' in response

def get_archive_from_basename(basename: str) -> dict:
def get_archive_url(basename: str, archive: str = settings.ARCHIVE_API) -> dict:
"""
Looks for the key as a prefix in the operations s3 bucket
Expand All @@ -97,7 +95,7 @@ def get_archive_from_basename(basename: str) -> dict:
'Authorization': f'Token {settings.ARCHIVE_API_TOKEN}'
}

response = requests.get(settings.ARCHIVE_API + '/frames/', params=query_params, headers=headers)
response = requests.get(archive + '/frames/', params=query_params, headers=headers)

try:
response.raise_for_status()
Expand All @@ -110,11 +108,12 @@ def get_archive_from_basename(basename: str) -> dict:
if not results:
raise FileNotFoundError(f"Could not find {basename} in the archive")

return results
fits_url = results[0].get('url', 'No URL found')
return fits_url

def get_hdu(basename: str, extension: str = 'SCI') -> list[fits.HDUList]:
def get_hdu(basename: str, extension: str = 'SCI', source: str = 'archive') -> list[fits.HDUList]:
"""
Returns a HDU for the given basename
Returns a HDU for the given basename from the source
Will download the file to a tmp directory so future calls can open it directly
Warning: this function returns an opened file that must be closed after use
"""
Expand All @@ -123,30 +122,37 @@ def get_hdu(basename: str, extension: str = 'SCI') -> list[fits.HDUList]:
basename = basename.replace('-large', '').replace('-small', '')
basename_file_path = os.path.join(settings.TEMP_FITS_DIR, basename)

# download the file if it isn't already downloaded in our temp directory
if not os.path.isfile(basename_file_path):

# create the tmp directory if it doesn't exist
if not os.path.exists(settings.TEMP_FITS_DIR):
os.makedirs(settings.TEMP_FITS_DIR)

archive_record = get_archive_from_basename(basename)
fits_url = archive_record[0].get('url', 'No URL found')
match source:
case 'archive':
fits_url = get_archive_url(basename)
case 'datalab':
s3_folder_path = f'{basename.split("-")[0]}/{basename}.fits'
fits_url = get_s3_url(s3_folder_path)
case _:
raise ValueError(f"Source {source} not recognized")

urllib.request.urlretrieve(fits_url, basename_file_path)

hdu = fits.open(basename_file_path)
extension = hdu[extension]

if not extension:
log.error(f"{extension} Header not found in fits file {basename}")
raise ValueError(f"{extension} Header not found in fits file {basename}")
try:
extension = hdu[extension]
except KeyError:
raise KeyError(f"{extension} Header not found in fits file {basename}")

return extension

def create_fits(key: str, image_arr: np.ndarray) -> fits.HDUList:

header = fits.Header([('KEY', key)])
primary_hdu = fits.PrimaryHDU(header=header)
image_hdu = fits.ImageHDU(image_arr)
image_hdu = fits.ImageHDU(data=image_arr, name='SCI')

hdu_list = fits.HDUList([primary_hdu, image_hdu])

Expand Down

0 comments on commit c39fb22

Please sign in to comment.