From b2c6ba625b49d465c786bf4a2d06720c4b2305e6 Mon Sep 17 00:00:00 2001 From: Bhavin Nayak Date: Fri, 1 Apr 2016 12:35:03 -0700 Subject: [PATCH] Changes made as per suggestions. refs #60 --- planetaryimage/image.py | 20 +++++++++++-------- planetaryimage/pds3image.py | 11 +++++----- tests/test_pds3file.py | 40 ++++++++++++++++++++++++++++--------- 3 files changed, 49 insertions(+), 22 deletions(-) diff --git a/planetaryimage/image.py b/planetaryimage/image.py index b1b6ea3..48374d8 100644 --- a/planetaryimage/image.py +++ b/planetaryimage/image.py @@ -21,7 +21,11 @@ def open(cls, filename): (``.gz``) or bzip2 (``.bz2``) compressed. """ if isinstance(filename, numpy.ndarray): - return cls(filename, 'numpy_array') + error_msg = ( + 'A file like object is expected for stream. ' + 'Use PDS3Image(numpy_array) to create a PDS3Image object.' + ) + raise TypeError(error_msg) else: if filename.endswith('.gz'): fp = gzip.open(filename, 'rb') @@ -39,7 +43,7 @@ def open(cls, filename): with open(filename, 'rb') as fp: return cls(fp, filename) - def __init__(self, stream_or_array, filename=None, compression=None): + def __init__(self, stream_string_or_array, filename=None, compression=None): """Create an Image object. Parameters @@ -54,18 +58,18 @@ def __init__(self, stream_or_array, filename=None, compression=None): compression : string an optional string that indicate the compression type 'bz2' or 'gz' """ - if isinstance(stream_or_array, six.string_types): + if isinstance(stream_string_or_array, six.string_types): error_msg = ( 'A file like object is expected for stream. ' 'Use %s.open(filename) to open a image file.' ) raise TypeError(error_msg % type(self).__name__) - if isinstance(stream_or_array, numpy.ndarray): + if isinstance(stream_string_or_array, numpy.ndarray): self.filename = 'numpy_array' self.compression = None - self.data = stream_or_array - self.label = self._create_label(stream_or_array) + self.data = stream_string_or_array + self.label = self._create_label(stream_string_or_array) else: #: The filename if given, otherwise none. self.filename = filename @@ -74,10 +78,10 @@ def __init__(self, stream_or_array, filename=None, compression=None): # TODO: rename to header and add footer? #: The parsed label header in dictionary form. - self.label = self._load_label(stream_or_array) + self.label = self._load_label(stream_string_or_array) #: A numpy array representing the image - self.data = self._load_data(stream_or_array) + self.data = self._load_data(stream_string_or_array) def __repr__(self): # TODO: pick a better repr diff --git a/planetaryimage/pds3image.py b/planetaryimage/pds3image.py index 8b3125e..35263ec 100644 --- a/planetaryimage/pds3image.py +++ b/planetaryimage/pds3image.py @@ -138,8 +138,9 @@ def _save(self, file_to_write, overwrite): if self._sample_bytes != self.data.itemsize: self.label['IMAGE']['SAMPLE_BITS'] = self.data.itemsize * 8 - sample_type_to_save = self.DTYPES[self._sample_type[0] + self.dtype.kind] - self.label['IMAGE']['SAMPLE_TYPE'] = sample_type_to_save + + sample_type_to_save = self.DTYPES[self._sample_type[0] + self.dtype.kind] + self.label['IMAGE']['SAMPLE_TYPE'] = sample_type_to_save if len(self.data.shape) == 3: self.label['IMAGE']['BANDS'] = self.data.shape[0] @@ -197,10 +198,10 @@ def _create_label(self, array): def _update_label(self, label, array): maximum = float(numpy.max(array)) - mean = numpy.mean(array) - median = numpy.median(array) + mean = float(numpy.mean(array)) + median = float(numpy.median(array)) minimum = float(numpy.min(array)) - stdev = numpy.std(array, ddof=1) + stdev = float(numpy.std(array, ddof=1)) encoder = pvl.encoder.PDSLabelEncoder serial_label = pvl.dumps(label, cls=encoder) diff --git a/tests/test_pds3file.py b/tests/test_pds3file.py index a487d65..ed6a0cf 100644 --- a/tests/test_pds3file.py +++ b/tests/test_pds3file.py @@ -182,18 +182,40 @@ def test_image_save_float_to_int(): os.remove('Temp_Image.IMG') -def test_numpy_array_save(): +def test_numpy_array_save_i2(): image = PDS3Image.open(filename) - temp = PDS3Image.open(image.data) + array = numpy.arange(100, dtype='>i2') + array = array.reshape(10, 10) + temp = PDS3Image(array) temp.save('Temp_Image.IMG') image_temp = PDS3Image.open('Temp_Image.IMG') - assert image_temp.bands == image.bands - assert image_temp.lines == image.lines - assert image_temp.samples == image.samples - assert image_temp.format == image.format - assert image_temp.dtype == image.dtype - assert image_temp.shape == image.shape - assert image_temp.size == image.size + assert image_temp.bands == 1 + assert image_temp.lines == 10 + assert image_temp.samples == 10 + assert image_temp.format == 'BAND_SEQUENTIAL' + assert image_temp.dtype == '>i2' + assert image_temp.shape == (1, 10, 10) + assert image_temp.size == 100 + assert_almost_equal(image_temp.data, image.data) + os.remove('Temp_Image.IMG') + + +def test_numpy_array_save_f4(): + image = PDS3Image.open(filename_float) + array = numpy.arange(100) + array = array.reshape(10, 10) + array = array * 1.5 + array = array.astype('>f4') + temp = PDS3Image(array) + temp.save('Temp_Image.IMG') + image_temp = PDS3Image.open('Temp_Image.IMG') + assert image_temp.bands == 1 + assert image_temp.lines == 10 + assert image_temp.samples == 10 + assert image_temp.format == 'BAND_SEQUENTIAL' + assert image_temp.dtype == '>f4' + assert image_temp.shape == (1, 10, 10) + assert image_temp.size == 100 assert_almost_equal(image_temp.data, image.data) os.remove('Temp_Image.IMG')