From ff2e8b6eac1ef5e21882f65c8056d411a27703ba Mon Sep 17 00:00:00 2001 From: RustyBamboo Date: Thu, 25 Dec 2025 14:46:06 -0800 Subject: [PATCH] Implement QR decomposition - replace gram-shmidt initializer - add qr to linalg functions in book --- burn-book/src/building-blocks/tensor.md | 1 + .../tests/tensor/float/linalg/mod.rs | 1 + .../tensor/float/linalg/qr_decomposition.rs | 145 ++++++++++++++++++ crates/burn-core/src/module/initializer.rs | 51 +----- crates/burn-tensor/src/tensor/linalg/mod.rs | 2 + .../src/tensor/linalg/qr_decomposition.rs | 129 ++++++++++++++++ 6 files changed, 281 insertions(+), 48 deletions(-) create mode 100644 crates/burn-backend-tests/tests/tensor/float/linalg/qr_decomposition.rs create mode 100644 crates/burn-tensor/src/tensor/linalg/qr_decomposition.rs diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 8dab138034..45688535b7 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -408,6 +408,7 @@ strategies. | `linalg::trace(tensor)` | `torch.trace(tensor)` | | `linalg::outer(x, y)` | `torch.outer(x, y)` / `einsum("bi,bj->bij", …)` | | `linalg::lu_decomposition(tensor)` | `torch.linalg.lu(tensor)` | +| `linalg::qr_decomposition(tensor)` | `torch.linalg.qr(tensor)` | | `linalg::matvec(matrix, vector)` | `torch.matmul(matrix, vector)` / `@` operator | ## Displaying Tensor Details diff --git a/crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs b/crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs index 495dbd8209..e5156d3bb1 100644 --- a/crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs +++ b/crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs @@ -3,6 +3,7 @@ use super::*; pub(crate) mod cosine_similarity; pub(crate) mod diag; pub(crate) mod lu_decomposition; +pub(crate) mod qr_decomposition; pub(crate) mod matvec; pub(crate) mod outer; pub(crate) mod trace; diff --git a/crates/burn-backend-tests/tests/tensor/float/linalg/qr_decomposition.rs b/crates/burn-backend-tests/tests/tensor/float/linalg/qr_decomposition.rs new file mode 100644 index 0000000000..6877cf0ad4 --- /dev/null +++ b/crates/burn-backend-tests/tests/tensor/float/linalg/qr_decomposition.rs @@ -0,0 +1,145 @@ +use super::*; +use burn_tensor::cast::ToElement; +use burn_tensor::{Tolerance, linalg, s}; + +fn assert_orthonormal(q: TestTensor<2>, tolerance: Tolerance) { + let device = q.device(); + let [_m, k] = q.dims(); + let eye = TestTensor::<2>::eye(k, &device); + let qtq = q.clone().transpose().matmul(q); + qtq.into_data() + .assert_approx_eq::(&eye.into_data(), tolerance); +} + +// QR factors are unique up to column-wise sign flips; align to reference. +fn align_qr_to_expected( + mut q: TestTensor<2>, + mut r: TestTensor<2>, + q_expected: TestTensor<2>, +) -> (TestTensor<2>, TestTensor<2>) { + let [_m, k] = q_expected.dims(); + for col in 0..k { + let q_col = q.clone().slice(s![.., col..(col + 1)]); + let q_ref = q_expected.clone().slice(s![.., col..(col + 1)]); + let dot = (q_col.clone() * q_ref).sum().into_scalar().to_f64(); + if dot < 0.0 { + q = q.slice_assign(s![.., col..(col + 1)], -q_col); + let r_row = r.clone().slice(s![col..(col + 1), ..]); + r = r.slice_assign(s![col..(col + 1), ..], -r_row); + } + } + (q, r) +} + +#[test] +fn test_qr_square_reconstruction() { + let device = Default::default(); + let tensor = TestTensor::<2>::from_data( + [[12.0, -51.0, 4.0], [6.0, 167.0, -68.0], [-4.0, 24.0, -41.0]], + &device, + ); + let (q, r) = linalg::qr_decomposition(tensor.clone()); + + assert_eq!(q.dims(), [3, 3]); + assert_eq!(r.dims(), [3, 3]); + + let reconstructed = q.clone().matmul(r.clone()); + let tolerance = Tolerance::permissive(); + reconstructed + .into_data() + .assert_approx_eq::(&tensor.into_data(), tolerance); + let q_expected = TestTensor::<2>::from_data( + [ + [-0.85714287, 0.3942857, 0.33142856], + [-0.42857143, -0.9028571, -0.034285713], + [0.2857143, -0.17142858, 0.94285715], + ], + &device, + ); + let r_expected = TestTensor::<2>::from_data( + [[-14.0, -21.0, 14.0], [0.0, -175.0, 70.0], [0.0, 0.0, -35.0]], + &device, + ); + let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone()); + q_aligned + .clone() + .into_data() + .assert_approx_eq::(&q_expected.into_data(), tolerance); + r_aligned + .into_data() + .assert_approx_eq::(&r_expected.into_data(), tolerance); + assert_orthonormal(q_aligned, tolerance); +} + +#[test] +fn test_qr_tall_reconstruction() { + let device = Default::default(); + let tensor = + TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], &device); + let (q, r) = linalg::qr_decomposition(tensor.clone()); + + assert_eq!(q.dims(), [4, 2]); + assert_eq!(r.dims(), [2, 2]); + + let reconstructed = q.clone().matmul(r.clone()); + let tolerance = Tolerance::permissive(); + reconstructed + .into_data() + .assert_approx_eq::(&tensor.into_data(), tolerance); + let q_expected = TestTensor::<2>::from_data( + [ + [-0.10910895, -0.82951504], + [-0.32732683, -0.43915504], + [-0.54554474, -0.048795003], + [-0.7637626, 0.341565], + ], + &device, + ); + let r_expected = + TestTensor::<2>::from_data([[-9.165152, -10.910894], [0.0, -0.97590005]], &device); + let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone()); + q_aligned + .clone() + .into_data() + .assert_approx_eq::(&q_expected.into_data(), tolerance); + r_aligned + .into_data() + .assert_approx_eq::(&r_expected.into_data(), tolerance); + assert_orthonormal(q_aligned, tolerance); +} + +#[test] +fn test_qr_wide_reconstruction() { + let device = Default::default(); + let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device); + let (q, r) = linalg::qr_decomposition(tensor.clone()); + + assert_eq!(q.dims(), [2, 2]); + assert_eq!(r.dims(), [2, 4]); + + let reconstructed = q.clone().matmul(r.clone()); + let tolerance = Tolerance::permissive(); + reconstructed + .into_data() + .assert_approx_eq::(&tensor.into_data(), tolerance); + let q_expected = TestTensor::<2>::from_data( + [[-0.19611613, -0.9805807], [-0.9805807, 0.19611613]], + &device, + ); + let r_expected = TestTensor::<2>::from_data( + [ + [-5.0990195, -6.2757163, -7.452413, -8.62911], + [0.0, -0.78446454, -1.5689291, -2.3533936], + ], + &device, + ); + let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone()); + q_aligned + .clone() + .into_data() + .assert_approx_eq::(&q_expected.into_data(), tolerance); + r_aligned + .into_data() + .assert_approx_eq::(&r_expected.into_data(), tolerance); + assert_orthonormal(q_aligned, tolerance); +} diff --git a/crates/burn-core/src/module/initializer.rs b/crates/burn-core/src/module/initializer.rs index 8903bc9361..38267c42cc 100644 --- a/crates/burn-core/src/module/initializer.rs +++ b/crates/burn-core/src/module/initializer.rs @@ -3,7 +3,7 @@ use crate::tensor::Shape; use crate::config::Config; use crate::module::{Param, ParamId}; use crate::tensor::backend::Backend; -use crate::tensor::{Distribution, Tensor, s}; +use crate::tensor::{Distribution, Tensor}; use crate as burn; @@ -176,7 +176,7 @@ impl Initializer { t = t.transpose(); } - let (q, r) = qr_decomposition(t, device); + let (q, r) = crate::tensor::linalg::qr_decomposition(t); let [r_rows, r_cols] = r.clone().dims(); let diag_r = Tensor::::ones([1, r_rows], device) @@ -242,51 +242,6 @@ fn normal_draw>( Tensor::::random(shape, distribution, device) } -fn qr_decomposition( - a: Tensor, - device: &B::Device, -) -> (Tensor, Tensor) { - // Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process - - let [m, n] = a.clone().dims(); - let mut q = Tensor::::zeros([m, n], device); - let mut r = Tensor::::zeros([n, n], device); - - for j in 0..n { - let mut v: Tensor = a.clone().slice(s![.., j..=j]).squeeze_dim(1); - - for i in 0..j { - let q_i: Tensor = q.clone().slice(s![.., i..=i]).squeeze_dim(1); - let r_ij = q_i.clone().mul(v.clone()).sum(); - - r = r - .clone() - .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze()); - - v = v - q_i.mul(r_ij); - } - - // norm of v - let r_jj = v - .clone() - .powf(Tensor::from_floats([2.0], device)) - .sum() - .sqrt(); - - r = r - .clone() - .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze()); - - let q_j = v / r_jj; - - q = q - .clone() - .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1)); - } - - (q, r) -} - #[cfg(test)] mod tests { use super::*; @@ -547,7 +502,7 @@ mod tests { [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]], &Default::default(), ); - let qr = qr_decomposition(a.clone(), &Default::default()); + let qr = crate::tensor::linalg::qr_decomposition(a.clone()); // Q @ R should reconstruct input `a` let q_matmul_r = qr.0.clone().matmul(qr.1.clone()); diff --git a/crates/burn-tensor/src/tensor/linalg/mod.rs b/crates/burn-tensor/src/tensor/linalg/mod.rs index ac235bb378..6ab7d1f216 100644 --- a/crates/burn-tensor/src/tensor/linalg/mod.rs +++ b/crates/burn-tensor/src/tensor/linalg/mod.rs @@ -1,6 +1,7 @@ mod cosine_similarity; mod diag; mod lu_decomposition; +mod qr_decomposition; mod matvec; mod outer; mod trace; @@ -9,6 +10,7 @@ mod vector_norm; pub use cosine_similarity::*; pub use diag::*; pub use lu_decomposition::*; +pub use qr_decomposition::*; pub use matvec::*; pub use outer::*; pub use trace::*; diff --git a/crates/burn-tensor/src/tensor/linalg/qr_decomposition.rs b/crates/burn-tensor/src/tensor/linalg/qr_decomposition.rs new file mode 100644 index 0000000000..bc99ac094f --- /dev/null +++ b/crates/burn-tensor/src/tensor/linalg/qr_decomposition.rs @@ -0,0 +1,129 @@ +use alloc::vec::Vec; + +use crate::{ + DType, Element, ElementConversion, backend::Backend, cast::ToElement, linalg::outer, s, + tensor::Tensor, +}; + +struct Householder { + v: Tensor, + tau: B::FloatElem, +} + +fn eps_for_dtype(dtype: DType) -> f64 { + match dtype { + DType::F16 => 1e-3, + DType::BF16 => 1e-2, + DType::F32 => 1e-7, + DType::F64 => 1e-15, + _ => 1e-7, + } +} + +/// Performs QR decomposition of a matrix using Householder reflections. +/// +/// The input matrix `A` is factored into `Q` and `R` such that `A = Q * R`, +/// where `Q` has orthonormal columns and `R` is upper trapezoidal. +/// +/// # Returns +/// +/// A tuple containing: +/// - `Q`: a matrix of shape `[m, k]` +/// - `R`: a matrix of shape `[k, n]` +/// +/// where `m` and `n` are the input dimensions and `k = min(m, n)`. +pub fn qr_decomposition(tensor: Tensor) -> (Tensor, Tensor) { + let device = tensor.device(); + let [m, n] = tensor.shape().dims::<2>(); + let k_max = m.min(n); + + if k_max == 0 { + let q = Tensor::::zeros([m, k_max], &device); + let r = Tensor::::zeros([k_max, n], &device); + return (q, r); + } + + let mut r = tensor; + // Store Householder vectors to build Q after R is formed. + let mut reflectors: Vec>> = Vec::with_capacity(k_max); + let eps_base = eps_for_dtype(::dtype()); + + for k in 0..k_max { + let r_sub = r.clone().slice(s![k.., k..]); + // Current column segment to be zeroed below the diagonal. + let x = r_sub.clone().slice(s![.., 0..1]).squeeze_dim(1); + let rows = m - k; + let x0 = x.clone().slice(s![0]); + let x0_scalar = x0.clone().into_scalar().to_f64(); + let xnorm = if rows > 1 { + x.clone() + .slice(s![1..]) + .square() + .sum() + .sqrt() + .into_scalar() + .to_f64() + } else { + 0.0 + }; + let scale = x0_scalar.abs().max(xnorm).max(1.0); + let eps = eps_base * scale; + if xnorm <= eps { + reflectors.push(None); + continue; + } + + // Choose sign to avoid cancellation in beta. + let sign = if x0_scalar >= 0.0 { 1.0 } else { -1.0 }; + let norm = (x0_scalar * x0_scalar + xnorm * xnorm).sqrt(); + let beta = -sign * norm; + let denom = x0_scalar - beta; + if denom.abs() <= eps || !beta.is_finite() { + reflectors.push(None); + continue; + } + let tau_scalar = (beta - x0_scalar) / beta; + let tau = ::from_elem(tau_scalar); + let mut v = x.mul_scalar(1.0 / denom); + let v0 = x0.clone().mul_scalar(0.0).add_scalar(1.0); + v = v.slice_assign(s![0], v0); + + // w = R^T * v for the rank-1 update. + let w = (r_sub.clone().transpose() * v.clone().unsqueeze_dim::<2>(0)) + .sum_dim(1) + .squeeze_dim::<1>(1); + // R = R - tau * v * w^T + let update = outer::(v.clone(), w).mul_scalar(tau); + let r_sub = r_sub - update; + r = r.slice_assign(s![k.., k..], r_sub); + + reflectors.push(Some(Householder { v, tau })); + } + + // Start with identity, then apply reflectors in reverse order. + let mut q = Tensor::::eye(m, &device); + if k_max < m { + q = q.slice(s![.., 0..k_max]); + } + + for k in (0..k_max).rev() { + let Some(reflector) = reflectors.get_mut(k).and_then(|r| r.take()) else { + continue; + }; + + let v = reflector.v; + let tau = reflector.tau; + + let q_sub = q.clone().slice(s![k.., ..]); + // Apply reflector: Q = Q - tau * v * (Q^T v)^T + let wq = (q_sub.clone().transpose() * v.clone().unsqueeze_dim::<2>(0)) + .sum_dim(1) + .squeeze_dim::<1>(1); + let update_q = outer::(v, wq).mul_scalar(tau); + let q_sub = q_sub - update_q; + q = q.slice_assign(s![k.., ..], q_sub); + } + + let r = r.slice(s![0..k_max, ..]); + (q, r) +}