Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ strategies.
| `linalg::trace(tensor)` | `torch.trace(tensor)` |
| `linalg::outer(x, y)` | `torch.outer(x, y)` / `einsum("bi,bj->bij", …)` |
| `linalg::lu_decomposition(tensor)` | `torch.linalg.lu(tensor)` |
| `linalg::qr_decomposition(tensor)` | `torch.linalg.qr(tensor)` |
| `linalg::matvec(matrix, vector)` | `torch.matmul(matrix, vector)` / `@` operator |

## Displaying Tensor Details
Expand Down
1 change: 1 addition & 0 deletions crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::*;
pub(crate) mod cosine_similarity;
pub(crate) mod diag;
pub(crate) mod lu_decomposition;
pub(crate) mod qr_decomposition;
pub(crate) mod matvec;
pub(crate) mod outer;
pub(crate) mod trace;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use super::*;
use burn_tensor::cast::ToElement;
use burn_tensor::{Tolerance, linalg, s};

fn assert_orthonormal(q: TestTensor<2>, tolerance: Tolerance<FloatElem>) {
let device = q.device();
let [_m, k] = q.dims();
let eye = TestTensor::<2>::eye(k, &device);
let qtq = q.clone().transpose().matmul(q);
qtq.into_data()
.assert_approx_eq::<FloatElem>(&eye.into_data(), tolerance);
}

// QR factors are unique up to column-wise sign flips; align to reference.
fn align_qr_to_expected(
mut q: TestTensor<2>,
mut r: TestTensor<2>,
q_expected: TestTensor<2>,
) -> (TestTensor<2>, TestTensor<2>) {
let [_m, k] = q_expected.dims();
for col in 0..k {
let q_col = q.clone().slice(s![.., col..(col + 1)]);
let q_ref = q_expected.clone().slice(s![.., col..(col + 1)]);
let dot = (q_col.clone() * q_ref).sum().into_scalar().to_f64();
if dot < 0.0 {
q = q.slice_assign(s![.., col..(col + 1)], -q_col);
let r_row = r.clone().slice(s![col..(col + 1), ..]);
r = r.slice_assign(s![col..(col + 1), ..], -r_row);
}
}
(q, r)
}

#[test]
fn test_qr_square_reconstruction() {
let device = Default::default();
let tensor = TestTensor::<2>::from_data(
[[12.0, -51.0, 4.0], [6.0, 167.0, -68.0], [-4.0, 24.0, -41.0]],
&device,
);
let (q, r) = linalg::qr_decomposition(tensor.clone());

assert_eq!(q.dims(), [3, 3]);
assert_eq!(r.dims(), [3, 3]);

let reconstructed = q.clone().matmul(r.clone());
let tolerance = Tolerance::permissive();
reconstructed
.into_data()
.assert_approx_eq::<FloatElem>(&tensor.into_data(), tolerance);
let q_expected = TestTensor::<2>::from_data(
[
[-0.85714287, 0.3942857, 0.33142856],
[-0.42857143, -0.9028571, -0.034285713],
[0.2857143, -0.17142858, 0.94285715],
],
&device,
);
let r_expected = TestTensor::<2>::from_data(
[[-14.0, -21.0, 14.0], [0.0, -175.0, 70.0], [0.0, 0.0, -35.0]],
&device,
);
let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone());
q_aligned
.clone()
.into_data()
.assert_approx_eq::<FloatElem>(&q_expected.into_data(), tolerance);
r_aligned
.into_data()
.assert_approx_eq::<FloatElem>(&r_expected.into_data(), tolerance);
assert_orthonormal(q_aligned, tolerance);
}

#[test]
fn test_qr_tall_reconstruction() {
let device = Default::default();
let tensor =
TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], &device);
let (q, r) = linalg::qr_decomposition(tensor.clone());

assert_eq!(q.dims(), [4, 2]);
assert_eq!(r.dims(), [2, 2]);

