Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions src/pybdy/nemo_bdy_extr_assist.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,51 @@ def get_vertical_weights_zco(dst_dep, dst_len_z, num_bdy, sc_z, sc_z_len):
return z9_dist, z9_ind


def interp_vertical(
sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, data_ind, num_bdy, zinterp=True
):
def flood_fill(sc_bdy, isslab, logger):
"""
Fill the data horizontally then downwards to remove nans before interpolation.

Parameters
----------
sc_bdy (np.array) : souce data [nz_sc, nbdy, 9]
isslab (bool) : if true data has vertical cells for vertical flood fill
logger : log of statements

Returns
-------
sc_bdy (np.array) : souce data [nz_sc, nbdy, 9]
"""
# identify valid pts
data_ind, nan_ind = valid_index(sc_bdy, logger)

# Set sc land pts to nan
sc_bdy[nan_ind] = np.nan
sc_shape = sc_bdy.shape

for i in range(sc_shape[0]):
while np.isnan(sc_bdy[i, :, 0]).any() & (~np.isnan(sc_bdy[i, :, 0])).any():
# Flood sc land horizontally within the chunk for the centre point.
# This may not be perfect but better than filling with zeros
sc_nan = np.isnan(sc_bdy)
sc_bdy[:, 1:, 0][sc_nan[:, 1:, 0]] = sc_bdy[:, :-1, 0][sc_nan[:, 1:, 0]]
sc_nan = np.isnan(sc_bdy)
sc_bdy[:, :-1, 0][sc_nan[:, :-1, 0]] = sc_bdy[:, 1:, 0][sc_nan[:, :-1, 0]]

if not isslab:
data_ind, nan_ind = valid_index(sc_bdy, logger)
# Fill down using deepest pt
ind_bdy = np.arange(sc_shape[1])
all_bot = np.tile(
sc_bdy[data_ind[:, 0], ind_bdy, 0], (sc_shape[0], sc_shape[2], 1)
).transpose((0, 2, 1))
sc_bdy[:, :, 0][np.isnan(sc_bdy[:, :, 0])] = all_bot[:, :, 0][
np.isnan(sc_bdy[:, :, 0])
]

return sc_bdy


def interp_vertical(sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, num_bdy, zinterp=True):
"""
Interpolate source data onto destination vertical levels.

Expand All @@ -330,23 +372,15 @@ def interp_vertical(
bdy_bathy (np.array): the destination grid bdy points bathymetry
z_ind (np.array) : the indices of the sc depth above and below bdy
z_dist (np.array) : the distance weights of the selected points
data_ind (np.array) : bool points above bathymetry that are valid
num_bdy (int) : number of boundary points in chunk
zinterp (bool) : vertical interpolation flag

Returns
-------
data_out (np.array) : source data on destination depth levels
"""
# If all else fails fill down using deepest pt
sc_shape = sc_bdy.shape
ind_bdy = np.arange(sc_bdy.shape[1])
all_bot = np.tile(
sc_bdy[data_ind[:, 0], ind_bdy, 0], (sc_bdy.shape[0], sc_bdy.shape[2], 1)
).transpose((0, 2, 1))
sc_bdy[np.isnan(sc_bdy)] = all_bot[np.isnan(sc_bdy)]

if zinterp is True:
sc_shape = sc_bdy.shape
sc_bdy = sc_bdy.flatten("F")

# Weighted averaged on new vertical grid
Expand Down
35 changes: 9 additions & 26 deletions src/pybdy/nemo_bdy_extr_tm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,8 @@ def extract_month(self, year, month):
else:
sc_z_len = self.sc_z_len

# identify valid pts
data_ind, _ = extr_assist.valid_index(sc_bdy[vn], self.logger)
# Flood fill
sc_bdy[vn] = extr_assist.flood_fill(sc_bdy[vn], isslab, self.logger)

if not isslab:
# Vertical interpolation
Expand All @@ -924,7 +924,6 @@ def extract_month(self, year, month):
self.bdy_z[chunk_d],
self.z_ind[chunk_z, :],
self.z_dist[chunk_z, :],
data_ind,
self.num_bdy_ch[chk],
self.settings["zinterp"],
)
Expand All @@ -933,8 +932,6 @@ def extract_month(self, year, month):
sc_bdy_lev = sc_bdy[vn]
sc_bdy_lev[:, np.isnan(self.bdy_z[chunk_d]), :] = np.NaN

_, nan_ind = extr_assist.valid_index(sc_bdy_lev, self.logger)

# distance weightings for averaging source data to destination
dist_wei, dist_fac = extr_assist.distance_weights(
sc_bdy_lev,
Expand All @@ -951,6 +948,13 @@ def extract_month(self, year, month):

# weight vector array and rotate onto dest grid
if self.key_vec:
# Do the same for both components of u and v velocities

# Flood fill
sc_bdy[vn + 1] = extr_assist.flood_fill(
sc_bdy[vn + 1], isslab, self.logger
)

if not isslab:
# Vertical interpolation
sc_bdy_lev2 = extr_assist.interp_vertical(
Expand All @@ -959,7 +963,6 @@ def extract_month(self, year, month):
self.bdy_z[chunk_d],
self.z_ind[chunk_z, :],
self.z_dist[chunk_z, :],
data_ind,
self.num_bdy_ch[chk],
self.settings["zinterp"],
)
Expand Down Expand Up @@ -1009,26 +1012,6 @@ def extract_month(self, year, month):
# Finished first run operations
# self.first = False

if self.settings["zinterp"]:
# Set land pts to zero
self.logger.info(
" pre dst_bdy[nan_ind] %s %s",
np.nanmin(dst_bdy),
np.nanmax(dst_bdy),
)

dst_bdy[nan_ind] = 0
self.logger.info(
" post dst_bdy %s %s",
np.nanmin(dst_bdy),
np.nanmax(dst_bdy),
)
# Remove any data on dst grid that is in land
dst_bdy[:, np.isnan(self.bdy_z[chunk_d])] = 0
self.logger.info(
" 3 dst_bdy %s %s", np.nanmin(dst_bdy), np.nanmax(dst_bdy)
)

data_out = dst_bdy

# add data to self.d_bdy
Expand Down
4 changes: 2 additions & 2 deletions src/pybdy/nemo_bdy_zgrv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def get_bdy_depths(DstCoord, bdy_i, grd):
bdy_e3 = np.ma.zeros((m_e.shape[0], len(g_ind)))
for k in range(m_w.shape[0]):
tmp_w = np.ma.masked_where(mbathy + 1 < k + 1, m_w[k, :, :])
tmp_t = np.ma.masked_where(mbathy + 1 < k + 1, m_t[k, :, :])
tmp_e = np.ma.masked_where(mbathy + 1 < k + 1, m_e[k, :, :])
tmp_t = np.ma.masked_where(mbathy < k + 1, m_t[k, :, :])
tmp_e = np.ma.masked_where(mbathy < k + 1, m_e[k, :, :])

tmp_w = tmp_w.flatten("F")
tmp_t = tmp_t.flatten("F")
Expand Down
65 changes: 63 additions & 2 deletions tests/test_nemo_bdy_extr_assist.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,67 @@ def test_get_vertical_weights_sco():
assert not errors, "errors occured:\n{}".format("\n".join(errors))


def test_flood_fill():
# Test the flood_fill function
logger = logging.getLogger(__name__)
max_depth = 100
num_bdy = 3
bdy_bathy = np.array([0, 0, 80.5, 100, 60])
sc_z_len = 15
gdept, _ = synth_zgrid.synth_zco(max_depth, sc_z_len)
sc_z = np.tile(gdept, (3, 5, 1)).T
sc_bdy = np.tile(np.linspace(12, 5, num=sc_z_len), (3, 5, 1)).T # Temperature data

# Centre then clockwise from 12 then corners
ind_g = np.array([[1, 2, 1, 0, 1, 2, 0, 0, 2], [1, 1, 2, 1, 0, 2, 2, 0, 0]])
ind = np.zeros((num_bdy, 9), dtype=int)
ind[0, :] = np.ravel_multi_index(ind_g, (3, 5), order="F")
ind_g[1, :] = ind_g[1, :] + 1
ind[1, :] = np.ravel_multi_index(ind_g, (3, 5), order="F")
ind_g[1, :] = ind_g[1, :] + 1
ind[2, :] = np.ravel_multi_index(ind_g, (3, 5), order="F")

bathy_tile = np.transpose(np.tile(bdy_bathy, (sc_z_len, 3, 1)), (0, 2, 1))
sc_bdy = np.ma.masked_where(sc_z > bathy_tile, sc_bdy)
sc_bdy = sc_bdy.filled(np.nan)

sc_bdy = sc_bdy.reshape((sc_bdy.shape[0], sc_bdy.shape[1] * sc_bdy.shape[2]))[
:, ind
]

# Run function
sc_bdy = extr_assist.flood_fill(sc_bdy, False, logger)

# Check results
lev_test = np.array(
[
12.0,
11.5,
11.0,
10.5,
10.0,
9.5,
9.0,
8.5,
8.0,
7.5,
7.0,
6.5,
6.0,
5.5,
5.5,
]
)
print(sc_bdy[:, 0, 0])
errors = []
if not (sc_bdy.shape == (sc_z_len, num_bdy, 9)):
errors.append("Error with output sc_bdy shape.")
elif not np.isclose(sc_bdy[:, 0, 0], lev_test, atol=1e-4).all():
errors.append("Error with sc_bdy_lev.")
# assert no error message has been registered, else print messages
assert not errors, "errors occured:\n{}".format("\n".join(errors))


def test_interp_vertical():
# Test the interp_vertical function
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -247,11 +308,11 @@ def test_interp_vertical():
z_dist, z_ind = extr_assist.get_vertical_weights(
dst_dep, dst_len_z, num_bdy, sc_z, sc_z_len, ind, zco
)
data_ind, _ = extr_assist.valid_index(sc_bdy, logger)
sc_bdy = extr_assist.flood_fill(sc_bdy, False, logger)

# Run function
sc_bdy_lev = extr_assist.interp_vertical(
sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, data_ind, num_bdy
sc_bdy, dst_dep, bdy_bathy, z_ind, z_dist, num_bdy
)
sc_bdy_lev[np.isnan(sc_bdy_lev)] = -1

Expand Down
56 changes: 28 additions & 28 deletions tests/test_zz_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ def test_zco_zco():
"Num_var_co": 21,
"Num_var_t": 11,
"Min_gdept": 41.66666793823242,
"Max_gdept": 1041.6666259765625,
"Max_gdept": 958.3333129882812,
"Shape_temp": (30, 25, 1, 1584),
"Shape_ssh": (30, 1, 1584),
"Shape_mask": (60, 50),
"Mean_temp": 16.437015533447266,
"Mean_sal": 30.8212833404541,
"Sum_unmask": 495030,
"Sum_mask": 692970,
"Mean_temp": 18.003202438354492,
"Mean_sal": 34.08450698852539,
"Sum_unmask": 447510,
"Sum_mask": 740490,
"Shape_u": (30, 25, 1, 1566),
"Shape_v": (30, 25, 1, 1566),
"Mean_u": 0.9023050665855408,
"Mean_v": 0.8996582627296448,
"Mean_u": 0.9975701570510864,
"Mean_v": 0.9893443584442139,
}

assert summary_grid == test_grid, "May need to update regression values."
Expand Down Expand Up @@ -181,18 +181,18 @@ def test_sco_sco():
"Num_var_co": 21,
"Num_var_t": 11,
"Min_gdept": 3.8946874141693115,
"Max_gdept": 991.9130859375,
"Max_gdept": 966.4112548828125,
"Shape_temp": (30, 15, 1, 1584),
"Shape_ssh": (30, 1, 1584),
"Shape_mask": (60, 50),
"Mean_temp": 17.26908302307129,
"Mean_sal": 31.9005126953125,
"Sum_unmask": 712800,
"Sum_mask": 0,
"Mean_temp": 18.50259017944336,
"Mean_sal": 34.17912292480469,
"Sum_unmask": 665280,
"Sum_mask": 47520,
"Shape_u": (30, 15, 1, 1566),
"Shape_v": (30, 15, 1, 1566),
"Mean_u": 0.7552474737167358,
"Mean_v": 0.7674047350883484,
"Mean_u": 0.8442354202270508,
"Mean_v": 0.832240879535675,
}

assert summary_grid == test_grid, "May need to update regression values."
Expand Down Expand Up @@ -261,18 +261,18 @@ def test_sco_zco():
"Num_var_co": 21,
"Num_var_t": 11,
"Min_gdept": 41.66666793823242,
"Max_gdept": 1041.6666259765625,
"Max_gdept": 958.3333129882812,
"Shape_temp": (30, 25, 1, 1584),
"Shape_ssh": (30, 1, 1584),
"Shape_mask": (60, 50),
"Mean_temp": 16.77924156188965,
"Mean_sal": 30.870267868041992,
"Sum_unmask": 495030,
"Sum_mask": 692970,
"Mean_temp": 18.56098747253418,
"Mean_sal": 34.14830780029297,
"Sum_unmask": 447510,
"Sum_mask": 740490,
"Shape_u": (30, 25, 1, 1566),
"Shape_v": (30, 25, 1, 1566),
"Mean_u": 0.7566075921058655,
"Mean_v": 0.7544463276863098,
"Mean_u": 0.8378004431724548,
"Mean_v": 0.8349438905715942,
}

assert summary_grid == test_grid, "May need to update regression values."
Expand Down Expand Up @@ -344,18 +344,18 @@ def test_wrap_sc():
"Num_var_co": 21,
"Num_var_t": 11,
"Min_gdept": 41.66666793823242,
"Max_gdept": 1041.6666259765625,
"Max_gdept": 958.3333129882812,
"Shape_temp": (30, 25, 1, 1584),
"Shape_ssh": (30, 1, 1584),
"Shape_mask": (60, 50),
"Mean_temp": 17.540178298950195,
"Mean_sal": 30.812232971191406,
"Sum_unmask": 495030,
"Sum_mask": 692970,
"Mean_temp": 19.402727127075195,
"Mean_sal": 34.084110260009766,
"Sum_unmask": 447510,
"Sum_mask": 740490,
"Shape_u": (30, 25, 1, 1566),
"Shape_v": (30, 25, 1, 1566),
"Mean_u": 0.9030880331993103,
"Mean_v": 0.9035892486572266,
"Mean_u": 1.0,
"Mean_v": 1.0,
"Temp_Strip1": [
19.80881690979004,
24.904409408569336,
Expand Down
Loading