Skip to content
Discussion options

You must be logged in to vote

Looking at the source code, rather than the docs, the implemented kernels have an additional check in in their constructor:

        if isinstance(lengthscale, nnx.Variable):
            self.lengthscale = lengthscale
        else:
            self.lengthscale = PositiveReal(lengthscale)

which seems to get round this. If I implement this logic:

class MyKernel(gpx.kernels.AbstractKernel):
   def __init__(self, a: float | nnx.Variable, *args, **kwargs):
       super().__init__(*args, **kwargs)
       
       if isinstance(a, nnx.Variable):
           self.a = a 
       else:
           self.a = gpx.parameters.PositiveReal(jnp.array(a))

   def __call__(self, x1: jax.Array, x2: jax.Array) -> jax

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@thomaspinder
Comment options

@theo-brown
Comment options

@theo-brown
Comment options

@theo-brown
Comment options

Answer selected by theo-brown
@thomaspinder
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants