Skip to content

Conversation

syedzayyan
Copy link
Contributor

@syedzayyan syedzayyan commented Sep 26, 2025

Checklist

  • I've formatted the new code by running uv run poe format before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

SVGP with Graphs

In #551 , problems around the graph kernel and SVGP was discussed. The PR deals with allowing current VariationalFamily code to work with GraphKernel. Recasting and casting the inducing points variables around jax grad messes up the pytree. Thus inducing point variables are kept out of the tree altogether by presenting them as just jnp arrays instead of Real variables.

There are further casting issues with the matrix multiplications and the output from the GraphKernel which the ensure2d function mitigates and there are conditional checks to fire only when a GraphKernel is used.

The attached notebook demonstrates SVGP on a small sample graph, which could be adopted into a doc if need be.

All tests pass and the code is formatted.

graph_svgp.ipynb

Comment on lines 131 to 137
def ensure_2d(mat):
mat = jnp.asarray(mat)
if mat.ndim == 0:
return mat[None, None]
if mat.ndim == 1:
return mat[:, None]
return mat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you recycle JAX's at_least_2d here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into it and it seems jax's implementation reverses the shape, in contract to what I would need. Then I would need to do a reshape after passing through jax's at_least2d. Maybe a helper function doesn't hurt if I subclass?

From the docs:
an array or list of arrays corresponding to the input values. Arrays of shape () are converted to shape (1, 1), 1D arrays of shape (N,) are converted to shape (1, N), and arrays of all other shapes are returned unchanged.

return mat


class VariationalGaussian(AbstractVariationalGaussian[L]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of if-statements herein makes the code a little messy. Could you subclass VariationalGaussian and create a GraphVariationalGaussian - it would enhance readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is all subclassed now. I'd reckon with further Graph Kernels things are likely to become messier with if statements.

@thomaspinder
Copy link
Collaborator

Thanks for raising this PR @syedzayyan . Before merging, I would like to see my comments above addressed, and also we need some unit tests for this new functionality. Let me know if you need a hand with any of this!

@syedzayyan
Copy link
Contributor Author

Hi, @thomaspinder I copied in the VariationalGaussian unit tests but I wouldn't know what to test beyond this. And I am running into typechecking errors 😬. Any help is much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants