diff --git a/aotools/functions/ell_zernike.py b/aotools/functions/ell_zernike.py index c92d138..bb08222 100644 --- a/aotools/functions/ell_zernike.py +++ b/aotools/functions/ell_zernike.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import numpy as np -from zernike import RZern import matplotlib.pyplot as plt from scipy.linalg import eig +from aotools.functions import zernike ''' Normalised "Zernike" polynomials for the elliptical aperture. @@ -11,14 +11,13 @@ ''' class ZernikeEllipticalaperture: - def __init__(self, rmax, npix, a, b, l, ell_aperture=True, coeff=None): - self.rmax = rmax #maximum radial order + def __init__(self, nterms, npix, a, b, ell_aperture=True, coeff=None): + self.nterms = nterms #maximum radial order self.npix = npix #number of pixels for the map self.a = a #major semi-axis (normalised) self.b = b #minor semi-axis (normalised) - self.l = l #number of elliptical modes (polynomials) - self.ell_aperture = ell_aperture # Specify whether to use the elliptical aperture - self.coeff = coeff # Coefficients for the Zernike modes (optional) + self.ell_aperture = ell_aperture # Specify whether to use the elliptical aperture. False leads to a standard circular aperture + self.coeff = coeff # Coefficients for the Zernike modes (optional). Default random. self.ell_aperture_mask = self.GenerateEllipticalAperture() self.circular_zern = self.GetCircZernikeValue() self.E = self.CalculateEllipticalZernike() @@ -29,18 +28,7 @@ def GetCircZernikeValue(self): Return: zern_value ''' - - zernike = RZern(self.rmax) - xx, yy = np.meshgrid(np.linspace(-1, 1, self.npix), np.linspace(-1, 1, self.npix)) - rho = np.sqrt(xx**2 + yy**2) - theta = np.arctan2(yy, xx) - zernike.make_cart_grid(xx, yy) - - zern_value = [] - nterms = int((self.rmax + 1) * (self.rmax + 2) / 2) - for i in range(nterms): - zern_value.append(zernike.Zk(i, rho, theta)) - + zern_value = zernike.zernikeArray(self.nterms, self.npix) zern_value = np.array(zern_value / np.linalg.norm(zern_value)).squeeze() return zern_value @@ -50,18 +38,19 @@ def CalculateEllipticalZernike(self): Return: E ''' - Z = self.GetCircZernikeValue() + + Z = self.circular_zern M = self.M_matrix() E = [] # Initialize a list to store E arrays for each l - for i in range(1, self.l + 1): + for i in range(1, self.nterms + 1): E_l = np.zeros(Z[0].shape) # Initialize E with the same shape as Z[0] for j in range(1, i + 1): E_l += M[i - 1, j - 1] * Z[j - 1] E.append(E_l) E = np.array(E) - if self.ell_aperture: + if self.ell_aperture == True: E[:, np.logical_not(self.ell_aperture_mask)] = 0 return E @@ -87,15 +76,14 @@ def C_zern(self): Return: C ''' - nterms = int((self.rmax + 1) * (self.rmax + 2) / 2) # Initialize the C matrix - C = np.zeros((nterms, nterms)) + C = np.zeros((self.nterms, self.nterms)) # Calculate the area of each grid cell dx = (2 * self.a) / 10000 dy = (2 * self.b) / 10000 - for i in range(nterms): - for j in range(i, nterms): + for i in range(self.nterms): + for j in range(i, self.nterms): product_Zern = np.dot(self.circular_zern[i], self.circular_zern[j]) * dx * dy C[i, j] += np.sum(product_Zern) if i != j: @@ -119,13 +107,16 @@ def EllZernikeMap(self, coeff=None): Return: phi ''' xx, yy = np.meshgrid(np.linspace(-1, 1, self.npix), np.linspace(-1, 1, self.npix)) - E_ell = np.zeros((xx.size, self.l)) + E_ell = np.zeros((xx.size, self.nterms)) - for k in range(self.l): + for k in range(self.nterms): E_ell[:, k] = np.ravel(self.E[k]) if coeff is None: - coeff = np.random.random(self.l) + coeff = np.random.random(self.nterms) + + if len(coeff) != self.nterms: + raise ValueError(f"Coefficient array must have length {self.nterms}, but got {len(coeff)}.") phi = np.dot(E_ell, coeff) phi = phi.reshape(xx.shape) diff --git a/test/test_ell_zernike.py b/test/test_ell_zernike.py index 03472ed..e26f937 100644 --- a/test/test_ell_zernike.py +++ b/test/test_ell_zernike.py @@ -1,20 +1,34 @@ -from aotools import functions +from aotools.functions import ell_zernike import matplotlib.pyplot as plt +import numpy as np -if __name__ == '__main__': - a = 1 - b = 0.8 - rmax = 7 - npix = 256 - l = 35 +def test_ZernikeEllipticalaperture(): + # Define parameters for the ZernikeEllipticalaperture instance + nterms = 6 # Number of Zernike terms + npix = 256 # Number of pixels in each dimension + a = 1.0 # Semi-major axis of the elliptical aperture + b = 0.5 # Semi-minor axis of the elliptical aperture - ell_zern = functions.ZernikeEllipticalaperture(rmax, npix, a, b, l) + zernike_instance = ell_zernike.ZernikeEllipticalaperture(nterms, npix, a, b) - Ell = ell_zern.CalculateEllipticalZernike() - plt.imshow(Ell[2]) - plt.show() + assert zernike_instance.ell_aperture_mask.shape == (npix, npix), "Aperture mask shape is incorrect" - phi = ell_zern.EllZernikeMap() - plt.imshow(phi) - plt.show() + assert np.all(np.isin(zernike_instance.ell_aperture_mask, [0, 1])), "Aperture mask should contain only 0s and 1s" + + assert zernike_instance.E.shape == (nterms, npix, npix), "Zernike modes shape is incorrect" + + assert np.any(zernike_instance.E[0][zernike_instance.GenerateEllipticalAperture() == 1]) != 0, "First Zernike mode should have non-zero values in the aperture" + + expected_number_of_modes = nterms + assert zernike_instance.E.shape[0] == expected_number_of_modes, f"Expected {expected_number_of_modes} Zernike modes, got {zernike_instance.E.shape[0]}" + + phi = zernike_instance.EllZernikeMap() + assert phi.shape == (npix, npix), "Output shape is incorrect when no coefficients are provided." + + coeff = np.random.random(nterms) + phi_with_coeff = zernike_instance.EllZernikeMap(coeff) + assert phi_with_coeff.shape == (npix, npix), "Output shape is incorrect with provided coefficients." + + +test_ZernikeEllipticalaperture() diff --git a/test/test_zernike.py b/test/test_zernike.py index c236acb..1306322 100644 --- a/test/test_zernike.py +++ b/test/test_zernike.py @@ -1,5 +1,6 @@ from aotools import functions import numpy +import matplotlib.pyplot as plt def test_zernIndex(): @@ -8,7 +9,6 @@ def test_zernIndex(): index = functions.zernIndex(i) assert(index == results[i-1]) - def test_makegammas(): gammas = functions.makegammas(5) assert(gammas.shape == (2, 21, 21)) @@ -39,6 +39,8 @@ def test_zernike(): def test_zernikeArray_single(): zernike_array = functions.zernikeArray(10, 32) assert(zernike_array.shape == (10, 32, 32)) + plt.imshow(zernike_array[0]) + plt.show() def test_zernikeArray_list(): @@ -55,3 +57,5 @@ def test_zernikeArray_comparison(): def test_phaseFromZernikes(): phase_map = functions.phaseFromZernikes([1, 2, 3, 4, 5], 32) assert(phase_map.shape == (32, 32)) + +