|
9 | 9 |
|
10 | 10 | import parcels |
11 | 11 | from parcels import AdvectionRK4, Field, FieldSet, Particle, ParticleSet, Variable, VectorField |
| 12 | +from parcels._core.utils.time import TimeInterval |
12 | 13 | from parcels._datasets.structured.generic import datasets |
| 14 | +from parcels.particle import Particle, create_particle_data |
13 | 15 | from parcels.particlefile import ParticleFile |
14 | | -from parcels.particle import create_particle_data |
15 | | -from parcels._core.utils.time import TimeInterval |
16 | 16 | from parcels.xgrid import XGrid |
17 | 17 | from tests.common_kernels import DoNothing |
18 | 18 |
|
@@ -366,4 +366,48 @@ def test_particlefile_init_invalid(store): # TODO: Add test for read only store |
366 | 366 |
|
367 | 367 |
|
368 | 368 | @pytest.mark.new |
369 | | -def test_particlefile_writing(store): ... |
| 369 | +def test_particlefile_write_particle_data(store): |
| 370 | + nparticles = 100 |
| 371 | + |
| 372 | + pfile = ParticleFile(store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) |
| 373 | + pclass = Particle |
| 374 | + |
| 375 | + left, right = np.datetime64("2020-01-01", "ns"), np.datetime64("2020-01-02", "ns") |
| 376 | + time_interval = TimeInterval(left=left, right=right) |
| 377 | + |
| 378 | + initial_lon = np.linspace(0, 1, nparticles) |
| 379 | + data = create_particle_data( |
| 380 | + pclass=pclass, |
| 381 | + nparticles=nparticles, |
| 382 | + ngrids=4, |
| 383 | + time_interval=time_interval, |
| 384 | + initial={ |
| 385 | + "time": np.full(nparticles, fill_value=left), |
| 386 | + "lon": initial_lon, |
| 387 | + "dt": np.full(nparticles, fill_value=1.0), |
| 388 | + "trajectory": np.arange(nparticles), |
| 389 | + }, |
| 390 | + ) |
| 391 | + np.testing.assert_array_equal(data["time"], left) |
| 392 | + pfile._write_particle_data( |
| 393 | + particle_data=data, |
| 394 | + pclass=pclass, |
| 395 | + time_interval=time_interval, |
| 396 | + time=left, |
| 397 | + ) |
| 398 | + ds = xr.open_zarr(store, decode_cf=False) |
| 399 | + # ds.time.attrs |
| 400 | + |
| 401 | + def patched_time_attrs(attrs): |
| 402 | + units = attrs["units"] |
| 403 | + units = units[: units.index("T")] |
| 404 | + attrs["units"] = units |
| 405 | + return attrs |
| 406 | + |
| 407 | + ds.time.attrs = patched_time_attrs(ds.time.attrs) |
| 408 | + breakpoint() |
| 409 | + ds = xr.decode_cf(ds) |
| 410 | + assert ds.sizes["trajectory"] == nparticles |
| 411 | + assert ds.time.dtype == "datetime64[ns]" |
| 412 | + np.testing.assert_equal(ds["time"].isel(obs=0).values, left) |
| 413 | + np.testing.assert_equal(ds["lon"].isel(obs=0).values, initial_lon) |
0 commit comments