let reconstructed = q.clone().matmul(r.clone());
let tolerance = Tolerance::permissive();
reconstructed
.into_data()
.assert_approx_eq::<FloatElem>(&tensor.into_data(), tolerance);
let q_expected = TestTensor::<2>::from_data(
[
[-0.10910895, -0.82951504],
[-0.32732683, -0.43915504],
[-0.54554474, -0.048795003],
[-0.7637626, 0.341565],
],
&device,
);
let r_expected =
TestTensor::<2>::from_data([[-9.165152, -10.910894], [0.0, -0.97590005]], &device);
let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone());
q_aligned
.clone()
.into_data()
.assert_approx_eq::<FloatElem>(&q_expected.into_data(), tolerance);
r_aligned
.into_data()
.assert_approx_eq::<FloatElem>(&r_expected.into_data(), tolerance);
assert_orthonormal(q_aligned, tolerance);
}

#[test]
fn test_qr_wide_reconstruction() {
let device = Default::default();
let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device);
let (q, r) = linalg::qr_decomposition(tensor.clone());

assert_eq!(q.dims(), [2, 2]);
assert_eq!(r.dims(), [2, 4]);

let reconstructed = q.clone().matmul(r.clone());
let tolerance = Tolerance::permissive();
reconstructed
.into_data()
.assert_approx_eq::<FloatElem>(&tensor.into_data(), tolerance);
let q_expected = TestTensor::<2>::from_data(
[[-0.19611613, -0.9805807], [-0.9805807, 0.19611613]],
&device,
);
let r_expected = TestTensor::<2>::from_data(
[
[-5.0990195, -6.2757163, -7.452413, -8.62911],
[0.0, -0.78446454, -1.5689291, -2.3533936],
],
&device,
);
let (q_aligned, r_aligned) = align_qr_to_expected(q, r, q_expected.clone());
q_aligned
.clone()
.into_data()
.assert_approx_eq::<FloatElem>(&q_expected.into_data(), tolerance);
r_aligned
.into_data()
.assert_approx_eq::<FloatElem>(&r_expected.into_data(), tolerance);
assert_orthonormal(q_aligned, tolerance);
}
51 changes: 3 additions & 48 deletions crates/burn-core/src/module/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::tensor::Shape;
use crate::config::Config;
use crate::module::{Param, ParamId};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor, s};
use crate::tensor::{Distribution, Tensor};

use crate as burn;

Expand Down Expand Up @@ -176,7 +176,7 @@ impl Initializer {
t = t.transpose();
}

let (q, r) = qr_decomposition(t, device);
let (q, r) = crate::tensor::linalg::qr_decomposition(t);
let [r_rows, r_cols] = r.clone().dims();

let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
Expand Down Expand Up @@ -242,51 +242,6 @@ fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
Tensor::<B, D>::random(shape, distribution, device)
}

fn qr_decomposition<B: Backend>(
a: Tensor<B, 2>,
device: &B::Device,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
// Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process

let [m, n] = a.clone().dims();
let mut q = Tensor::<B, 2>::zeros([m, n], device);
let mut r = Tensor::<B, 2>::zeros([n, n], device);

for j in 0..n {
let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);

for i in 0..j {
let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
let r_ij = q_i.clone().mul(v.clone()).sum();

r = r
.clone()
.slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());

v = v - q_i.mul(r_ij);
}

// norm of v
let r_jj = v
.clone()
.powf(Tensor::from_floats([2.0], device))
.sum()
.sqrt();

r = r
.clone()
.slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());

let q_j = v / r_jj;

q = q
.clone()
.slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
}

