-
Notifications
You must be signed in to change notification settings - Fork 68
SVGP with GraphKernel Fix #552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
gpjax/variational_families.py
Outdated
def ensure_2d(mat): | ||
mat = jnp.asarray(mat) | ||
if mat.ndim == 0: | ||
return mat[None, None] | ||
if mat.ndim == 1: | ||
return mat[:, None] | ||
return mat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you recycle JAX's at_least_2d
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into it and it seems jax's implementation reverses the shape, in contract to what I would need. Then I would need to do a reshape after passing through jax's at_least2d. Maybe a helper function doesn't hurt if I subclass?
From the docs:
an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1, 1), 1D arrays of shape (N,) are converted to shape (1, N), and arrays of all other shapes are returned unchanged.
return mat | ||
|
||
|
||
class VariationalGaussian(AbstractVariationalGaussian[L]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of if-statements herein makes the code a little messy. Could you subclass VariationalGaussian
and create a GraphVariationalGaussian
- it would enhance readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is all subclassed now. I'd reckon with further Graph Kernels things are likely to become messier with if statements.
Thanks for raising this PR @syedzayyan . Before merging, I would like to see my comments above addressed, and also we need some unit tests for this new functionality. Let me know if you need a hand with any of this! |
Hi, @thomaspinder I copied in the VariationalGaussian unit tests but I wouldn't know what to test beyond this. And I am running into typechecking errors 😬. Any help is much appreciated! |
… kernel to allow SVGP shapes and dtypes
Checklist
uv run poe format
before committing.Description
SVGP with Graphs
In #551 , problems around the graph kernel and SVGP was discussed. The PR deals with allowing current VariationalFamily code to work with GraphKernel. Recasting and casting the inducing points variables around jax grad messes up the pytree. Thus inducing point variables are kept out of the tree altogether by presenting them as just
jnp arrays
instead ofReal
variables.There are further casting issues with the matrix multiplications and the output from the GraphKernel which the ensure2d function mitigates and there are conditional checks to fire only when a GraphKernel is used.
The attached notebook demonstrates SVGP on a small sample graph, which could be adopted into a doc if need be.
All tests pass and the code is formatted.
graph_svgp.ipynb