Skip to content

Commit dba2ac1

Browse files
reint-fischerreint-fischer
authored andcommitted
fix field creation from code review suggestions
1 parent 5f41847 commit dba2ac1

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

docs/user_guide/examples/tutorial_diffusion.ipynb

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@
114114
"metadata": {},
115115
"outputs": [],
116116
"source": [
117-
"import random\n",
118-
"\n",
119117
"import matplotlib.pyplot as plt\n",
120118
"import numpy as np\n",
121119
"import trajan as ta\n",
@@ -197,7 +195,7 @@
197195
"ds = simple_UV_dataset(dims=(1, 1, Ny, 1), mesh=\"flat\").isel(time=0, depth=0)\n",
198196
"ds[\"lat\"][:] = np.linspace(-0.01, 1.01, Ny)\n",
199197
"ds[\"lon\"][:] = np.ones(len(ds.XG))\n",
200-
"ds[\"Kh_meridional\"] = ds[\"U\"] + Kh_meridional[:, None]\n",
198+
"ds[\"Kh_meridional\"] = ([\"YG\",\"XG\"], Kh_meridional[:, None])\n",
201199
"ds"
202200
]
203201
},
@@ -345,7 +343,7 @@
345343
" chunks=(len(testParticles), 50),\n",
346344
" outputdt=np.timedelta64(1, \"ms\"),\n",
347345
")\n",
348-
"random.seed(1636) # Random seed for reproducibility\n",
346+
"np.random.seed(1636) # Random seed for reproducibility\n",
349347
"testParticles.execute(\n",
350348
" parcels.kernels.AdvectionDiffusionEM,\n",
351349
" runtime=np.timedelta64(300, \"ms\"),\n",
@@ -499,45 +497,32 @@
499497
"source": [
500498
"fieldset = parcels.FieldSet.from_copernicusmarine(ds_fields)\n",
501499
"\n",
500+
"def degree_lat_to_meter(d):\n",
501+
" return d * 1000.0 * 1.852 * 60.0\n",
502502
"\n",
503-
"def calc_cell_edge_sizes(grid):\n",
504-
" \"\"\"Calculate cell sizes based on numpy.gradient method.\n",
505503
"\n",
506-
" Currently only works for Rectilinear Grids. Operates in place adding a `cell_edge_sizes`\n",
507-
" attribute to the grid.\n",
508-
" \"\"\"\n",
509-
" # TODO: check for gridtypes once added in v4\n",
510-
" # assert grid._gtype in (GridType.RectilinearZGrid, GridType.RectilinearSGrid), (\n",
511-
" # f\"_cell_edge_sizes() not implemented for {grid._gtype} grids. \"\n",
512-
" # \"You can provide cell_edge_sizes yourself by in, e.g., \"\n",
513-
" # \"NEMO using the e1u fields etc from the mesh_mask.nc file.\"\n",
514-
" # )\n",
504+
"def degree_lon_to_meter(d, lat):\n",
505+
" return d * 1000.0 * 1.852 * 60.0 * np.cos(lat * np.pi / 180)\n",
515506
"\n",
516-
" cell_edge_sizes_x = np.zeros((grid.ydim + 1, grid.xdim + 1), dtype=np.float32)\n",
517-
" cell_edge_sizes_y = np.zeros((grid.ydim + 1, grid.xdim + 1), dtype=np.float32)\n",
518507
"\n",
519-
" x_conv = (\n",
520-
" parcels.GeographicPolar()\n",
521-
" ) # if grid._mesh == \"spherical\" else parcels.UnitConverter()\n",
522-
" y_conv = (\n",
523-
" parcels.Geographic()\n",
524-
" ) # if grid._mesh == \"spherical\" else parcels.UnitConverter()\n",
525-
" for y, (lat, dy) in enumerate(zip(grid.lat, np.gradient(grid.lat), strict=False)):\n",
526-
" for x, (lon, dx) in enumerate(\n",
527-
" zip(grid.lon, np.gradient(grid.lon), strict=False)\n",
528-
" ):\n",
529-
" cell_edge_sizes_x[y, x] = x_conv.to_source(dx, grid.depth[0], lat, lon)\n",
530-
" cell_edge_sizes_y[y, x] = y_conv.to_source(dy, grid.depth[0], lat, lon)\n",
531-
" return cell_edge_sizes_x, cell_edge_sizes_y\n",
508+
"def calc_cell_areas(ds):\n",
509+
" \"\"\"calculate cell areas for rectilinear grids\"\"\"\n",
510+
" lon, lat = ds[\"longitude\"], ds[\"latitude\"]\n",
511+
" assert \"degrees\" in lon.attrs[\"units\"]\n",
512+
" assert \"degrees\" in lat.attrs[\"units\"]\n",
532513
"\n",
514+
" LON, LAT = np.meshgrid(lon, lat)\n",
515+
" X, Y = degree_lon_to_meter(LON, LAT), degree_lat_to_meter(LAT)\n",
533516
"\n",
534-
"def calc_cell_areas(grid):\n",
535-
" cell_edge_sizes_x, cell_edge_sizes_y = calc_cell_edge_sizes(grid)\n",
536-
" return cell_edge_sizes_x * cell_edge_sizes_y\n",
517+
" dX = np.gradient(X, axis=1)\n",
518+
" dY = np.gradient(Y, axis=0)\n",
519+
" cell_areas = dX * dY\n",
520+
"\n",
521+
" return cell_areas\n",
537522
"\n",
538523
"\n",
539524
"da_cell_areas = xr.DataArray(\n",
540-
" data=calc_cell_areas(fieldset.U.grid),\n",
525+
" data=calc_cell_areas(ds_fields),\n",
541526
" coords=dict(\n",
542527
" latitude=([\"lat\"], ds_fields.latitude.values),\n",
543528
" longitude=([\"lon\"], ds_fields.longitude.values),\n",
@@ -612,11 +597,12 @@
612597
"outputs": [],
613598
"source": [
614599
"ds_particles = xr.open_zarr(\"smagdiff.zarr\")\n",
600+
"\n",
615601
"temperature = ds_fields.isel(time=0, depth=0).thetao.plot(cmap=\"magma\")\n",
616602
"velocity = ds_fields.isel(time=0, depth=0).plot.quiver(\n",
617603
" x=\"longitude\", y=\"latitude\", u=\"uo\", v=\"vo\"\n",
618604
")\n",
619-
"ds_particles.traj.plot(color=\"blue\")\n",
605+
"particles = ds_particles.traj.plot(color=\"blue\")\n",
620606
"plt.ylim(-31, -30)\n",
621607
"plt.xlim(31, 32.1)\n",
622608
"plt.show()"

0 commit comments

Comments
 (0)