|
7 | 7 | mod csc_serde;
|
8 | 8 |
|
9 | 9 | use crate::cs;
|
10 |
| -use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; |
| 10 | +use crate::cs::{CsBuilder, CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix}; |
11 | 11 | use crate::csr::CsrMatrix;
|
12 |
| -use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter}; |
| 12 | +use crate::pattern::{ |
| 13 | + BuilderInsertError, SparsityPattern, SparsityPatternFormatError, SparsityPatternIter, |
| 14 | +}; |
13 | 15 | use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
14 | 16 |
|
15 |
| -use nalgebra::Scalar; |
| 17 | +use nalgebra::{RealField, Scalar}; |
16 | 18 | use num_traits::One;
|
17 | 19 | use std::slice::{Iter, IterMut};
|
18 | 20 |
|
@@ -553,6 +555,269 @@ impl<T> CscMatrix<T> {
|
553 | 555 | self.filter(|i, j, _| i >= j)
|
554 | 556 | }
|
555 | 557 |
|
| 558 | + /// Solves a lower triangular system, `self` is a matrix of NxN, and `b` is a column vector of size N |
| 559 | + /// Assuming that b is dense. |
| 560 | + pub fn dense_lower_triangular_solve(&self, b: &[T], out: &mut [T], unit_diagonal: bool) |
| 561 | + where |
| 562 | + T: RealField + Copy, |
| 563 | + { |
| 564 | + assert_eq!(self.nrows(), self.ncols()); |
| 565 | + assert_eq!(self.ncols(), b.len()); |
| 566 | + assert_eq!(out.len(), b.len()); |
| 567 | + out.copy_from_slice(b); |
| 568 | + let n = b.len(); |
| 569 | + |
| 570 | + for i in 0..n { |
| 571 | + let col = self.col(i); |
| 572 | + let mut iter = col.row_indices().iter().zip(col.values().iter()).peekable(); |
| 573 | + while iter.next_if(|n| *n.0 < i).is_some() {} |
| 574 | + if let Some(n) = iter.peek() { |
| 575 | + if *n.0 == i && !unit_diagonal { |
| 576 | + assert!(*n.0 <= i); |
| 577 | + out[i] /= *n.1; |
| 578 | + iter.next(); |
| 579 | + } |
| 580 | + } |
| 581 | + let mul = out[i]; |
| 582 | + for (&ri, &v) in col.row_indices().iter().zip(col.values().iter()) { |
| 583 | + use std::cmp::Ordering::*; |
| 584 | + // ensure that only using the lower part |
| 585 | + match ri.cmp(&i) { |
| 586 | + Greater => out[ri] -= v * mul, |
| 587 | + Equal | Less => {} |
| 588 | + } |
| 589 | + } |
| 590 | + } |
| 591 | + } |
| 592 | + |
| 593 | + /// Solves an upper triangular system, `self` is a matrix of NxN, and `b` is a column vector of size N |
| 594 | + /// Assuming that b is dense. |
| 595 | + pub fn dense_upper_triangular_solve(&self, b: &[T], out: &mut [T]) |
| 596 | + where |
| 597 | + T: RealField + Copy, |
| 598 | + { |
| 599 | + assert_eq!(self.nrows(), self.ncols()); |
| 600 | + assert_eq!(self.ncols(), b.len()); |
| 601 | + assert_eq!(out.len(), b.len()); |
| 602 | + out.copy_from_slice(b); |
| 603 | + let n = b.len(); |
| 604 | + |
| 605 | + for i in (0..n).rev() { |
| 606 | + let col = self.col(i); |
| 607 | + let mut iter = col |
| 608 | + .row_indices() |
| 609 | + .iter() |
| 610 | + .zip(col.values().iter()) |
| 611 | + .rev() |
| 612 | + .peekable(); |
| 613 | + while iter.next_if(|n| *n.0 > i).is_some() {} |
| 614 | + if let Some(n) = iter.peek() { |
| 615 | + if *n.0 == i { |
| 616 | + out[i] /= *n.1; |
| 617 | + iter.next(); |
| 618 | + } |
| 619 | + } |
| 620 | + // introduce a NaN, intentionally, if the diagonal doesn't have a value. |
| 621 | + let mul = out[i]; |
| 622 | + for (&row, &v) in iter { |
| 623 | + use std::cmp::Ordering::*; |
| 624 | + match row.cmp(&i) { |
| 625 | + Less => out[row] -= v * mul, |
| 626 | + Equal | Greater => {} |
| 627 | + } |
| 628 | + } |
| 629 | + } |
| 630 | + } |
| 631 | + |
| 632 | + /// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector |
| 633 | + /// sparse. |
| 634 | + /// sparsity_idxs should be precomputed using the sparse_lower_triangle. |
| 635 | + pub fn sparse_upper_triangular_solve_sorted( |
| 636 | + &self, |
| 637 | + b_idxs: &[usize], |
| 638 | + b: &[T], |
| 639 | + |
| 640 | + out_sparsity_pattern: &[usize], |
| 641 | + out: &mut [T], |
| 642 | + ) where |
| 643 | + T: RealField + Copy, |
| 644 | + { |
| 645 | + assert_eq!(self.nrows(), self.ncols()); |
| 646 | + assert_eq!(b_idxs.len(), b.len()); |
| 647 | + assert!(b_idxs.iter().all(|&bi| bi < self.ncols())); |
| 648 | + |
| 649 | + assert_eq!(out_sparsity_pattern.len(), out.len()); |
| 650 | + assert!(out_sparsity_pattern.iter().all(|&bi| bi < self.ncols())); |
| 651 | + |
| 652 | + // initialize out with b |
| 653 | + out.fill(T::zero()); |
| 654 | + for (&bv, &bi) in b.iter().zip(b_idxs.iter()) { |
| 655 | + let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap(); |
| 656 | + out[out_pos] = bv; |
| 657 | + } |
| 658 | + |
| 659 | + for (i, &row) in out_sparsity_pattern.iter().enumerate().rev() { |
| 660 | + let col = self.col(row); |
| 661 | + let mut iter = col |
| 662 | + .row_indices() |
| 663 | + .iter() |
| 664 | + .zip(col.values().iter()) |
| 665 | + .rev() |
| 666 | + .peekable(); |
| 667 | + |
| 668 | + while iter.next_if(|n| *n.0 > row).is_some() {} |
| 669 | + match iter.peek() { |
| 670 | + Some((&r, &l_val)) if r == row => out[i] /= l_val, |
| 671 | + // here it now becomes implicitly 0, |
| 672 | + // likely this should introduce NaN or some other behavior. |
| 673 | + _ => {} |
| 674 | + } |
| 675 | + let mul = out[i]; |
| 676 | + for (ni, &nrow) in out_sparsity_pattern[i + 1..].iter().enumerate().rev() { |
| 677 | + assert!(nrow > row); |
| 678 | + while iter.next_if(|n| *n.0 > nrow).is_some() {} |
| 679 | + let l_val = match iter.peek() { |
| 680 | + Some((&r, &l_val)) if r == nrow => l_val, |
| 681 | + _ => continue, |
| 682 | + }; |
| 683 | + out[ni] -= l_val * mul; |
| 684 | + } |
| 685 | + } |
| 686 | + } |
| 687 | + |
| 688 | + /// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector |
| 689 | + /// sparse. |
| 690 | + /// sparsity_idxs should be precomputed using the sparse_lower_triangle. |
| 691 | + /// Assumes that the diagonal of the sparse matrix is all 1 if `assume_unit` is true. |
| 692 | + pub fn sparse_lower_triangular_solve( |
| 693 | + &self, |
| 694 | + b_idxs: &[usize], |
| 695 | + b: &[T], |
| 696 | + // idx -> row |
| 697 | + // for now, is permitted to be unsorted |
| 698 | + // TODO maybe would be better to enforce sorted, but would have to sort internally. |
| 699 | + out_sparsity_pattern: &[usize], |
| 700 | + out: &mut [T], |
| 701 | + assume_unit: bool, |
| 702 | + ) where |
| 703 | + T: RealField + Copy, |
| 704 | + { |
| 705 | + assert_eq!(self.nrows(), self.ncols()); |
| 706 | + assert_eq!(b.len(), b_idxs.len()); |
| 707 | + assert!(b_idxs.iter().all(|&bi| bi < self.ncols())); |
| 708 | + |
| 709 | + assert_eq!(out_sparsity_pattern.len(), out.len()); |
| 710 | + assert!(out_sparsity_pattern.iter().all(|&i| i < self.ncols())); |
| 711 | + |
| 712 | + let is_sorted = (0..out_sparsity_pattern.len() - 1) |
| 713 | + .all(|i| out_sparsity_pattern[i] < out_sparsity_pattern[i + 1]); |
| 714 | + if is_sorted { |
| 715 | + return self.sparse_lower_triangular_solve_sorted( |
| 716 | + b_idxs, |
| 717 | + b, |
| 718 | + out_sparsity_pattern, |
| 719 | + out, |
| 720 | + assume_unit, |
| 721 | + ); |
| 722 | + } |
| 723 | + |
| 724 | + // initialize out with b |
| 725 | + out.fill(T::zero()); |
| 726 | + for (&bv, &bi) in b.iter().zip(b_idxs.iter()) { |
| 727 | + let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap(); |
| 728 | + out[out_pos] = bv; |
| 729 | + } |
| 730 | + |
| 731 | + for (i, &row) in out_sparsity_pattern.iter().enumerate() { |
| 732 | + let col = self.col(row); |
| 733 | + if !assume_unit { |
| 734 | + if let Some(l_val) = col.get_entry(row) { |
| 735 | + out[i] /= l_val.into_value(); |
| 736 | + } else { |
| 737 | + // diagonal is 0, non-invertible |
| 738 | + out[i] /= T::zero(); |
| 739 | + } |
| 740 | + } |
| 741 | + let mul = out[i]; |
| 742 | + for (ni, &nrow) in out_sparsity_pattern.iter().enumerate() { |
| 743 | + if nrow <= row { |
| 744 | + continue; |
| 745 | + } |
| 746 | + // TODO in a sorted version may be able to iterate without |
| 747 | + // having the cost of binary search at each iteration |
| 748 | + let l_val = if let Some(l_val) = col.get_entry(nrow) { |
| 749 | + l_val.into_value() |
| 750 | + } else { |
| 751 | + continue; |
| 752 | + }; |
| 753 | + out[ni] -= l_val * mul; |
| 754 | + } |
| 755 | + } |
| 756 | + } |
| 757 | + /// Solves a sparse lower triangular system `Ax = b`, with both the matrix and vector |
| 758 | + /// sparse. |
| 759 | + /// sparsity_idxs should be precomputed using the sparse_lower_triangle pattern. |
| 760 | + /// |
| 761 | + /// `out_sparsity_pattern` must also be pre-sorted. |
| 762 | + /// |
| 763 | + /// Assumes that the diagonal of the sparse matrix is all 1 if `assume_unit` is true. |
| 764 | + pub fn sparse_lower_triangular_solve_sorted( |
| 765 | + &self, |
| 766 | + // input vector idxs & values |
| 767 | + b_idxs: &[usize], |
| 768 | + b: &[T], |
| 769 | + // idx -> row |
| 770 | + // for now, is permitted to be unsorted |
| 771 | + // TODO maybe would be better to enforce sorted, but would have to sort internally. |
| 772 | + out_sparsity_pattern: &[usize], |
| 773 | + out: &mut [T], |
| 774 | + assume_unit: bool, |
| 775 | + ) where |
| 776 | + T: RealField + Copy, |
| 777 | + { |
| 778 | + assert_eq!(self.nrows(), self.ncols()); |
| 779 | + assert_eq!(b.len(), b_idxs.len()); |
| 780 | + assert!(b_idxs.iter().all(|&bi| bi < self.ncols())); |
| 781 | + |
| 782 | + assert_eq!(out_sparsity_pattern.len(), out.len()); |
| 783 | + assert!(out_sparsity_pattern.iter().all(|&i| i < self.ncols())); |
| 784 | + |
| 785 | + // initialize out with b |
| 786 | + // TODO can make this more efficient by keeping two iterators in sorted order |
| 787 | + out.fill(T::zero()); |
| 788 | + for (&bv, &bi) in b.iter().zip(b_idxs.iter()) { |
| 789 | + let out_pos = out_sparsity_pattern.iter().position(|&p| p == bi).unwrap(); |
| 790 | + out[out_pos] = bv; |
| 791 | + } |
| 792 | + // end init |
| 793 | + |
| 794 | + // assuming that the output sparsity pattern is sorted |
| 795 | + // iterate thru |
| 796 | + for (i, &row) in out_sparsity_pattern.iter().enumerate() { |
| 797 | + let col = self.col(row); |
| 798 | + let mut iter = col.row_indices().iter().zip(col.values().iter()).peekable(); |
| 799 | + if !assume_unit { |
| 800 | + while iter.next_if(|n| *n.0 < row).is_some() {} |
| 801 | + match iter.peek() { |
| 802 | + Some((&r, &l_val)) if r == row => out[i] /= l_val, |
| 803 | + // here it now becomes implicitly 0, |
| 804 | + // likely this should introduce NaN or some other behavior. |
| 805 | + _ => {} |
| 806 | + } |
| 807 | + } |
| 808 | + let mul = out[i]; |
| 809 | + for (ni, &nrow) in out_sparsity_pattern.iter().enumerate().skip(i + 1) { |
| 810 | + assert!(nrow > row); |
| 811 | + while iter.next_if(|n| *n.0 < nrow).is_some() {} |
| 812 | + let l_val = match iter.peek() { |
| 813 | + Some((&r, &l_val)) if r == nrow => l_val, |
| 814 | + _ => continue, |
| 815 | + }; |
| 816 | + out[ni] -= l_val * mul; |
| 817 | + } |
| 818 | + } |
| 819 | + } |
| 820 | + |
556 | 821 | /// Returns the diagonal of the matrix as a sparse matrix.
|
557 | 822 | #[must_use]
|
558 | 823 | pub fn diagonal_as_csc(&self) -> Self
|
@@ -784,3 +1049,30 @@ where
|
784 | 1049 | self.lane_iter.next().map(|lane| CscColMut { lane })
|
785 | 1050 | }
|
786 | 1051 | }
|
| 1052 | + |
| 1053 | +/// An incremental builder for a Csc matrix. |
| 1054 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 1055 | +pub struct CscBuilder<T>(CsBuilder<T>); |
| 1056 | + |
| 1057 | +impl<T> CscBuilder<T> { |
| 1058 | + /// Constructs a new instance of a Csc Builder. |
| 1059 | + pub fn new(rows: usize, cols: usize) -> Self { |
| 1060 | + Self(CsBuilder::new(cols, rows)) |
| 1061 | + } |
| 1062 | + /// Convert back from a matrix to a CscBuilder. |
| 1063 | + pub fn from_mat(mat: CscMatrix<T>) -> Self { |
| 1064 | + Self(CsBuilder::from_mat(mat.cs)) |
| 1065 | + } |
| 1066 | + /// Backtracks back to column `col`, deleting all entries ahead of it. |
| 1067 | + pub fn revert_to_col(&mut self, col: usize) -> bool { |
| 1068 | + self.0.revert_to_major(col) |
| 1069 | + } |
| 1070 | + /// Inserts a value into the builder. Must be called in ascending col, row order. |
| 1071 | + pub fn insert(&mut self, row: usize, col: usize, val: T) -> Result<(), BuilderInsertError> { |
| 1072 | + self.0.insert(col, row, val) |
| 1073 | + } |
| 1074 | + /// Converts this builder into a valid CscMatrix. |
| 1075 | + pub fn build(self) -> CscMatrix<T> { |
| 1076 | + CscMatrix { cs: self.0.build() } |
| 1077 | + } |
| 1078 | +} |
0 commit comments