Skip to content

Something like "MethodError: PeriodicTransform{Vector{Float32}} with SubArray in Zygote gradient of TransformedKernel (D_feats=1)" #573

@yuchenxiao95

Description

@yuchenxiao95

julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b73 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 24 × Intel(R) Xeon(R) Silver 4410Y
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, sapphirerapids)
Threads: 24 default, 0 interactive, 12 GC (on 24 virtual cores)
Environment:
JULIA_NUM_THREADS = 24
JULIA_EDITOR = code
JULIA_VSCODE_REPL = 1

using KernelFunctions, Optimisers, Zygote, LinearAlgebra, Statistics
using CUDA # To enable GPU arrays and operations
using PDMats # For handling positive definite matrices if needed
using Functors # Needed for traversing kernel parameters (though its role is minimized now)
using Base.Iterators: only # Keep if you want to use it for other purposes, but not for θ_current

--- Kernel and Parameter Management ---

---------------- Composite Kernel Suggestion ----------------

All parameters are now explicitly Float32 scalars from the start.

function suggest_composite_kernel(N::Int; period_guess::Real=12.0f0) # Ensure period_guess is Float32
l_trend = Float32(0.1 * N)
l_seasonal_period = Float32(0.7 * period_guess)
l_local_scale = Float32(0.3 * period_guess)

# Explicitly ensure scalar Float32 for all transform parameters
# This is a defensive measure against type widening issues
k_trend = TransformedKernel(Matern52Kernel(), ScaleTransform(Float32(1.0f0 / l_trend)))
k_seasonal_base = RationalQuadraticKernel()
k_seasonal = k_seasonal_base ∘ PeriodicTransform(Float32(l_seasonal_period))
k_local_periodic_base = PeriodicKernel()
k_local_periodic = k_local_periodic_base ∘ ScaleTransform(Float32(1.0f0 / l_local_scale))
k_linear = LinearKernel()

return k_trend + k_linear + k_seasonal + k_local_periodic

end

---------------- Kernel Matrix Construction ----------------

This function forces the kernel matrix computation to happen on the CPU

if the input data is a CuArray, then moves the result back to GPU.

function build_kernel_matrix(kernel, X::AbstractMatrix{Float32}, log_noise::Float32)
local K_raw
if X isa CuArray{Float32}
# FORCE KERNEL MATRIX COMPUTATION ON CPU AS A WORKAROUND
# This will be inefficient due to data transfer, but bypasses the Distances.jl/CUDA.jl issue
println("Warning: Forcing kernel matrix computation on CPU due to Distances.jl/CUDA.jl integration issue.")
println(" Performance will be severely impacted for this step.")

    # Convert CuArray X back to CPU Array
    X_cpu = Array(X) 
    # Compute kernel matrix on CPU
    K_raw_cpu = kernelmatrix(kernel, ColVecs(X_cpu), ColVecs(X_cpu))
    # Move the computed kernel matrix back to GPU
    K_raw = CuArray(K_raw_cpu) 
else
    # Original path for CPU arrays or if X is not a CuArray
    K_raw = kernelmatrix(kernel, ColVecs(X), ColVecs(X))
end

noise_var = exp(2 * log_noise)
# Ensure noise_diag is also on GPU if K_raw is on GPU (which it will be if X was CuArray)
noise_diag = Diagonal(fill(noise_var, size(K_raw, 1))) 

return K_raw + noise_diag

end

---------------- Log Marginal Likelihood ----------------

function log_marginal_likelihood(K::AbstractMatrix{Float32}, y::AbstractVector{Float32})
n = length(y)
try
F = cholesky(Hermitian(K))
α = F \ y
return -0.5f0 * (dot(y, α) + 2f0 * sum(log, diag(F)) + n * log(2f0 * π))
catch e
if isa(e, PosDefException) || isa(e, ArgumentError)
return -Inf32
else
rethrow(e)
end
end
end

---------------- GPU-Compatible Training Loop ----------------

function train_gp_gpu(X::AbstractMatrix{Float32}, y::AbstractVector{Float32};
max_epochs::Int=100, lr::Float32=0.01f0, use_gpu::Bool=CUDA.has_cuda())

# Ensure X is in the correct format for KernelFunctions.jl (features x observations).
if size(X, 1) != length(y)
    @assert size(X, 2) == length(y) "X must have dimensions (features x observations) or (observations x features)."
    X_reshaped = permutedims(X)
else
    X_reshaped = X
end

# Move data to GPU (CuArray) if use_gpu is true, otherwise keep on CPU.
X_device = use_gpu ? CuArray(X_reshaped) : X_reshaped
y_device = use_gpu ? CuArray(y) : y

initial_N = size(X_reshaped, 2)

# Kernel remains on CPU.
kernel_cpu = suggest_composite_kernel(initial_N, period_guess=12.0f0) 

# Assign the CPU kernel directly. No explicit GPU transfer for the kernel itself.
kernel_device = kernel_cpu 
if use_gpu
    println("Warning: Kernel object is on CPU. Using CuArray for data inputs. Performance might be impacted.")
    println("This bypasses the `ScaleTransform` MethodError for now.")
end

# Initialize the trainable parameter: log_noise.
log_noise = log(0.1f0 * std(Vector(y_device))) 
θ = Float32[log_noise] # Our only trainable parameter for this example
θ_device = use_gpu ? CuArray(θ) : θ # Move parameters to GPU

