diff --git a/src/pybdy/nemo_bdy_extr_assist.py b/src/pybdy/nemo_bdy_extr_assist.py index b92a717e..a742f747 100644 --- a/src/pybdy/nemo_bdy_extr_assist.py +++ b/src/pybdy/nemo_bdy_extr_assist.py @@ -317,9 +317,51 @@ def get_vertical_weights_zco(dst_dep, dst_len_z, num_bdy, sc_z, sc_z_len): return z9_dist, z9_ind -def interp_vertical( - sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, data_ind, num_bdy, zinterp=True -): +def flood_fill(sc_bdy, isslab, logger): + """ + Fill the data horizontally then downwards to remove nans before interpolation. + + Parameters + ---------- + sc_bdy (np.array) : souce data [nz_sc, nbdy, 9] + isslab (bool) : if true data has vertical cells for vertical flood fill + logger : log of statements + + Returns + ------- + sc_bdy (np.array) : souce data [nz_sc, nbdy, 9] + """ + # identify valid pts + data_ind, nan_ind = valid_index(sc_bdy, logger) + + # Set sc land pts to nan + sc_bdy[nan_ind] = np.nan + sc_shape = sc_bdy.shape + + for i in range(sc_shape[0]): + while np.isnan(sc_bdy[i, :, 0]).any() & (~np.isnan(sc_bdy[i, :, 0])).any(): + # Flood sc land horizontally within the chunk for the centre point. + # This may not be perfect but better than filling with zeros + sc_nan = np.isnan(sc_bdy) + sc_bdy[:, 1:, 0][sc_nan[:, 1:, 0]] = sc_bdy[:, :-1, 0][sc_nan[:, 1:, 0]] + sc_nan = np.isnan(sc_bdy) + sc_bdy[:, :-1, 0][sc_nan[:, :-1, 0]] = sc_bdy[:, 1:, 0][sc_nan[:, :-1, 0]] + + if not isslab: + data_ind, nan_ind = valid_index(sc_bdy, logger) + # Fill down using deepest pt + ind_bdy = np.arange(sc_shape[1]) + all_bot = np.tile( + sc_bdy[data_ind[:, 0], ind_bdy, 0], (sc_shape[0], sc_shape[2], 1) + ).transpose((0, 2, 1)) + sc_bdy[:, :, 0][np.isnan(sc_bdy[:, :, 0])] = all_bot[:, :, 0][ + np.isnan(sc_bdy[:, :, 0]) + ] + + return sc_bdy + + +def interp_vertical(sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, num_bdy, zinterp=True): """ Interpolate source data onto destination vertical levels. @@ -330,7 +372,6 @@ def interp_vertical( bdy_bathy (np.array): the destination grid bdy points bathymetry z_ind (np.array) : the indices of the sc depth above and below bdy z_dist (np.array) : the distance weights of the selected points - data_ind (np.array) : bool points above bathymetry that are valid num_bdy (int) : number of boundary points in chunk zinterp (bool) : vertical interpolation flag @@ -338,15 +379,8 @@ def interp_vertical( ------- data_out (np.array) : source data on destination depth levels """ - # If all else fails fill down using deepest pt - sc_shape = sc_bdy.shape - ind_bdy = np.arange(sc_bdy.shape[1]) - all_bot = np.tile( - sc_bdy[data_ind[:, 0], ind_bdy, 0], (sc_bdy.shape[0], sc_bdy.shape[2], 1) - ).transpose((0, 2, 1)) - sc_bdy[np.isnan(sc_bdy)] = all_bot[np.isnan(sc_bdy)] - if zinterp is True: + sc_shape = sc_bdy.shape sc_bdy = sc_bdy.flatten("F") # Weighted averaged on new vertical grid diff --git a/src/pybdy/nemo_bdy_extr_tm3.py b/src/pybdy/nemo_bdy_extr_tm3.py index 10a8dd41..edbdf2db 100644 --- a/src/pybdy/nemo_bdy_extr_tm3.py +++ b/src/pybdy/nemo_bdy_extr_tm3.py @@ -913,8 +913,8 @@ def extract_month(self, year, month): else: sc_z_len = self.sc_z_len - # identify valid pts - data_ind, _ = extr_assist.valid_index(sc_bdy[vn], self.logger) + # Flood fill + sc_bdy[vn] = extr_assist.flood_fill(sc_bdy[vn], isslab, self.logger) if not isslab: # Vertical interpolation @@ -924,7 +924,6 @@ def extract_month(self, year, month): self.bdy_z[chunk_d], self.z_ind[chunk_z, :], self.z_dist[chunk_z, :], - data_ind, self.num_bdy_ch[chk], self.settings["zinterp"], ) @@ -933,8 +932,6 @@ def extract_month(self, year, month): sc_bdy_lev = sc_bdy[vn] sc_bdy_lev[:, np.isnan(self.bdy_z[chunk_d]), :] = np.NaN - _, nan_ind = extr_assist.valid_index(sc_bdy_lev, self.logger) - # distance weightings for averaging source data to destination dist_wei, dist_fac = extr_assist.distance_weights( sc_bdy_lev, @@ -951,6 +948,13 @@ def extract_month(self, year, month): # weight vector array and rotate onto dest grid if self.key_vec: + # Do the same for both components of u and v velocities + + # Flood fill + sc_bdy[vn + 1] = extr_assist.flood_fill( + sc_bdy[vn + 1], isslab, self.logger + ) + if not isslab: # Vertical interpolation sc_bdy_lev2 = extr_assist.interp_vertical( @@ -959,7 +963,6 @@ def extract_month(self, year, month): self.bdy_z[chunk_d], self.z_ind[chunk_z, :], self.z_dist[chunk_z, :], - data_ind, self.num_bdy_ch[chk], self.settings["zinterp"], ) @@ -1009,26 +1012,6 @@ def extract_month(self, year, month): # Finished first run operations # self.first = False - if self.settings["zinterp"]: - # Set land pts to zero - self.logger.info( - " pre dst_bdy[nan_ind] %s %s", - np.nanmin(dst_bdy), - np.nanmax(dst_bdy), - ) - - dst_bdy[nan_ind] = 0 - self.logger.info( - " post dst_bdy %s %s", - np.nanmin(dst_bdy), - np.nanmax(dst_bdy), - ) - # Remove any data on dst grid that is in land - dst_bdy[:, np.isnan(self.bdy_z[chunk_d])] = 0 - self.logger.info( - " 3 dst_bdy %s %s", np.nanmin(dst_bdy), np.nanmax(dst_bdy) - ) - data_out = dst_bdy # add data to self.d_bdy diff --git a/src/pybdy/nemo_bdy_zgrv2.py b/src/pybdy/nemo_bdy_zgrv2.py index ecbac295..c82c0e03 100644 --- a/src/pybdy/nemo_bdy_zgrv2.py +++ b/src/pybdy/nemo_bdy_zgrv2.py @@ -175,8 +175,8 @@ def get_bdy_depths(DstCoord, bdy_i, grd): bdy_e3 = np.ma.zeros((m_e.shape[0], len(g_ind))) for k in range(m_w.shape[0]): tmp_w = np.ma.masked_where(mbathy + 1 < k + 1, m_w[k, :, :]) - tmp_t = np.ma.masked_where(mbathy + 1 < k + 1, m_t[k, :, :]) - tmp_e = np.ma.masked_where(mbathy + 1 < k + 1, m_e[k, :, :]) + tmp_t = np.ma.masked_where(mbathy < k + 1, m_t[k, :, :]) + tmp_e = np.ma.masked_where(mbathy < k + 1, m_e[k, :, :]) tmp_w = tmp_w.flatten("F") tmp_t = tmp_t.flatten("F") diff --git a/tests/test_nemo_bdy_extr_assist.py b/tests/test_nemo_bdy_extr_assist.py index 8528d0c2..a1da774c 100644 --- a/tests/test_nemo_bdy_extr_assist.py +++ b/tests/test_nemo_bdy_extr_assist.py @@ -215,6 +215,67 @@ def test_get_vertical_weights_sco(): assert not errors, "errors occured:\n{}".format("\n".join(errors)) +def test_flood_fill(): + # Test the flood_fill function + logger = logging.getLogger(__name__) + max_depth = 100 + num_bdy = 3 + bdy_bathy = np.array([0, 0, 80.5, 100, 60]) + sc_z_len = 15 + gdept, _ = synth_zgrid.synth_zco(max_depth, sc_z_len) + sc_z = np.tile(gdept, (3, 5, 1)).T + sc_bdy = np.tile(np.linspace(12, 5, num=sc_z_len), (3, 5, 1)).T # Temperature data + + # Centre then clockwise from 12 then corners + ind_g = np.array([[1, 2, 1, 0, 1, 2, 0, 0, 2], [1, 1, 2, 1, 0, 2, 2, 0, 0]]) + ind = np.zeros((num_bdy, 9), dtype=int) + ind[0, :] = np.ravel_multi_index(ind_g, (3, 5), order="F") + ind_g[1, :] = ind_g[1, :] + 1 + ind[1, :] = np.ravel_multi_index(ind_g, (3, 5), order="F") + ind_g[1, :] = ind_g[1, :] + 1 + ind[2, :] = np.ravel_multi_index(ind_g, (3, 5), order="F") + + bathy_tile = np.transpose(np.tile(bdy_bathy, (sc_z_len, 3, 1)), (0, 2, 1)) + sc_bdy = np.ma.masked_where(sc_z > bathy_tile, sc_bdy) + sc_bdy = sc_bdy.filled(np.nan) + + sc_bdy = sc_bdy.reshape((sc_bdy.shape[0], sc_bdy.shape[1] * sc_bdy.shape[2]))[ + :, ind + ] + + # Run function + sc_bdy = extr_assist.flood_fill(sc_bdy, False, logger) + + # Check results + lev_test = np.array( + [ + 12.0, + 11.5, + 11.0, + 10.5, + 10.0, + 9.5, + 9.0, + 8.5, + 8.0, + 7.5, + 7.0, + 6.5, + 6.0, + 5.5, + 5.5, + ] + ) + print(sc_bdy[:, 0, 0]) + errors = [] + if not (sc_bdy.shape == (sc_z_len, num_bdy, 9)): + errors.append("Error with output sc_bdy shape.") + elif not np.isclose(sc_bdy[:, 0, 0], lev_test, atol=1e-4).all(): + errors.append("Error with sc_bdy_lev.") + # assert no error message has been registered, else print messages + assert not errors, "errors occured:\n{}".format("\n".join(errors)) + + def test_interp_vertical(): # Test the interp_vertical function logger = logging.getLogger(__name__) @@ -247,11 +308,11 @@ def test_interp_vertical(): z_dist, z_ind = extr_assist.get_vertical_weights( dst_dep, dst_len_z, num_bdy, sc_z, sc_z_len, ind, zco ) - data_ind, _ = extr_assist.valid_index(sc_bdy, logger) + sc_bdy = extr_assist.flood_fill(sc_bdy, False, logger) # Run function sc_bdy_lev = extr_assist.interp_vertical( - sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, data_ind, num_bdy + sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, num_bdy ) sc_bdy_lev[np.isnan(sc_bdy_lev)] = -1 diff --git a/tests/test_zz_end_to_end.py b/tests/test_zz_end_to_end.py index 63458c76..cfce8411 100644 --- a/tests/test_zz_end_to_end.py +++ b/tests/test_zz_end_to_end.py @@ -101,18 +101,18 @@ def test_zco_zco(): "Num_var_co": 21, "Num_var_t": 11, "Min_gdept": 41.66666793823242, - "Max_gdept": 1041.6666259765625, + "Max_gdept": 958.3333129882812, "Shape_temp": (30, 25, 1, 1584), "Shape_ssh": (30, 1, 1584), "Shape_mask": (60, 50), - "Mean_temp": 16.437015533447266, - "Mean_sal": 30.8212833404541, - "Sum_unmask": 495030, - "Sum_mask": 692970, + "Mean_temp": 18.003202438354492, + "Mean_sal": 34.08450698852539, + "Sum_unmask": 447510, + "Sum_mask": 740490, "Shape_u": (30, 25, 1, 1566), "Shape_v": (30, 25, 1, 1566), - "Mean_u": 0.9023050665855408, - "Mean_v": 0.8996582627296448, + "Mean_u": 0.9975701570510864, + "Mean_v": 0.9893443584442139, } assert summary_grid == test_grid, "May need to update regression values." @@ -181,18 +181,18 @@ def test_sco_sco(): "Num_var_co": 21, "Num_var_t": 11, "Min_gdept": 3.8946874141693115, - "Max_gdept": 991.9130859375, + "Max_gdept": 966.4112548828125, "Shape_temp": (30, 15, 1, 1584), "Shape_ssh": (30, 1, 1584), "Shape_mask": (60, 50), - "Mean_temp": 17.26908302307129, - "Mean_sal": 31.9005126953125, - "Sum_unmask": 712800, - "Sum_mask": 0, + "Mean_temp": 18.50259017944336, + "Mean_sal": 34.17912292480469, + "Sum_unmask": 665280, + "Sum_mask": 47520, "Shape_u": (30, 15, 1, 1566), "Shape_v": (30, 15, 1, 1566), - "Mean_u": 0.7552474737167358, - "Mean_v": 0.7674047350883484, + "Mean_u": 0.8442354202270508, + "Mean_v": 0.832240879535675, } assert summary_grid == test_grid, "May need to update regression values." @@ -261,18 +261,18 @@ def test_sco_zco(): "Num_var_co": 21, "Num_var_t": 11, "Min_gdept": 41.66666793823242, - "Max_gdept": 1041.6666259765625, + "Max_gdept": 958.3333129882812, "Shape_temp": (30, 25, 1, 1584), "Shape_ssh": (30, 1, 1584), "Shape_mask": (60, 50), - "Mean_temp": 16.77924156188965, - "Mean_sal": 30.870267868041992, - "Sum_unmask": 495030, - "Sum_mask": 692970, + "Mean_temp": 18.56098747253418, + "Mean_sal": 34.14830780029297, + "Sum_unmask": 447510, + "Sum_mask": 740490, "Shape_u": (30, 25, 1, 1566), "Shape_v": (30, 25, 1, 1566), - "Mean_u": 0.7566075921058655, - "Mean_v": 0.7544463276863098, + "Mean_u": 0.8378004431724548, + "Mean_v": 0.8349438905715942, } assert summary_grid == test_grid, "May need to update regression values." @@ -344,18 +344,18 @@ def test_wrap_sc(): "Num_var_co": 21, "Num_var_t": 11, "Min_gdept": 41.66666793823242, - "Max_gdept": 1041.6666259765625, + "Max_gdept": 958.3333129882812, "Shape_temp": (30, 25, 1, 1584), "Shape_ssh": (30, 1, 1584), "Shape_mask": (60, 50), - "Mean_temp": 17.540178298950195, - "Mean_sal": 30.812232971191406, - "Sum_unmask": 495030, - "Sum_mask": 692970, + "Mean_temp": 19.402727127075195, + "Mean_sal": 34.084110260009766, + "Sum_unmask": 447510, + "Sum_mask": 740490, "Shape_u": (30, 25, 1, 1566), "Shape_v": (30, 25, 1, 1566), - "Mean_u": 0.9030880331993103, - "Mean_v": 0.9035892486572266, + "Mean_u": 1.0, + "Mean_v": 1.0, "Temp_Strip1": [ 19.80881690979004, 24.904409408569336,