Skip to content

Commit

Permalink
Merge pull request #21 from LCOGT/feature/rgb-composite
Browse files Browse the repository at this point in the history
Feature/rgb composite
  • Loading branch information
LTDakin authored Aug 1, 2024
2 parents c39fb22 + 8ac1c7c commit 5de3961
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 60 deletions.
44 changes: 1 addition & 43 deletions datalab/datalab_session/data_operations/data_operation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from abc import ABC, abstractmethod
import hashlib
import json
import tempfile

from django.core.cache import cache
from fits2image.conversions import fits_to_jpg
from astropy.io import fits
import numpy as np

from datalab.datalab_session.tasks import execute_data_operation
from datalab.datalab_session.util import add_file_to_bucket, get_hdu
from datalab.datalab_session.util import get_hdu

CACHE_DURATION = 60 * 60 * 24 * 30 # cache for 30 days

Expand Down Expand Up @@ -99,45 +96,6 @@ def set_failed(self, message: str):
self.set_status('FAILED')
self.set_message(message)

# percent lets you allocate a fraction of the operation that this takes up in time
# cur_percent is the current completion of the operation
def create_and_store_fits(self, hdu_list: fits.HDUList, percent=None, cur_percent=None) -> list:
if not type(hdu_list) == list:
hdu_list = [hdu_list]

output = []
total_files = len(hdu_list)

# Create temp file paths for storing the products
fits_path = tempfile.NamedTemporaryFile(suffix=f'{self.cache_key}.fits').name
large_jpg_path = tempfile.NamedTemporaryFile(suffix=f'{self.cache_key}-large.jpg').name
thumbnail_jpg_path = tempfile.NamedTemporaryFile(suffix=f'{self.cache_key}-small.jpg').name

for index, hdu in enumerate(hdu_list, start=1):
height, width = hdu[1].shape

hdu.writeto(fits_path)
fits_to_jpg(fits_path, large_jpg_path, width=width, height=height)
fits_to_jpg(fits_path, thumbnail_jpg_path)

# Save Fits and Thumbnails in S3 Buckets
fits_url = add_file_to_bucket(f'{self.cache_key}/{self.cache_key}-{index}.fits', fits_path)
large_jpg_url = add_file_to_bucket(f'{self.cache_key}/{self.cache_key}-{index}-large.jpg', large_jpg_path)
thumbnail_jpg_url = add_file_to_bucket(f'{self.cache_key}/{self.cache_key}-{index}-small.jpg', thumbnail_jpg_path)

output.append({
'fits_url': fits_url,
'large_url': large_jpg_url,
'thumbnail_url': thumbnail_jpg_url,
'basename': f'{self.cache_key}-{index}',
'source': 'datalab'}
)

if percent is not None and cur_percent is not None:
self.set_percent_completion(cur_percent + index/total_files * percent)

return output

def get_fits_npdata(self, input_files: list[dict], percent=None, cur_percent=None) -> list[np.memmap]:
total_files = len(input_files)
image_data_list = []
Expand Down
2 changes: 1 addition & 1 deletion datalab/datalab_session/data_operations/long.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def wizard_description():
'description': 'The input files to operate on',
'type': 'file',
'minimum': 1,
'maxmimum': 999
'maximum': 999
},
'duration': {
'name': 'Duration',
Expand Down
12 changes: 7 additions & 5 deletions datalab/datalab_session/data_operations/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.util import create_fits, stack_arrays
from datalab.datalab_session.util import create_fits, stack_arrays, create_jpgs, save_fits_and_thumbnails

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -33,7 +33,7 @@ def wizard_description():
'description': 'The input files to operate on',
'type': 'file',
'minimum': 1,
'maxmimum': 999
'maximum': 999
}
}
}
Expand All @@ -52,11 +52,13 @@ def operate(self):
# using the numpy library's median method
median = np.median(stacked_data, axis=2)

