diff --git a/include/signedheat3d/signed_heat_3d.h b/include/signedheat3d/signed_heat_3d.h index c3e8f49..f0dce88 100644 --- a/include/signedheat3d/signed_heat_3d.h +++ b/include/signedheat3d/signed_heat_3d.h @@ -76,6 +76,11 @@ bool isResolutionValid(const std::array& resolution); std::pair computeBBox(VertexPositionGeometry& geometry); std::pair computeBBox(pointcloud::PointPositionNormalGeometry& pointGeom); +Vector solveSquareSystem(SparseMatrix& LHS, const Vector& RHS, bool verbose = false); +Vector solvePositiveDefiniteSystem(SparseMatrix& LHS, const Vector& RHS, bool verbose = false); +Vector solvePositiveDefiniteSystem(SparseMatrix& LHS, const Vector& RHS, + std::unique_ptr>& solver, bool factorize, + bool verbose = false); Vector AMGCL_solve(SparseMatrix& LHS, const Vector& RHS, bool& success, bool verbose = false); Vector AMGCL_blockSolve(const SparseMatrix& L, const SparseMatrix& A, const SparseMatrix& Z, const Vector& rhs, bool verbose = false); \ No newline at end of file diff --git a/src/signed_heat_3d.cpp b/src/signed_heat_3d.cpp index 2652182..f5dc503 100644 --- a/src/signed_heat_3d.cpp +++ b/src/signed_heat_3d.cpp @@ -121,6 +121,97 @@ void setFaceVectorAreas(VertexPositionGeometry& geometry, FaceData& area } } +/* A wrapper function around solveSquare, to handle occasional failures of direct solver. */ +Vector solveSquareSystem(SparseMatrix& LHS, const Vector& RHS, bool verbose) { + bool success = true; + Vector soln; + try { + soln = solveSquare(LHS, RHS); + } catch (const std::exception& e) { + if (verbose) std::cerr << "Caught exception: " << e.what() << std::endl; + success = false; + } + + double solnNorm = soln.norm(); + double error = (LHS * soln - RHS).norm(); + double tol = 1.0; + if (verbose) std::cerr << "Direct solver residual: " << error << std::endl; + if (std::isinf(solnNorm) || std::isnan(solnNorm) || abs(error) > tol) { + if (verbose) std::cerr << "Direct solver failed, using iterative solver" << std::endl; + success = false; + + Eigen::BiCGSTAB> solver; + solver.compute(LHS); + soln = solver.solve(RHS); + if (verbose) { + std::cout << "\t#iterations: " << solver.iterations() << std::endl; + std::cout << "\testimated error: " << solver.error() << std::endl; + } + } + return soln; +} + +Vector solvePositiveDefiniteSystem(SparseMatrix& LHS, const Vector& RHS, bool verbose) { + bool success = true; + Vector soln; + try { + soln = solvePositiveDefinite(LHS, RHS); + } catch (const std::exception& e) { + if (verbose) std::cerr << "Caught exception: " << e.what() << std::endl; + success = false; + } + + double solnNorm = soln.norm(); + double error = (LHS * soln - RHS).norm(); + double tol = 1.0; + if (verbose) std::cerr << "Direct solver residual: " << error << std::endl; + if (std::isinf(solnNorm) || std::isnan(solnNorm) || abs(error) > tol) { + if (verbose) std::cerr << "Direct solver failed, using iterative solver" << std::endl; + success = false; + + Eigen::ConjugateGradient, Eigen::Lower | Eigen::Upper> solver; + solver.compute(LHS); + soln = solver.solve(RHS); + if (verbose) { + std::cout << "\t#iterations: " << solver.iterations() << std::endl; + std::cout << "\testimated error: " << solver.error() << std::endl; + } + } + return soln; +} + +Vector solvePositiveDefiniteSystem(SparseMatrix& LHS, const Vector& RHS, + std::unique_ptr>& solver, bool factorize, + bool verbose) { + bool success = true; + Vector soln; + try { + solver.reset(new PositiveDefiniteSolver(LHS)); + soln = solver->solve(RHS); + } catch (const std::exception& e) { + if (verbose) std::cerr << "Caught exception: " << e.what() << std::endl; + success = false; + } + + double solnNorm = soln.norm(); + double error = (LHS * soln - RHS).norm(); + double tol = 1.0; + if (verbose) std::cerr << "Direct solver residual: " << error << std::endl; + if (std::isinf(solnNorm) || std::isnan(solnNorm) || abs(error) > tol) { + if (verbose) std::cerr << "Direct solver failed, using iterative solver" << std::endl; + success = false; + + Eigen::ConjugateGradient, Eigen::Lower | Eigen::Upper> cg; + cg.compute(LHS); + soln = cg.solve(RHS); + if (verbose) { + std::cout << "\t#iterations: " << cg.iterations() << std::endl; + std::cout << "\testimated error: " << cg.error() << std::endl; + } + } + return soln; +} + #ifndef SHM_NO_AMGCL Vector AMGCL_solve(SparseMatrix& L, const Vector& RHS, bool& success, bool verbose) { @@ -152,10 +243,10 @@ Vector AMGCL_solve(SparseMatrix& L, const Vector& RHS, b std::tie(iters, error) = solve(LHS, RHS, x); } catch (const std::exception& e) { if (verbose) { - std::cerr << "Caught exception: '" << e.what() << std::endl; + std::cerr << "Caught exception: " << e.what() << std::endl; std::cerr << "Use direct solver" << std::endl; - success = false; } + success = false; return x; } if (verbose) std::cerr << "AMGCL # iters: " << iters << "\tAMGCL residual: " << error << std::endl; diff --git a/src/signed_heat_grid_solver.cpp b/src/signed_heat_grid_solver.cpp index e9f7a2d..374d78c 100644 --- a/src/signed_heat_grid_solver.cpp +++ b/src/signed_heat_grid_solver.cpp @@ -115,9 +115,9 @@ Vector SignedHeatGridSolver::computeDistance(VertexPositionGeometry& geo #ifndef SHM_NO_AMGCL bool success; Vector soln = AMGCL_solve(LHS, RHS, success, VERBOSE); - if (!success) soln = solveSquare(LHS, RHS); + if (!success) soln = solveSquareSystem(LHS, RHS, VERBOSE); #else - Vector soln = solveSquare(LHS, RHS); + Vector soln = solveSquareSystem(LHS, RHS, VERBOSE); #endif // clang-format on phi = -soln.head(totalNodes); @@ -227,9 +227,9 @@ Vector SignedHeatGridSolver::computeDistance(pointcloud::PointPositionNo #ifndef SHM_NO_AMGCL bool success; Vector soln = AMGCL_solve(LHS, RHS, success, VERBOSE); - if (!success) soln = solveSquare(LHS, RHS); + if (!success) soln = solveSquareSystem(LHS, RHS, VERBOSE); #else - Vector soln = solveSquare(LHS, RHS); + Vector soln = solveSquareSystem(LHS, RHS, VERBOSE); #endif // clang-format on phi = -soln.head(totalNodes); diff --git a/src/signed_heat_tet_solver.cpp b/src/signed_heat_tet_solver.cpp index c28b1e8..8f2f962 100755 --- a/src/signed_heat_tet_solver.cpp +++ b/src/signed_heat_tet_solver.cpp @@ -210,9 +210,9 @@ Vector SignedHeatTetSolver::integrateVectorField(VertexPositionGeometry& #ifndef SHM_NO_AMGCL bool success; Vector Aresult = AMGCL_solve(decomp.AA, combinedRHS, success, VERBOSE); - if (!success) Aresult = solvePositiveDefinite(decomp.AA, combinedRHS); // success + if (!success) Aresult = solvePositiveDefiniteSystem(decomp.AA, combinedRHS); // success #else - Vector Aresult = solvePositiveDefinite(decomp.AA, combinedRHS); + Vector Aresult = solvePositiveDefiniteSystem(decomp.AA, combinedRHS); #endif // clang-format on phi = reassembleVector(decomp, Aresult, bcVals); @@ -256,20 +256,18 @@ Vector SignedHeatTetSolver::integrateVectorField(VertexPositionGeometry& #ifndef SHM_NO_AMGCL bool success; Vector soln = AMGCL_solve(LHS, RHS, success, VERBOSE); - if (!success) soln = solveSquare(LHS, RHS); // direct solver + if (!success) soln = solveSquareSystem(LHS, RHS); // direct solver #else - Vector soln = solveSquare(LHS, RHS); + Vector soln = solveSquareSystem(LHS, RHS); #endif // clang-format on phi = soln.head(nVertices); double shift = averageVertexDataOnSource(geometry, phi); phi -= shift * Vector::Ones(nVertices); } else { - auto solveDirect = [&]() -> Vector { - if (rebuild || poissonSolver == nullptr) { - if (VERBOSE) std::cerr << "\tFactorizing..." << std::endl; - poissonSolver.reset(new PositiveDefiniteSolver(laplaceMat)); - } + auto solveFallback = [&]() -> Vector { + phi = solvePositiveDefiniteSystem(laplaceMat, div, poissonSolver, rebuild || poissonSolver == nullptr, + VERBOSE); phi = poissonSolver->solve(div); double shift = averageVertexDataOnSource(geometry, phi); phi -= shift * Vector::Ones(nVertices); @@ -279,9 +277,9 @@ Vector SignedHeatTetSolver::integrateVectorField(VertexPositionGeometry& #ifndef SHM_NO_AMGCL bool success; phi = AMGCL_solve(laplaceMat, div, success, VERBOSE); - if (!success) phi = solveDirect(); + if (!success) phi = solveFallback(); #else - phi = solveDirect(); + phi = solveFallback(); #endif // clang-format on } @@ -358,20 +356,18 @@ Vector SignedHeatTetSolver::integrateVectorFieldToFaces(VertexPositionGe #ifndef SHM_NO_AMGCL bool success; Vector soln = AMGCL_solve(LHS, RHS, success, VERBOSE); - if (!success) soln = solveSquare(LHS, RHS); + if (!success) soln = solveSquareSystem(LHS, RHS); #else - Vector soln = solveSquare(LHS, RHS); + Vector soln = solveSquareSystem(LHS, RHS); #endif // clang-format on phi = soln.head(nFaces); double shift = averageFaceDataOnSource(geometry, phi); phi -= shift * Vector::Ones(nFaces); } else { - auto solveDirect = [&]() -> Vector { - if (rebuild || poissonSolverCR == nullptr) { - if (VERBOSE) std::cerr << "\tFactorizing..." << std::endl; - poissonSolverCR.reset(new PositiveDefiniteSolver(laplaceCR)); - } + auto solveFallback = [&]() -> Vector { + phi = solvePositiveDefiniteSystem(laplaceCR, div, poissonSolverCR, rebuild || poissonSolverCR == nullptr, + VERBOSE); phi = poissonSolverCR->solve(div); double shift = averageFaceDataOnSource(geometry, phi); phi -= shift * Vector::Ones(nFaces); @@ -381,9 +377,9 @@ Vector SignedHeatTetSolver::integrateVectorFieldToFaces(VertexPositionGe #ifndef SHM_NO_AMGCL bool success; phi = AMGCL_solve(laplaceCR, div, success, VERBOSE); - if (!success) phi = solveDirect(); + if (!success) phi = solveFallback(); #else - phi = solveDirect(); + phi = solveFallback(); #endif // clang-format on } @@ -410,11 +406,9 @@ Vector SignedHeatTetSolver::integrateVectorField(pointcloud::PointPositi case (LevelSetConstraint::None): { Vector div = vertexDivergence(Yt); - auto solveDirect = [&]() -> Vector { - if (rebuild || poissonSolver == nullptr) { - if (VERBOSE) std::cerr << "\tFactorizing..." << std::endl; - poissonSolver.reset(new PositiveDefiniteSolver(laplaceMat)); - } + auto solveFallback = [&]() -> Vector { + phi = solvePositiveDefiniteSystem(laplaceMat, div, poissonSolver, rebuild || poissonSolver == nullptr, + VERBOSE); phi = poissonSolver->solve(div); double shift = averageVertexDataOnSource(pointGeom, phi); phi -= shift * Vector::Ones(nVertices); @@ -425,9 +419,9 @@ Vector SignedHeatTetSolver::integrateVectorField(pointcloud::PointPositi #ifndef SHM_NO_AMGCL bool success; phi = AMGCL_solve(laplaceMat, div, success, VERBOSE); - if (!success) phi = solveDirect(); + if (!success) phi = solveFallback(); #else - phi = solveDirect(); + phi = solveFallback(); #endif // clang-format on break; @@ -447,9 +441,9 @@ Vector SignedHeatTetSolver::integrateVectorField(pointcloud::PointPositi #ifndef SHM_NO_AMGCL bool success; Vector Aresult = AMGCL_solve(decomp.AA, combinedRHS, success, VERBOSE); - if (!success) Aresult = solvePositiveDefinite(decomp.AA, combinedRHS); // direct solver + if (!success) Aresult = solvePositiveDefiniteSystem(decomp.AA, combinedRHS); // direct solver #else - Vector Aresult = solvePositiveDefinite(decomp.AA, combinedRHS); + Vector Aresult = solvePositiveDefiniteSystem(decomp.AA, combinedRHS); #endif // clang-format on phi = reassembleVector(decomp, Aresult, bcVals); @@ -495,9 +489,9 @@ Vector SignedHeatTetSolver::integrateVectorField(pointcloud::PointPositi #ifndef SHM_NO_AMGCL bool success; Vector soln = AMGCL_solve(LHS, RHS, success, VERBOSE); - if (!success) soln = solveSquare(LHS, RHS); // direct solver + if (!success) soln = solveSquareSystem(LHS, RHS); // direct solver #else - Vector soln = solveSquare(LHS, RHS); + Vector soln = solveSquareSystem(LHS, RHS); #endif #// clang-format on phi = soln.head(nVertices);