# Setup the Adam optimizer.
opt = Optimisers.Adam(lr)
state = Optimisers.setup(opt, θ_device) # Optimizer state is on the device

# Initialize best parameters and log marginal likelihood (LML).
best_θ_device = copy(θ_device)
best_lml = -Inf32

# Training loop.
for epoch in 1:max_epochs
    # Compute gradients of the negative log marginal likelihood with respect to θ_device.
    grads = Zygote.gradient(θ_device) do θ_current
        # Use `sum()` for GPU-compatible scalar extraction within Zygote
        current_log_noise = sum(θ_current) 
        K = build_kernel_matrix(kernel_device, X_device, current_log_noise)
        return -log_marginal_likelihood(K, y_device) # Minimize negative LML
    end

    # Update parameters using the optimizer if gradients are valid.
    if !isnothing(grads[1])
        state, θ_device = Optimisers.update(state, θ_device, grads[1])
        
        # Recalculate LML to track the true value after update.
        current_lml = -log_marginal_likelihood(
            build_kernel_matrix(kernel_device, X_device, sum(θ_device)), 
            y_device
        )

        # Update best parameters if current LML is better.
        if current_lml > best_lml
            best_lml = current_lml
            best_θ_device = copy(θ_device)
        end
    end

    # Print progress every 10 epochs.
    if epoch % 10 == 0
        println("[Epoch $epoch] LML = $(round(best_lml, digits=2))")
    end
end

# Convert the best learned log_noise parameter back to a CPU array for return.
final_log_noise_cpu = Array(best_θ_device)[1] # This is fine as it's after Array conversion

# Return the final kernel structure (from CPU) and the best learned noise.
return (kernel=kernel_cpu, log_noise=final_log_noise_cpu) # Return kernel_cpu

end

--- Example Usage ---

D_feats = 1 # Number of features (e.g., time)
N_points = 2000 # Number of data points (can be larger for GPU benefits)

Create synthetic time data (e.g., from 0 to 100)

X_train_time = Float32.(rand(N_points, D_feats) * 100.0f0)

Create synthetic observations with a sine wave, linear trend, and noise

y_train = Float32.(sin.(X_train_time[:,1] * 0.5f0) .+ (X_train_time[:,1] * 0.01f0) .+ 0.5f0 .* randn(Float32, N_points))

println("--- Starting GPU-compatible Gaussian Process Training ---")
println("Dataset size: $(N_points) observations, $(D_feats) feature.")
println("Using GPU: $(CUDA.has_cuda() ? "Yes" : "No (CUDA not detected or enabled)")")

Call the training function.

trained_params = train_gp_gpu(X_train_time, y_train, max_epochs=200, lr=0.005f0, use_gpu=true)

println("\n--- Training Complete ---")
println("Best trained log_noise: $(trained_params.log_noise)")
println("Inferred noise variance: $(exp(2 * trained_params.log_noise))")
println("Final kernel structure: $(trained_params.kernel)")

ERROR: MethodError: no method matching (::PeriodicTransform{Vector{Float32}})(::SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{…}, Int64}, true})
The object of type PeriodicTransform{Vector{Float32}} exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Closest candidates are:
(::PeriodicTransform)(::Real)
@ KernelFunctions C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\transform\periodic_transform.jl:28

Stacktrace:
[1] macro expansion
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context{…}, f::PeriodicTransform{…}, args::SubArray{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:81
[3] (::Zygote.var"#676#680"{Zygote.Context{…}, PeriodicTransform{…}})(args::SubArray{Float32, 1, Matrix{…}, Tuple{…}, true})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:188
[4] iterate
@ .\generator.jl:48 [inlined]
[5] _collect
@ .\array.jl:811 [inlined]
[6] collect_similar
@ .\array.jl:720 [inlined]
[7] map
@ .\abstractarray.jl:3371 [inlined]
[8] ∇map
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:188 [inlined]
[9] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:214 [inlined]
[10] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[11] _map
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\transform\transform.jl:21 [inlined]
[12] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._map), ::PeriodicTransform{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[13] kernelmatrix
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\transformedkernel.jl:117 [inlined]
[14] _pullback(::Zygote.Context{…}, ::typeof(kernelmatrix), ::TransformedKernel{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[15] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:946
[16] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined]
[17] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[18] _sum
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined]
[19] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[20] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:946
[21] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined]
[22] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[23] _sum
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined]
[24] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[25] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:946
[26] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined]
[27] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[28] _sum
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined]
[29] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[30] kernelmatrix
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:57 [inlined]
[31] _pullback(::Zygote.Context{…}, ::typeof(kernelmatrix), ::KernelSum{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[32] build_kernel_matrix
@ .\Untitled-1:41 [inlined]
[33] _pullback(::Zygote.Context{…}, ::typeof(build_kernel_matrix), ::KernelSum{…}, ::CuArray{…}, ::Float32)
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[34] #13
@ .\Untitled-1:120 [inlined]
[35] _pullback(ctx::Zygote.Context{…}, f::var"#13#14"{…}, args::CuArray{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[36] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.DeviceMemory})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:96
[37] pullback
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:94 [inlined]
[38] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:153
[39] train_gp_gpu(X::Matrix{Float32}, y::Vector{Float32}; max_epochs::Int64, lr::Float32, use_gpu::Bool)
@ Main .\Untitled-1:117
[40] top-level scope
@ Untitled-1:168
Some type information was truncated. Use show(err) to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions