Skip to content

Conversation

@RustyBamboo
Copy link

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#1538

Changes

Would be nice to have more linalg operations out-of-the-box in burn. This adds a householder QR decomposition.

Also replaced the Gram-Shmidt method with a call to this one in initializer module.

Testing

Generated examples in python and decomposed them via numpy.linalg.qr. Then compared burn implementation against the expected output from numpy (up to a degree-of-freedom in sign). Also check that our Q is indeed orthogonal.

Notes

  • There is a lot of syncing here with host and eventually will benefit from a block-based QR decomposition or something lower-level
  • Can't be autodiffed
  • Wasn't sure what the best way to check against numerical precession... mimicked safemin from LAPACK

- replace gram-shmidt initializer
- add qr to linalg functions in book
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • There is a lot of syncing here with host and eventually will benefit from a block-based QR decomposition or something lower-level

In the meantime, we should add a performance note similar to the one in LU decomposition:

/// # Performance note (synchronization / device transfers)
/// This function may involve multiple synchronizations and device transfers, especially
/// when determining pivot elements and performing row swaps. This can impact performance,

  • Wasn't sure what the best way to check against numerical precession... mimicked safemin from LAPACK

For tests, you correctly used the Tolerance in checks. But unsure what is the typical error tolerance for such algos.

Also, please fix the formatting issue as pointed out by the CI. You can just use cargo fmt.

/// where `m` and `n` are the input dimensions and `k = min(m, n)`.
pub fn qr_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
let device = tensor.device();
let [m, n] = tensor.shape().dims::<2>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use the short-hand tensor.dims::<2>() instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants