Skip to content

Commit

Permalink
use pygetm.simulation.BaseSimulation, which adds load_restart
Browse files Browse the repository at this point in the history
  • Loading branch information
jornbr committed Aug 27, 2024
1 parent 0e7f906 commit 91871ec
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 100 deletions.
2 changes: 1 addition & 1 deletion environment-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ dependencies:
- cmake
- m2w64-toolchain
- pip
- pygetm
- pygetm>=0.9.2
- scipy
- h5py
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ dependencies:
- cmake
- fortran-compiler
- pip
- pygetm
- pygetm>=0.9.2
- scipy
- h5py
11 changes: 9 additions & 2 deletions src/fabmos/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _update_coordinates(grid: pygetm.domain.Grid, area: np.ndarray, h: Optional[
grid.domain.depth.attrs["_time_varying"] = False

grid.area.values[slc_loc] = area[slc_glob]
grid.iarea.values[slc_loc] = 1.0 / grid.area.values[slc_loc]
grid.D.values[slc_loc] = grid.H.values[slc_loc]


Expand Down Expand Up @@ -189,8 +190,8 @@ def compress(full_domain: Optional[Domain], comm: Optional[MPI.Comm] = None) ->
spherical=full_domain.spherical,
tiling=tiling,
logger=full_domain.root_logger,
halox=full_domain.halox,
haloy=full_domain.haloy
halox=0,
haloy=0
)

slc_loc, slc_glob, _, _ = domain.tiling.subdomain2slices()
Expand All @@ -213,3 +214,9 @@ def compress(full_domain: Optional[Domain], comm: Optional[MPI.Comm] = None) ->
domain.uncompressed_area = area

return domain


def drop_grids(domain: Domain, *grids: Grid):
for name in list(domain.fields):
if domain.fields[name].grid in grids:
del domain.fields[name]
111 changes: 18 additions & 93 deletions src/fabmos/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,8 @@
from . import environment, Array, __version__


def log_exceptions(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
try:
return method(self, *args, **kwargs)
except Exception as e:
logger = getattr(self, "logger", None)
domain = getattr(self, "domain", None)
if logger is None or domain is None or domain.tiling.n == 1:
raise
logger.exception(str(e), stack_info=True, stacklevel=3)
domain.tiling.comm.Abort(1)

return wrapper


class Simulator:
@log_exceptions
class Simulator(pygetm.simulation.BaseSimulation):
@pygetm.simulation.log_exceptions
def __init__(
self,
domain: pygetm.domain.Domain,
Expand All @@ -36,9 +20,8 @@ def __init__(
log_level: Optional[int] = None,
use_virtual_flux: bool = False,
):
self.logger = domain.root_logger
if log_level is not None:
self.logger.setLevel(log_level)
super().__init__(domain, log_level=log_level)

self.logger.info(f"fabmos {__version__}")

self.fabm = pygetm.fabm.FABM(
Expand All @@ -48,17 +31,6 @@ def __init__(
squeeze=True,
)

self.domain = domain

self.output_manager = pygetm.output.OutputManager(
self.domain.fields,
rank=domain.tiling.rank,
logger=self.logger.getChild("output_manager"),
)

self.input_manager = self.domain.input_manager
self.input_manager.set_logger(self.logger.getChild("input_manager"))

self.domain.initialize(pygetm.BAROCLINIC)
self.domain.depth.fabm_standard_name = "pressure"

Expand Down Expand Up @@ -95,10 +67,7 @@ def __init__(
self.unmasked2d = None
self.unmasked3d = None

def __getitem__(self, key: str) -> Array:
return self.output_manager.fields[key]

@log_exceptions
@pygetm.simulation.log_exceptions
def start(
self,
time: Union[cftime.datetime, datetime.datetime],
Expand All @@ -123,7 +92,17 @@ def start(
f"The transport timestep of {transport_timestep} s must be an"
f" exact multiple of the biogeochemical timestep of {timestep} s"
)
self.nstep_transport = nstep_transport
super().start(
time,
timestep,
nstep_transport,
report=report,
report_totals=report_totals,
profile=profile,
)

def _start(self):
self.tracers_with_virtual_flux: List[pygetm.tracer.Tracer] = []
if self.use_virtual_flux:
for tracer in self.tracers:
Expand All @@ -137,59 +116,20 @@ def start(
else:
self.logger.info("Virtual tracer flux due to net freshwater flux not used")

self.time = pygetm.simulation.to_cftime(time)
self.logger.info(f"Starting simulation at {self.time}")
self.timestep = timestep
self.timedelta = datetime.timedelta(seconds=timestep)
self.nstep_transport = nstep_transport
self.istep = 0
self.report = int(report.total_seconds() / timestep)
if isinstance(report_totals, datetime.timedelta):
report_totals = int(round(report_totals.total_seconds() / self.timestep))
self.report_totals = report_totals

self.fabm.start(self.time)
self.update_diagnostics(macro=True)
self.output_manager.start(self.istep, self.time)
self._start_time = timeit.default_timer()

# Start profiling if requested
self._profile = None
if profile:
import cProfile

pr = cProfile.Profile()
self._profile = (profile, pr)
pr.enable()

@log_exceptions
def advance(self):
self.time += self.timedelta
self.istep += 1
apply_transport = self.istep % self.nstep_transport == 0
if self.report != 0 and self.istep % self.report == 0:
self.logger.info(self.time)

self.output_manager.prepare_save(
self.timestep * self.istep, self.istep, self.time, macro=apply_transport
)

def _advance_state(self, macro_active: bool):
self.logger.debug(f"fabm advancing to {self.time} (dt={self.timestep} s)")
self.advance_fabm(self.timestep)

if apply_transport:
if macro_active:
timestep_transport = self.nstep_transport * self.timestep
self.logger.debug(
f"transport advancing to {self.time} (dt={timestep_transport} s)"
)
self.transport(timestep_transport)

self.update_diagnostics(apply_transport)

self.output_manager.save(self.timestep * self.istep, self.istep, self.time)

def update_diagnostics(self, macro: bool):
self.input_manager.update(self.time, macro=macro)
def _update_forcing_and_diagnostics(self, macro_active: bool):
self.radiation.update(self.time)
self.fabm.update_sources(self.timestep * self.istep, self.time)
self.fabm.add_vertical_movement_to_sources()
Expand All @@ -209,21 +149,6 @@ def advance_fabm(self, timestep: float):
def transport(self, timestep: float):
pass

@log_exceptions
def finish(self):
if self._profile:
name, pr = self._profile
pr.disable()
profile_path = f"{name}-{self.domain.tiling.rank:03}.prof"
self.logger.info(f"Writing profiling report to {profile_path}")
with open(profile_path, "w") as f:
ps = pstats.Stats(pr, stream=f).sort_stats(pstats.SortKey.TIME)
ps.print_stats()

nsecs = timeit.default_timer() - self._start_time
self.logger.info(f"Time spent in main loop: {nsecs:.3f} s")
self.output_manager.close(self.timestep * self.istep, self.time)

@property
def totals(
self,
Expand Down
18 changes: 16 additions & 2 deletions src/fabmos/transport/null.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@

from typing import Optional
import os

import pygetm
from .. import simulator
from ..domain import compress, _update_coordinates
from ..domain import compress, _update_coordinates, drop_grids


class Simulator(simulator.Simulator):
Expand All @@ -18,6 +17,21 @@ def __init__(

domain = compress(domain)

# Drop unused domain variables. Some of these will be NaN,
# which causes check_finite to fail.
drop_grids(
domain,
domain.U,
domain.V,
domain.X,
domain.UU,
domain.UV,
domain.VU,
domain.VV,
)
for name in ("dxt", "dyt", "idxt", "idyt"):
del domain.fields[name]

super().__init__(
domain,
fabm_config,
Expand Down
18 changes: 17 additions & 1 deletion src/fabmos/transport/tmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CompressedToFullGrid,
map_input_to_compressed_grid,
_update_coordinates,
drop_grids,
)

# Note: mpi4py components should be imported after pygetm.parallel,
Expand Down Expand Up @@ -521,6 +522,20 @@ def __init__(
use_virtual_flux=True,
)

# Drop unused domain variables. Some of these will be NaN, which causes check_finite to fail.
drop_grids(
domain,
domain.U,
domain.V,
domain.X,
domain.UU,
domain.UV,
domain.VU,
domain.VV,
)
for name in ("dxt", "dyt", "idxt", "idyt"):
del domain.fields[name]

self.tmm_logger = self.logger.getChild("TMM")
_update_coordinates(self.domain.T, self.domain.da, self.domain.dz)
if self.domain.glob and self.domain.glob is not self.domain:
Expand Down Expand Up @@ -656,9 +671,10 @@ def add_atmospheric_gas(
underwater_scale_factor: float = 1.0,
) -> Array:
atm_array = self.fabm.get_dependency(atmospheric_name)
atm_array.attrs["_part_of_state"] = True
flux_array = self._get_existing_fabm_variable(underwater_name + "_sfl")
self.logger.info(
f"Atmopsheric gas: {atm_array.name} will be updated"
f"Atmospheric gas: {atm_array.name} will be updated"
f" based on air-sea flux of {underwater_name}."
)
self._atmospheric_gases.append(
Expand Down

0 comments on commit 91871ec

Please sign in to comment.