hdu_list = create_fits(self.cache_key, median)
fits_file = create_fits(self.cache_key, median)

output = self.create_and_store_fits(hdu_list, percent=0.6, cur_percent=0.4)
large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_file)

output = {'output_files': output}
output_file = save_fits_and_thumbnails(self.cache_key, fits_file, large_jpg_path, small_jpg_path)

output = {'output_files': [output_file]}
else:
output = {'output_files': []}

Expand Down
2 changes: 1 addition & 1 deletion datalab/datalab_session/data_operations/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def wizard_description():
'description': 'The input files to operate on',
'type': 'file',
'minimum': 1,
'maxmimum': 999
'maximum': 999
},
'scalar_parameter_1': {
'name': 'Scalar Parameter 1',
Expand Down
83 changes: 83 additions & 0 deletions datalab/datalab_session/data_operations/rgb_stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging

from astropy.io import fits

from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from datalab.datalab_session.util import get_fits, stack_arrays, create_fits, save_fits_and_thumbnails, create_jpgs

log = logging.getLogger()
log.setLevel(logging.INFO)


class RGB_Stack(BaseDataOperation):

@staticmethod
def name():
return 'RGB Stack'

@staticmethod
def description():
return """The RGB Stack operation takes in 3 input images which have red, green, and blue filters and creates a colored image by compositing them on top of each other."""

@staticmethod
def wizard_description():
return {
'name': RGB_Stack.name(),
'description': RGB_Stack.description(),
'category': 'image',
'inputs': {
'red_input': {
'name': 'Red Filter',
'description': 'Three images to stack their RGB values',
'type': 'file',
'minimum': 1,
'maximum': 1,
'filter': ['rp', 'r']
},
'green_input': {
'name': 'Green Filter',
'description': 'Three images to stack their RGB values',
'type': 'file',
'minimum': 1,
'maximum': 1,
'filter': ['V', 'gp']
},
'blue_input': {
'name': 'Blue Filter',
'description': 'Three images to stack their RGB values',
'type': 'file',
'minimum': 1,
'maximum': 1,
'filter': ['B']
}
}
}

def operate(self):
rgb_input_list = self.input_data['red_input'] + self.input_data['green_input'] + self.input_data['blue_input']

if len(rgb_input_list) == 3:
log.info(f'Executing RGB Stack operation on files: {rgb_input_list}')

fits_paths = []
for file in rgb_input_list:
fits_paths.append(get_fits(file.get('basename')))
self.set_percent_completion(self.get_percent_completion() + 0.2)

large_jpg_path, small_jpg_path = create_jpgs(self.cache_key, fits_paths, color=True)

# color photos take three files, so we store it as one fits file with a 3d SCI ndarray
arrays = [fits.open(file)['SCI'].data for file in fits_paths]
stacked_data = stack_arrays(arrays)
fits_file = create_fits(self.cache_key, stacked_data)

output_file = save_fits_and_thumbnails(self.cache_key, fits_file, large_jpg_path, small_jpg_path)

output = {'output_files': [output_file]}
else:
output = {'output_files': []}
raise ValueError('RGB Stack operation requires exactly 3 input files')

self.set_percent_completion(1.0)
self.set_output(output)
log.info(f'RGB Stack output: {self.get_output()}')
79 changes: 69 additions & 10 deletions datalab/datalab_session/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
import requests
import logging
import os
Expand All @@ -6,9 +7,10 @@
import boto3
from astropy.io import fits
import numpy as np
from botocore.exceptions import ClientError

from django.conf import settings
from botocore.exceptions import ClientError
from fits2image.conversions import fits_to_jpg

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -111,14 +113,10 @@ def get_archive_url(basename: str, archive: str = settings.ARCHIVE_API) -> dict:
fits_url = results[0].get('url', 'No URL found')
return fits_url

def get_hdu(basename: str, extension: str = 'SCI', source: str = 'archive') -> list[fits.HDUList]:
def get_fits(basename: str, source: str = 'archive'):
"""
Returns a HDU for the given basename from the source
Will download the file to a tmp directory so future calls can open it directly
Warning: this function returns an opened file that must be closed after use
Returns a Fits File for the given basename from the source
"""

# use the basename to fetch and create a list of hdu objects
basename = basename.replace('-large', '').replace('-small', '')
basename_file_path = os.path.join(settings.TEMP_FITS_DIR, basename)

Expand All @@ -139,6 +137,17 @@ def get_hdu(basename: str, extension: str = 'SCI', source: str = 'archive') -> l
raise ValueError(f"Source {source} not recognized")

urllib.request.urlretrieve(fits_url, basename_file_path)

return basename_file_path

def get_hdu(basename: str, extension: str = 'SCI', source: str = 'archive') -> list[fits.HDUList]:
"""
Returns a HDU for the given basename from the source
Will download the file to a tmp directory so future calls can open it directly
Warning: this function returns an opened file that must be closed after use
"""

basename_file_path = get_fits(basename, source)

hdu = fits.open(basename_file_path)
try:
Expand All @@ -148,15 +157,65 @@ def get_hdu(basename: str, extension: str = 'SCI', source: str = 'archive') -> l

return extension

def create_fits(key: str, image_arr: np.ndarray) -> fits.HDUList:
def get_fits_dimensions(fits_file, extension: str = 'SCI') -> tuple:
return fits.open(fits_file)[extension].shape

def create_fits(key: str, image_arr: np.ndarray) -> str:
"""
Creates a fits file with the given key and image array
Returns the the path to the fits_file
"""

header = fits.Header([('KEY', key)])
primary_hdu = fits.PrimaryHDU(header=header)
image_hdu = fits.ImageHDU(data=image_arr, name='SCI')

hdu_list = fits.HDUList([primary_hdu, image_hdu])

return hdu_list
fits_path = tempfile.NamedTemporaryFile(suffix=f'{key}.fits').name
hdu_list.writeto(fits_path)

return fits_path

def create_jpgs(cache_key, fits_paths: str, color=False) -> list:
"""
Create jpgs from fits files and save them to S3
If using the color option fits_paths need to be in order R, G, B
percent and cur_percent are used to update the progress of the operation
"""

if not isinstance(fits_paths, list):
fits_paths = [fits_paths]

# create the jpgs from the fits files
large_jpg_path = tempfile.NamedTemporaryFile(suffix=f'{cache_key}-large.jpg').name
thumbnail_jpg_path = tempfile.NamedTemporaryFile(suffix=f'{cache_key}-small.jpg').name

max_height, max_width = max(get_fits_dimensions(path) for path in fits_paths)

fits_to_jpg(fits_paths, large_jpg_path, width=max_width, height=max_height, color=color)
fits_to_jpg(fits_paths, thumbnail_jpg_path, color=color)

return large_jpg_path, thumbnail_jpg_path

def save_fits_and_thumbnails(cache_key, fits_path, large_jpg_path, thumbnail_jpg_path, index=None):
"""
Save Fits and Thumbnails in S3 Buckets, Returns the URLs in an output object
"""
bucket_key = f'{cache_key}/{cache_key}-{index}' if index else f'{cache_key}/{cache_key}'

fits_url = add_file_to_bucket(f'{bucket_key}.fits', fits_path)
large_jpg_url = add_file_to_bucket(f'{bucket_key}-large.jpg', large_jpg_path)
thumbnail_jpg_url = add_file_to_bucket(f'{bucket_key}-small.jpg', thumbnail_jpg_path)

output_file = dict({
'fits_url': fits_url,
'large_url': large_jpg_url,
'thumbnail_url': thumbnail_jpg_url,
'basename': f'{cache_key}',
'source': 'datalab'}
)

return output_file

def stack_arrays(array_list: list):
"""
Expand Down

0 comments on commit 5de3961

Please sign in to comment.