(q, r)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -547,7 +502,7 @@ mod tests {
[[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
&Default::default(),
);
let qr = qr_decomposition(a.clone(), &Default::default());
let qr = crate::tensor::linalg::qr_decomposition(a.clone());

// Q @ R should reconstruct input `a`
let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-tensor/src/tensor/linalg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod cosine_similarity;
mod diag;
mod lu_decomposition;
mod qr_decomposition;
mod matvec;
mod outer;
mod trace;
Expand All @@ -9,6 +10,7 @@ mod vector_norm;
pub use cosine_similarity::*;
pub use diag::*;
pub use lu_decomposition::*;
pub use qr_decomposition::*;
pub use matvec::*;
pub use outer::*;
pub use trace::*;
Expand Down
129 changes: 129 additions & 0 deletions crates/burn-tensor/src/tensor/linalg/qr_decomposition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use alloc::vec::Vec;

use crate::{
DType, Element, ElementConversion, backend::Backend, cast::ToElement, linalg::outer, s,
tensor::Tensor,
};

struct Householder<B: Backend> {
v: Tensor<B, 1>,
tau: B::FloatElem,
}

fn eps_for_dtype(dtype: DType) -> f64 {
match dtype {
DType::F16 => 1e-3,
DType::BF16 => 1e-2,
DType::F32 => 1e-7,
DType::F64 => 1e-15,
_ => 1e-7,
}
}

/// Performs QR decomposition of a matrix using Householder reflections.
///
/// The input matrix `A` is factored into `Q` and `R` such that `A = Q * R`,
/// where `Q` has orthonormal columns and `R` is upper trapezoidal.
///
/// # Returns
///
/// A tuple containing:
/// - `Q`: a matrix of shape `[m, k]`
/// - `R`: a matrix of shape `[k, n]`
///
/// where `m` and `n` are the input dimensions and `k = min(m, n)`.
pub fn qr_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let device = tensor.device();
let [m, n] = tensor.shape().dims::<2>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use the short-hand tensor.dims::<2>() instead

let k_max = m.min(n);

if k_max == 0 {
let q = Tensor::<B, 2>::zeros([m, k_max], &device);
let r = Tensor::<B, 2>::zeros([k_max, n], &device);
return (q, r);
}

let mut r = tensor;
// Store Householder vectors to build Q after R is formed.
let mut reflectors: Vec<Option<Householder<B>>> = Vec::with_capacity(k_max);
let eps_base = eps_for_dtype(<B::FloatElem as Element>::dtype());

for k in 0..k_max {
let r_sub = r.clone().slice(s![k.., k..]);
// Current column segment to be zeroed below the diagonal.
let x = r_sub.clone().slice(s![.., 0..1]).squeeze_dim(1);
let rows = m - k;
let x0 = x.clone().slice(s![0]);
let x0_scalar = x0.clone().into_scalar().to_f64();
let xnorm = if rows > 1 {
x.clone()
.slice(s![1..])
.square()
.sum()
.sqrt()
.into_scalar()
.to_f64()
} else {
0.0
};
let scale = x0_scalar.abs().max(xnorm).max(1.0);
let eps = eps_base * scale;
if xnorm <= eps {
reflectors.push(None);
continue;
}

// Choose sign to avoid cancellation in beta.
let sign = if x0_scalar >= 0.0 { 1.0 } else { -1.0 };
let norm = (x0_scalar * x0_scalar + xnorm * xnorm).sqrt();
let beta = -sign * norm;
let denom = x0_scalar - beta;
if denom.abs() <= eps || !beta.is_finite() {
reflectors.push(None);
continue;
}
let tau_scalar = (beta - x0_scalar) / beta;
let tau = <B::FloatElem as ElementConversion>::from_elem(tau_scalar);
let mut v = x.mul_scalar(1.0 / denom);
let v0 = x0.clone().mul_scalar(0.0).add_scalar(1.0);
v = v.slice_assign(s![0], v0);

// w = R^T * v for the rank-1 update.
let w = (r_sub.clone().transpose() * v.clone().unsqueeze_dim::<2>(0))
.sum_dim(1)
.squeeze_dim::<1>(1);
// R = R - tau * v * w^T
let update = outer::<B, 1, 2, _>(v.clone(), w).mul_scalar(tau);
let r_sub = r_sub - update;
r = r.slice_assign(s![k.., k..], r_sub);

reflectors.push(Some(Householder { v, tau }));
}

// Start with identity, then apply reflectors in reverse order.
let mut q = Tensor::<B, 2>::eye(m, &device);
if k_max < m {
q = q.slice(s![.., 0..k_max]);
}

for k in (0..k_max).rev() {
let Some(reflector) = reflectors.get_mut(k).and_then(|r| r.take()) else {
continue;
};

let v = reflector.v;
let tau = reflector.tau;

let q_sub = q.clone().slice(s![k.., ..]);
// Apply reflector: Q = Q - tau * v * (Q^T v)^T
let wq = (q_sub.clone().transpose() * v.clone().unsqueeze_dim::<2>(0))
.sum_dim(1)
.squeeze_dim::<1>(1);
let update_q = outer::<B, 1, 2, _>(v, wq).mul_scalar(tau);
let q_sub = q_sub - update_q;
q = q.slice_assign(s![k.., ..], q_sub);
}

let r = r.slice(s![0..k_max, ..]);
(q, r)
}
Loading