@@ -214,7 +214,7 @@ class Solver:
214
214
def __init__ (self , problem : Problem , * ,
215
215
abstol : float = 1e-10 , reltol : float = 1e-10 ,
216
216
sens_mode : Optional [str ] = None , scaling_factors : Optional [np .ndarray ] = None ,
217
- constraints : Optional [np .ndarray ] = None , solver = 'BDF' ):
217
+ constraints : Optional [np .ndarray ] = None , solver = 'BDF' , linear_solver = "dense" ):
218
218
self ._problem = problem
219
219
self ._user_data = problem .make_user_data ()
220
220
@@ -242,7 +242,7 @@ def __init__(self, problem: Problem, *,
242
242
self ._constraints_vec = sunode .from_numpy (constraints )
243
243
check (lib .CVodeSetConstraints (self ._ode , self ._constraints_vec .c_ptr ))
244
244
245
- self ._make_linsol ()
245
+ self ._make_linsol (linear_solver )
246
246
247
247
user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
248
248
check (lib .CVodeSetUserData (self ._ode , user_data_p ))
@@ -252,12 +252,28 @@ def __init__(self, problem: Problem, *,
252
252
sens_rhs = self ._problem .make_sundials_sensitivity_rhs ()
253
253
self ._init_sens (sens_rhs , sens_mode )
254
254
255
- def _make_linsol (self ) -> None :
256
- linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
257
- check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
258
-
259
- self ._jac_func = self ._problem .make_sundials_jac_dense ()
260
- check (lib .CVodeSetJacFn (self ._ode , self ._jac_func .cffi ))
255
+ def _make_linsol (self , linear_solver ) -> None :
256
+ if linear_solver == "dense" :
257
+ linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
258
+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
259
+
260
+ self ._jac_func = self ._problem .make_sundials_jac_dense ()
261
+ check (lib .CVodeSetJacFn (self ._ode , self ._jac_func .cffi ))
262
+ elif linear_solver == "dense_finitediff" :
263
+ linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
264
+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
265
+ elif linear_solver == "spgmr_finitediff" :
266
+ linsolver = check (lib .SUNLinSol_SPGMR (self ._state_buffer .c_ptr , lib .PREC_NONE , 5 ))
267
+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , ffi .NULL ))
268
+ check (lib .SUNLinSolInitialize_SPGMR (linsolver ))
269
+ elif linear_solver == "spgmr" :
270
+ linsolver = check (lib .SUNLinSol_SPGMR (self ._state_buffer .c_ptr , lib .PREC_NONE , 5 ))
271
+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , ffi .NULL ))
272
+ check (lib .SUNLinSolInitialize_SPGMR (linsolver ))
273
+ jac_prod = self ._problem .make_sundials_jac_prod ()
274
+ check (lib .CVodeSetJacTimes (self ._ode , ffi .NULL , jac_prod .cffi ))
275
+ else :
276
+ raise ValueError (f"Unknown linear solver: { linear_solver } " )
261
277
262
278
def _init_sens (self , sens_rhs , sens_mode , scaling_factors = None ) -> None :
263
279
if sens_mode == 'simultaneous' :
0 commit comments