Skip to content

Commit

Permalink
make scipy bicgstab call dependant on scipy version
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Jul 7, 2024
1 parent 4c7bf0a commit e290d94
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
24 changes: 20 additions & 4 deletions pymatsolver/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@
SolverCG = WrapIterative(cg, name="SolverCG")
SolverBiCG = WrapIterative(bicgstab, name="SolverBiCG")

import scipy
_rtol_call = False
scipy_major, scipy_minor, scipy_patch = scipy.__version__.split(".")
if int(scipy_major) >= 1 and int(scipy_minor) >= 12:
_rtol_call = True

class BiCGJacobi(Base):
"""Bicg Solver with Jacobi preconditioner"""

_factored = False
solver = None
maxiter = 1000
tol = 1E-6
rtol = 1E-6
atol = 0.0

def __init__(self, A, symmetric=True):
self.A = A
Expand All @@ -30,13 +37,21 @@ def factor(self):
self.M = sp.linalg.interface.aslinearoperator(Ainv)
self._factored = True

@property
def _tols(self):
if _rtol_call:
return {'rtol': self.rtol, 'atol': self.atol}
else:
return {'tol': self.rtol, 'atol': self.atol}


def _solve1(self, rhs):
self.factor()
sol, info = self.solver(
self.A, rhs,
atol=self.tol,
maxiter=self.maxiter,
M=self.M
M=self.M,
**self._tols,
)
return sol

Expand All @@ -45,7 +60,8 @@ def _solveM(self, rhs):
sol = []
for icol in range(rhs.shape[1]):
sol.append(self.solver(self.A, rhs[:, icol].flatten(),
atol=self.tol, maxiter=self.maxiter, M=self.M)[0])
maxiter=self.maxiter, M=self.M,
**self._tols,)[0])
out = np.hstack(sol)
out.shape
return out
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
setup(
name="pymatsolver",
version="0.2.0",
packages=find_packages(exclude=["*mumps", "tests"]),
packages=find_packages(exclude=["tests"]),
install_requires=[
'numpy>=1.7',
'scipy>=0.13',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_BicgJacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_T(self):
Ainv.clean()


class TestPardisoComplex(unittest.TestCase):
class TestBicgJacobiComplex(unittest.TestCase):

def setUp(self):
nSize = 100
Expand Down

0 comments on commit e290d94

Please sign in to comment.