-
Notifications
You must be signed in to change notification settings - Fork 768
Implement QR decomposition #4250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RustyBamboo
wants to merge
1
commit into
tracel-ai:main
Choose a base branch
from
RustyBamboo:qr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+281
−48
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
crates/burn-backend-tests/tests/tensor/float/linalg/qr_decomposition.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| use super::*; | ||
| use burn_tensor::cast::ToElement; | ||
| use burn_tensor::{Tolerance, linalg, s}; | ||
|
|
||
| fn assert_orthonormal(q: TestTensor<2>, tolerance: Tolerance<FloatElem>) { | ||
| let device = q.device(); | ||
| let [_m, k] = q.dims(); | ||
| let eye = TestTensor::<2>::eye(k, &device); | ||
| let qtq = q.clone().transpose().matmul(q); | ||
| qtq.into_data() | ||
| .assert_approx_eq::<FloatElem>(&eye.into_data(), tolerance); | ||
| } | ||
|
|
||
| // QR factors are unique up to column-wise sign flips; align to reference. | ||
| fn align_qr_to_expected( | ||
| mut q: TestTensor<2>, | ||
| mut r: TestTensor<2>, | ||
| q_expected: TestTensor<2>, | ||
| ) -> (TestTensor<2>, TestTensor<2>) { | ||
| let [_m, k] = q_expected.dims(); | ||
| for col in 0..k { | ||
| let q_col = q.clone().slice(s![.., col..(col + 1)]); | ||
| let q_ref = q_expected.clone().slice(s![.., col..(col + 1)]); | ||
| let dot = (q_col.clone() * q_ref).sum().into_scalar().to_f64(); | ||
| if dot < 0.0 { | ||
| q = q.slice_assign(s![.., col..(col + 1)], -q_col); | ||
| let r_row = r.clone().slice(s![col..(col + 1), ..]); | ||
| r = r.slice_assign(s![col..(col + 1), ..], -r_row); | ||
| } | ||
| } | ||
| (q, r) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_qr_square_reconstruction() { | ||
| let device = Default::default(); | ||
| let tensor = TestTensor::<2>::from_data( | ||
| [[12.0, -51.0, 4.0], [6.0, 167.0, -68.0], [-4.0, 24.0, -41.0]], | ||
| &device, | ||
| ); | ||
| let (q, r) = linalg::qr_decomposition(tensor.clone()); | ||
|
|
||
| assert_eq!(q.dims(), [3, 3]); | ||
| assert_eq!(r.dims(), [3, 3]); | ||
|
|
||
| let reconstructed = q.clone().matmul(r.clone()); | ||
| let tolerance = Tolerance::permissive(); | ||
| reconstructed | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&tensor.into_data(), tolerance); | ||
| let q_expected = TestTensor::<2>::from_data( | ||
| [ | ||
| [-0.85714287, 0.3942857, 0.33142856], | ||
| [-0.42857143, -0.9028571, -0.034285713], | ||
| [0.2857143, -0.17142858, 0.94285715], | ||
| ], | ||
| &device, | ||
| ); | ||
| let r_expected = TestTensor::<2>::from_data( | ||
| [[-14.0, -21.0, 14.0], [0.0, -175.0, 70.0], [0.0, 0.0, -35.0]], | ||
| &device, | ||
| ); | ||
| let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone()); | ||
| q_aligned | ||
| .clone() | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&q_expected.into_data(), tolerance); | ||
| r_aligned | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&r_expected.into_data(), tolerance); | ||
| assert_orthonormal(q_aligned, tolerance); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_qr_tall_reconstruction() { | ||
| let device = Default::default(); | ||
| let tensor = | ||
| TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], &device); | ||
| let (q, r) = linalg::qr_decomposition(tensor.clone()); | ||
|
|
||
| assert_eq!(q.dims(), [4, 2]); | ||
| assert_eq!(r.dims(), [2, 2]); | ||
|
|
||
| let reconstructed = q.clone().matmul(r.clone()); | ||
| let tolerance = Tolerance::permissive(); | ||
| reconstructed | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&tensor.into_data(), tolerance); | ||
| let q_expected = TestTensor::<2>::from_data( | ||
| [ | ||
| [-0.10910895, -0.82951504], | ||
| [-0.32732683, -0.43915504], | ||
| [-0.54554474, -0.048795003], | ||
| [-0.7637626, 0.341565], | ||
| ], | ||
| &device, | ||
| ); | ||
| let r_expected = | ||
| TestTensor::<2>::from_data([[-9.165152, -10.910894], [0.0, -0.97590005]], &device); | ||
| let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone()); | ||
| q_aligned | ||
| .clone() | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&q_expected.into_data(), tolerance); | ||
| r_aligned | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&r_expected.into_data(), tolerance); | ||
| assert_orthonormal(q_aligned, tolerance); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_qr_wide_reconstruction() { | ||
| let device = Default::default(); | ||
| let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device); | ||
| let (q, r) = linalg::qr_decomposition(tensor.clone()); | ||
|
|
||
| assert_eq!(q.dims(), [2, 2]); | ||
| assert_eq!(r.dims(), [2, 4]); | ||
|
|
||
| let reconstructed = q.clone().matmul(r.clone()); | ||
| let tolerance = Tolerance::permissive(); | ||
| reconstructed | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&tensor.into_data(), tolerance); | ||
| let q_expected = TestTensor::<2>::from_data( | ||
| [[-0.19611613, -0.9805807], [-0.9805807, 0.19611613]], | ||
| &device, | ||
| ); | ||
| let r_expected = TestTensor::<2>::from_data( | ||
| [ | ||
| [-5.0990195, -6.2757163, -7.452413, -8.62911], | ||
| [0.0, -0.78446454, -1.5689291, -2.3533936], | ||
| ], | ||
| &device, | ||
| ); | ||
| let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone()); | ||
| q_aligned | ||
| .clone() | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&q_expected.into_data(), tolerance); | ||
| r_aligned | ||
| .into_data() | ||
| .assert_approx_eq::<FloatElem>(&r_expected.into_data(), tolerance); | ||
| assert_orthonormal(q_aligned, tolerance); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
129 changes: 129 additions & 0 deletions
129
crates/burn-tensor/src/tensor/linalg/qr_decomposition.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| use alloc::vec::Vec; | ||
|
|
||
| use crate::{ | ||
| DType, Element, ElementConversion, backend::Backend, cast::ToElement, linalg::outer, s, | ||
| tensor::Tensor, | ||
| }; | ||
|
|
||
| struct Householder<B: Backend> { | ||
| v: Tensor<B, 1>, | ||
| tau: B::FloatElem, | ||
| } | ||
|
|
||
| fn eps_for_dtype(dtype: DType) -> f64 { | ||
| match dtype { | ||
| DType::F16 => 1e-3, | ||
| DType::BF16 => 1e-2, | ||
| DType::F32 => 1e-7, | ||
| DType::F64 => 1e-15, | ||
| _ => 1e-7, | ||
| } | ||
| } | ||
|
|
||
| /// Performs QR decomposition of a matrix using Householder reflections. | ||
| /// | ||
| /// The input matrix `A` is factored into `Q` and `R` such that `A = Q * R`, | ||
| /// where `Q` has orthonormal columns and `R` is upper trapezoidal. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tuple containing: | ||
| /// - `Q`: a matrix of shape `[m, k]` | ||
| /// - `R`: a matrix of shape `[k, n]` | ||
| /// | ||
| /// where `m` and `n` are the input dimensions and `k = min(m, n)`. | ||
| pub fn qr_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) { | ||
| let device = tensor.device(); | ||
| let [m, n] = tensor.shape().dims::<2>(); | ||
| let k_max = m.min(n); | ||
|
|
||
| if k_max == 0 { | ||
| let q = Tensor::<B, 2>::zeros([m, k_max], &device); | ||
| let r = Tensor::<B, 2>::zeros([k_max, n], &device); | ||
| return (q, r); | ||
| } | ||
|
|
||
| let mut r = tensor; | ||
| // Store Householder vectors to build Q after R is formed. | ||
| let mut reflectors: Vec<Option<Householder<B>>> = Vec::with_capacity(k_max); | ||
| let eps_base = eps_for_dtype(<B::FloatElem as Element>::dtype()); | ||
|
|
||
| for k in 0..k_max { | ||
| let r_sub = r.clone().slice(s![k.., k..]); | ||
| // Current column segment to be zeroed below the diagonal. | ||
| let x = r_sub.clone().slice(s![.., 0..1]).squeeze_dim(1); | ||
| let rows = m - k; | ||
| let x0 = x.clone().slice(s![0]); | ||
| let x0_scalar = x0.clone().into_scalar().to_f64(); | ||
| let xnorm = if rows > 1 { | ||
| x.clone() | ||
| .slice(s![1..]) | ||
| .square() | ||
| .sum() | ||
| .sqrt() | ||
| .into_scalar() | ||
| .to_f64() | ||
| } else { | ||
| 0.0 | ||
| }; | ||
| let scale = x0_scalar.abs().max(xnorm).max(1.0); | ||
| let eps = eps_base * scale; | ||
| if xnorm <= eps { | ||
| reflectors.push(None); | ||
| continue; | ||
| } | ||
|
|
||
| // Choose sign to avoid cancellation in beta. | ||
| let sign = if x0_scalar >= 0.0 { 1.0 } else { -1.0 }; | ||
| let norm = (x0_scalar * x0_scalar + xnorm * xnorm).sqrt(); | ||
| let beta = -sign * norm; | ||
| let denom = x0_scalar - beta; | ||
| if denom.abs() <= eps || !beta.is_finite() { | ||
| reflectors.push(None); | ||
| continue; | ||
| } | ||
| let tau_scalar = (beta - x0_scalar) / beta; | ||
| let tau = <B::FloatElem as ElementConversion>::from_elem(tau_scalar); | ||
| let mut v = x.mul_scalar(1.0 / denom); | ||
| let v0 = x0.clone().mul_scalar(0.0).add_scalar(1.0); | ||
| v = v.slice_assign(s![0], v0); | ||
|
|
||
| // w = R^T * v for the rank-1 update. | ||
| let w = (r_sub.clone().transpose() * v.clone().unsqueeze_dim::<2>(0)) | ||
| .sum_dim(1) | ||
| .squeeze_dim::<1>(1); | ||
| // R = R - tau * v * w^T | ||
| let update = outer::<B, 1, 2, _>(v.clone(), w).mul_scalar(tau); | ||
| let r_sub = r_sub - update; | ||
| r = r.slice_assign(s![k.., k..], r_sub); | ||
|
|
||
| reflectors.push(Some(Householder { v, tau })); | ||
| } | ||
|
|
||
| // Start with identity, then apply reflectors in reverse order. | ||
| let mut q = Tensor::<B, 2>::eye(m, &device); | ||
| if k_max < m { | ||
| q = q.slice(s![.., 0..k_max]); | ||
| } | ||
|
|
||
| for k in (0..k_max).rev() { | ||
| let Some(reflector) = reflectors.get_mut(k).and_then(|r| r.take()) else { | ||
| continue; | ||
| }; | ||
|
|
||
| let v = reflector.v; | ||
| let tau = reflector.tau; | ||
|
|
||
| let q_sub = q.clone().slice(s![k.., ..]); | ||
| // Apply reflector: Q = Q - tau * v * (Q^T v)^T | ||
| let wq = (q_sub.clone().transpose() * v.clone().unsqueeze_dim::<2>(0)) | ||
| .sum_dim(1) | ||
| .squeeze_dim::<1>(1); | ||
| let update_q = outer::<B, 1, 2, _>(v, wq).mul_scalar(tau); | ||
| let q_sub = q_sub - update_q; | ||
| q = q.slice_assign(s![k.., ..], q_sub); | ||
| } | ||
|
|
||
| let r = r.slice(s![0..k_max, ..]); | ||
| (q, r) | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could use the short-hand
tensor.dims::<2>()instead