Skip to content

Commit ec4a094

Browse files
Anna Störikoaseyboldt
authored andcommitted
Add linear_solver keyword argument to the solver
1 parent 94134d2 commit ec4a094

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

sunode/solver.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class Solver:
214214
def __init__(self, problem: Problem, *,
215215
abstol: float = 1e-10, reltol: float = 1e-10,
216216
sens_mode: Optional[str] = None, scaling_factors: Optional[np.ndarray] = None,
217-
constraints: Optional[np.ndarray] = None, solver='BDF'):
217+
constraints: Optional[np.ndarray] = None, solver='BDF', linear_solver="dense"):
218218
self._problem = problem
219219
self._user_data = problem.make_user_data()
220220

@@ -242,7 +242,7 @@ def __init__(self, problem: Problem, *,
242242
self._constraints_vec = sunode.from_numpy(constraints)
243243
check(lib.CVodeSetConstraints(self._ode, self._constraints_vec.c_ptr))
244244

245-
self._make_linsol()
245+
self._make_linsol(linear_solver)
246246

247247
user_data_p = ffi.cast('void *', ffi.addressof(ffi.from_buffer(self._user_data.data)))
248248
check(lib.CVodeSetUserData(self._ode, user_data_p))
@@ -252,12 +252,28 @@ def __init__(self, problem: Problem, *,
252252
sens_rhs = self._problem.make_sundials_sensitivity_rhs()
253253
self._init_sens(sens_rhs, sens_mode)
254254

255-
def _make_linsol(self) -> None:
256-
linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac))
257-
check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac))
258-
259-
self._jac_func = self._problem.make_sundials_jac_dense()
260-
check(lib.CVodeSetJacFn(self._ode, self._jac_func.cffi))
255+
def _make_linsol(self, linear_solver) -> None:
256+
if linear_solver == "dense":
257+
linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac))
258+
check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac))
259+
260+
self._jac_func = self._problem.make_sundials_jac_dense()
261+
check(lib.CVodeSetJacFn(self._ode, self._jac_func.cffi))
262+
elif linear_solver == "dense_finitediff":
263+
linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac))
264+
check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac))
265+
elif linear_solver == "spgmr_finitediff":
266+
linsolver = check(lib.SUNLinSol_SPGMR(self._state_buffer.c_ptr, lib.PREC_NONE, 5))
267+
check(lib.CVodeSetLinearSolver(self._ode, linsolver, ffi.NULL))
268+
check(lib.SUNLinSolInitialize_SPGMR(linsolver))
269+
elif linear_solver == "spgmr":
270+
linsolver = check(lib.SUNLinSol_SPGMR(self._state_buffer.c_ptr, lib.PREC_NONE, 5))
271+
check(lib.CVodeSetLinearSolver(self._ode, linsolver, ffi.NULL))
272+
check(lib.SUNLinSolInitialize_SPGMR(linsolver))
273+
jac_prod = self._problem.make_sundials_jac_prod()
274+
check(lib.CVodeSetJacTimes(self._ode, ffi.NULL, jac_prod.cffi))
275+
else:
276+
raise ValueError(f"Unknown linear solver: {linear_solver}")
261277

262278
def _init_sens(self, sens_rhs, sens_mode, scaling_factors=None) -> None:
263279
if sens_mode == 'simultaneous':

0 commit comments

Comments
 (0)