Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Particle filter example #53

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/particle-filter/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
162 changes: 162 additions & 0 deletions examples/particle-filter/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# # Partilce Filter with adaptive resampling
using StatsBase
using AbstractMCMC
using Random
using SSMProblems
using Distributions
using Plots
using StatsFuns
using Metal

# Filter
ess(weights) = inv(sum(abs2, weights))
get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights)
logZ(arr::AbstractArray) = StatsFuns.logsumexp(arr)


function rejection_resampling(rng::AbstractRNG, weights::AbstractArray{T}, n::Int=length(weights)) where {T<:Real}
Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function rejection_resampling(rng::AbstractRNG, weights::AbstractArray{T}, n::Int=length(weights)) where {T<:Real}
function rejection_resampling(
rng::AbstractRNG, weights::AbstractArray{T}, n::Int=length(weights)
) where {T<:Real}

w_max = maximum(weights) # Inefficient
a = zeros(Int, n)
for i in 1:n
j = i
u = rand(rng)
while log(u) > log(weights[j]) - log(w_max)
j = rand(rng, 1:n)
u = rand(rng)
end
a[i] = j
end
return a
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

"""
resample(rng::AbstractRNG, weights::AbstractArray, particles::AbstractArray)

Resample `particles` in-place
"""
function resample!(rng::AbstractRNG, weights::AbstractArray, particles::AbstractArray)
idx = rejection_resampling(rng, weights)
num_resamples = zeros(length(idx))
for i in idx
num_resamples[i] += 1
end
removed = findall(num_resamples .== 0)

for (i, num_children) in enumerate(num_resamples)
if num_children > 1
for _ in 2:num_children
j = popfirst!(removed)
particles[j] = particles[i]
end
end
end
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

"""
filter(rng::AbstractRNG, model::StateSpaceModel, N::Int, observations::AbstractArray, threshold::Real)

Estimate log-evidence using `N` particles. Resample particles when ESS falls below `N * threshold`.
"""
function filter(
rng::AbstractRNG,
model::StateSpaceModel,
N::Int,
observations::AbstractArray{T},
threshold::Real=0.5,
) where {T<:Real}
gpu_state = Metal.zeros(N; storage=Metal.Shared)
gpu_logweights = Metal.zeros(N; storage=Metal.Shared)

# Use unified memory option to avoid moving states and weights back and forth from the GPU
cpu_state = unsafe_wrap(Array{Float32}, gpu_state, size(gpu_state))
cpu_logweights = unsafe_wrap(Array{Float32}, gpu_logweights, size(gpu_logweights))

logevidence = 0
for (step, observation) in enumerate(observations)
#println(step)
weights = get_weights(cpu_logweights)
if ess(weights) <= threshold * N
resample!(rng, weights, cpu_state)
fill!(cpu_logweights, 0.0)
end

logZ0 = logZ(cpu_logweights)
Metal.@sync simulate!(model.dyn, step, gpu_state, nothing)
Metal.@sync logdensity!(
model.obs, gpu_logweights, step, gpu_state, observation, nothing
)
logZ1 = logZ(cpu_logweights)

logevidence += logZ1 - logZ0
end
return logevidence
end

# Model definition
struct LinearGaussianLatentDynamics{T<:Real} <: LatentDynamics
σ::T
end

struct LinearGaussianObservationProcess{T<:Real} <: ObservationProcess
σ::T
end

const LinearGaussianSSM{T} = StateSpaceModel{
<:LinearGaussianLatentDynamics{T},<:LinearGaussianObservationProcess{T}
};

function SSMProblems.distribution(
dyn::LinearGaussianLatentDynamics{T}, extra::Nothing
) where {T<:Real}
return Normal{T}(T(0), dyn.σ)
end

function SSMProblems.distribution(
dyn::LinearGaussianLatentDynamics{T}, step::Int, state::Real, extra::Nothing
) where {T<:Real}
return Normal{T}(state, dyn.σ)
end

function SSMProblems.distribution(
obs::LinearGaussianObservationProcess{T}, step::Int, state::Real, extra::Nothing
) where {T<:Real}
return Normal{T}(state, dyn.σ)
end

function simulate!(
dyn::LinearGaussianLatentDynamics{T}, step::Int, state::AbstractArray{T}, extra::Nothing
) where {T}
return state .= state .+ dyn.σ * Metal.randn(size(state)...)
end

function logdensity!(
obs::LinearGaussianObservationProcess{T},
arr::AbstractArray{T},
timestep::Int,
state::AbstractArray{T},
observation::T,
extra::Nothing,
) where {T<:Real}
return arr .+= normlogpdf.(state, (obs.σ,), (observation,))
end

# Simulation / Inference
Tn = 10
seed = 1
N = 500_000
rng = MersenneTwister(seed)

# Float32 required for GPU but leads to numerical instability in the resampling algorithm
# See https://www.sciencedirect.com/science/article/abs/pii/S0167819122000837
# for example of numerical instability with single precision
# For a numerically stable version of resampling algos https://arxiv.org/pdf/1301.4019
T = Float32
dyn = LinearGaussianLatentDynamics{T}(T(0.2))
obs = LinearGaussianObservationProcess{T}(T(0.7))
model = StateSpaceModel(dyn, obs)
xs, ys = sample(rng, model, Tn)

logevidence = filter(rng, model, N, ys)
println(logevidence)
Loading