-
Notifications
You must be signed in to change notification settings - Fork 40
Description
julia> versioninfo()
Julia Version 1.11.5
Commit 760b2e5b73 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 24 × Intel(R) Xeon(R) Silver 4410Y
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, sapphirerapids)
Threads: 24 default, 0 interactive, 12 GC (on 24 virtual cores)
Environment:
JULIA_NUM_THREADS = 24
JULIA_EDITOR = code
JULIA_VSCODE_REPL = 1
using KernelFunctions, Optimisers, Zygote, LinearAlgebra, Statistics
using CUDA # To enable GPU arrays and operations
using PDMats # For handling positive definite matrices if needed
using Functors # Needed for traversing kernel parameters (though its role is minimized now)
using Base.Iterators: only # Keep if you want to use it for other purposes, but not for θ_current
--- Kernel and Parameter Management ---
---------------- Composite Kernel Suggestion ----------------
All parameters are now explicitly Float32 scalars from the start.
function suggest_composite_kernel(N::Int; period_guess::Real=12.0f0) # Ensure period_guess is Float32
l_trend = Float32(0.1 * N)
l_seasonal_period = Float32(0.7 * period_guess)
l_local_scale = Float32(0.3 * period_guess)
# Explicitly ensure scalar Float32 for all transform parameters
# This is a defensive measure against type widening issues
k_trend = TransformedKernel(Matern52Kernel(), ScaleTransform(Float32(1.0f0 / l_trend)))
k_seasonal_base = RationalQuadraticKernel()
k_seasonal = k_seasonal_base ∘ PeriodicTransform(Float32(l_seasonal_period))
k_local_periodic_base = PeriodicKernel()
k_local_periodic = k_local_periodic_base ∘ ScaleTransform(Float32(1.0f0 / l_local_scale))
k_linear = LinearKernel()
return k_trend + k_linear + k_seasonal + k_local_periodic
end
---------------- Kernel Matrix Construction ----------------
This function forces the kernel matrix computation to happen on the CPU
if the input data is a CuArray, then moves the result back to GPU.
function build_kernel_matrix(kernel, X::AbstractMatrix{Float32}, log_noise::Float32)
local K_raw
if X isa CuArray{Float32}
# FORCE KERNEL MATRIX COMPUTATION ON CPU AS A WORKAROUND
# This will be inefficient due to data transfer, but bypasses the Distances.jl/CUDA.jl issue
println("Warning: Forcing kernel matrix computation on CPU due to Distances.jl/CUDA.jl integration issue.")
println(" Performance will be severely impacted for this step.")
# Convert CuArray X back to CPU Array
X_cpu = Array(X)
# Compute kernel matrix on CPU
K_raw_cpu = kernelmatrix(kernel, ColVecs(X_cpu), ColVecs(X_cpu))
# Move the computed kernel matrix back to GPU
K_raw = CuArray(K_raw_cpu)
else
# Original path for CPU arrays or if X is not a CuArray
K_raw = kernelmatrix(kernel, ColVecs(X), ColVecs(X))
end
noise_var = exp(2 * log_noise)
# Ensure noise_diag is also on GPU if K_raw is on GPU (which it will be if X was CuArray)
noise_diag = Diagonal(fill(noise_var, size(K_raw, 1)))
return K_raw + noise_diag
end
---------------- Log Marginal Likelihood ----------------
function log_marginal_likelihood(K::AbstractMatrix{Float32}, y::AbstractVector{Float32})
n = length(y)
try
F = cholesky(Hermitian(K))
α = F \ y
return -0.5f0 * (dot(y, α) + 2f0 * sum(log, diag(F)) + n * log(2f0 * π))
catch e
if isa(e, PosDefException) || isa(e, ArgumentError)
return -Inf32
else
rethrow(e)
end
end
end
---------------- GPU-Compatible Training Loop ----------------
function train_gp_gpu(X::AbstractMatrix{Float32}, y::AbstractVector{Float32};
max_epochs::Int=100, lr::Float32=0.01f0, use_gpu::Bool=CUDA.has_cuda())
# Ensure X is in the correct format for KernelFunctions.jl (features x observations).
if size(X, 1) != length(y)
@assert size(X, 2) == length(y) "X must have dimensions (features x observations) or (observations x features)."
X_reshaped = permutedims(X)
else
X_reshaped = X
end
# Move data to GPU (CuArray) if use_gpu is true, otherwise keep on CPU.
X_device = use_gpu ? CuArray(X_reshaped) : X_reshaped
y_device = use_gpu ? CuArray(y) : y
initial_N = size(X_reshaped, 2)
# Kernel remains on CPU.
kernel_cpu = suggest_composite_kernel(initial_N, period_guess=12.0f0)
# Assign the CPU kernel directly. No explicit GPU transfer for the kernel itself.
kernel_device = kernel_cpu
if use_gpu
println("Warning: Kernel object is on CPU. Using CuArray for data inputs. Performance might be impacted.")
println("This bypasses the `ScaleTransform` MethodError for now.")
end
# Initialize the trainable parameter: log_noise.
log_noise = log(0.1f0 * std(Vector(y_device)))
θ = Float32[log_noise] # Our only trainable parameter for this example
θ_device = use_gpu ? CuArray(θ) : θ # Move parameters to GPU
# Setup the Adam optimizer.
opt = Optimisers.Adam(lr)
state = Optimisers.setup(opt, θ_device) # Optimizer state is on the device
# Initialize best parameters and log marginal likelihood (LML).
best_θ_device = copy(θ_device)
best_lml = -Inf32
# Training loop.
for epoch in 1:max_epochs
# Compute gradients of the negative log marginal likelihood with respect to θ_device.
grads = Zygote.gradient(θ_device) do θ_current
# Use `sum()` for GPU-compatible scalar extraction within Zygote
current_log_noise = sum(θ_current)
K = build_kernel_matrix(kernel_device, X_device, current_log_noise)
return -log_marginal_likelihood(K, y_device) # Minimize negative LML
end
# Update parameters using the optimizer if gradients are valid.
if !isnothing(grads[1])
state, θ_device = Optimisers.update(state, θ_device, grads[1])
# Recalculate LML to track the true value after update.
current_lml = -log_marginal_likelihood(
build_kernel_matrix(kernel_device, X_device, sum(θ_device)),
y_device
)
# Update best parameters if current LML is better.
if current_lml > best_lml
best_lml = current_lml
best_θ_device = copy(θ_device)
end
end
# Print progress every 10 epochs.
if epoch % 10 == 0
println("[Epoch $epoch] LML = $(round(best_lml, digits=2))")
end
end
# Convert the best learned log_noise parameter back to a CPU array for return.
final_log_noise_cpu = Array(best_θ_device)[1] # This is fine as it's after Array conversion
# Return the final kernel structure (from CPU) and the best learned noise.
return (kernel=kernel_cpu, log_noise=final_log_noise_cpu) # Return kernel_cpu
end
--- Example Usage ---
D_feats = 1 # Number of features (e.g., time)
N_points = 2000 # Number of data points (can be larger for GPU benefits)
Create synthetic time data (e.g., from 0 to 100)
X_train_time = Float32.(rand(N_points, D_feats) * 100.0f0)
Create synthetic observations with a sine wave, linear trend, and noise
y_train = Float32.(sin.(X_train_time[:,1] * 0.5f0) .+ (X_train_time[:,1] * 0.01f0) .+ 0.5f0 .* randn(Float32, N_points))
println("--- Starting GPU-compatible Gaussian Process Training ---")
println("Dataset size:
println("Using GPU: $(CUDA.has_cuda() ? "Yes" : "No (CUDA not detected or enabled)")")
Call the training function.
trained_params = train_gp_gpu(X_train_time, y_train, max_epochs=200, lr=0.005f0, use_gpu=true)
println("\n--- Training Complete ---")
println("Best trained log_noise: $(trained_params.log_noise)")
println("Inferred noise variance: $(exp(2 * trained_params.log_noise))")
println("Final kernel structure: $(trained_params.kernel)")
ERROR: MethodError: no method matching (::PeriodicTransform{Vector{Float32}})(::SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{…}, Int64}, true})
The object of type PeriodicTransform{Vector{Float32}}
exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.
Closest candidates are:
(::PeriodicTransform)(::Real)
@ KernelFunctions C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\transform\periodic_transform.jl:28
Stacktrace:
[1] macro expansion
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context{…}, f::PeriodicTransform{…}, args::SubArray{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:81
[3] (::Zygote.var"#676#680"{Zygote.Context{…}, PeriodicTransform{…}})(args::SubArray{Float32, 1, Matrix{…}, Tuple{…}, true})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:188
[4] iterate
@ .\generator.jl:48 [inlined]
[5] _collect
@ .\array.jl:811 [inlined]
[6] collect_similar
@ .\array.jl:720 [inlined]
[7] map
@ .\abstractarray.jl:3371 [inlined]
[8] ∇map
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:188 [inlined]
[9] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:214 [inlined]
[10] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[11] _map
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\transform\transform.jl:21 [inlined]
[12] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._map), ::PeriodicTransform{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[13] kernelmatrix
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\transformedkernel.jl:117 [inlined]
[14] _pullback(::Zygote.Context{…}, ::typeof(kernelmatrix), ::TransformedKernel{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[15] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:946
[16] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined]
[17] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[18] _sum
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined]
[19] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[20] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:946
[21] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined]
[22] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[23] _sum
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined]
[24] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[25] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:946
[26] adjoint
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined]
[27] _pullback
@ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined]
[28] _sum
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined]
[29] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[30] kernelmatrix
@ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:57 [inlined]
[31] _pullback(::Zygote.Context{…}, ::typeof(kernelmatrix), ::KernelSum{…}, ::ColVecs{…}, ::ColVecs{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[32] build_kernel_matrix
@ .\Untitled-1:41 [inlined]
[33] _pullback(::Zygote.Context{…}, ::typeof(build_kernel_matrix), ::KernelSum{…}, ::CuArray{…}, ::Float32)
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[34] #13
@ .\Untitled-1:120 [inlined]
[35] _pullback(ctx::Zygote.Context{…}, f::var"#13#14"{…}, args::CuArray{…})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0
[36] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.DeviceMemory})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:96
[37] pullback
@ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:94 [inlined]
[38] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory})
@ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:153
[39] train_gp_gpu(X::Matrix{Float32}, y::Vector{Float32}; max_epochs::Int64, lr::Float32, use_gpu::Bool)
@ Main .\Untitled-1:117
[40] top-level scope
@ Untitled-1:168
Some type information was truncated. Use show(err)
to see complete types.