Skip to content

Commit

Permalink
update change the ScoreGradELBO objective to be VarGrad underneath
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Nov 5, 2024
1 parent 35ea7ec commit dc23a02
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 57 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -50,7 +49,6 @@ Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
OnlineStats = "1"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Expand Down
71 changes: 16 additions & 55 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,63 +37,35 @@ To reduce the variance of the gradient estimator, we use a baseline computed fro
Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
"""
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <: AbstractVariationalObjective
entropy::EntropyEst
n_samples::Int
baseline_window_size::Int
end

function ScoreGradELBO(
n_samples::Int;
entropy::AbstractEntropyEstimator=MonteCarloEntropy(),
baseline_window_size::Int=10,
)
return ScoreGradELBO(entropy, n_samples, baseline_window_size)
end

function init(
::Random.AbstractRNG, obj::ScoreGradELBO, prob, params::AbstractVector{T}, restructure
) where {T<:Real}
return MovingWindow(T, obj.baseline_window_size)
end

function Base.show(io::IO, obj::ScoreGradELBO)
print(io, "ScoreGradELBO(entropy=")
print(io, obj.entropy)
print(io, ", n_samples=")
print(io, "ScoreGradELBO(n_samples=")
print(io, obj.n_samples)
print(io, ", baseline_window_size=")
print(io, obj.baseline_window_size)
return print(io, ")")
end

function estimate_objective(
rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples
)
samples = rand(rng, q, n_samples)
entropy = estimate_entropy(obj.entropy, samples, q)
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
return mean(energy) + entropy
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
return mean(ℓπ - ℓq)
end

function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples)
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_scoregradelbo_ad_forward(params′, aux)
@unpack rng, obj, logprobs, adtype, restructure, samples, q_stop, baseline = aux
@unpack rng, obj, logprob, adtype, restructure, samples = aux
q = restructure_ad_forward(adtype, restructure, params′)

ℓπ = logprob
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
ℓq_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples))
ℓπ_mean = mean(logprobs)
score_grad = mean(@. ℓq * (logprobs - baseline))
score_grad_stop = mean(@. ℓq_stop * (logprobs - baseline))

energy = ℓπ_mean + (score_grad - score_grad_stop)
entropy = estimate_entropy(obj.entropy, samples, q)

elbo = energy + entropy
return -elbo
f = ℓq - ℓπ
return var(f) / 2
end

function AdvancedVI.estimate_gradient!(
Expand All @@ -106,33 +78,22 @@ function AdvancedVI.estimate_gradient!(
restructure,
state,
)
baseline_buf = state
baseline_history = OnlineStats.value(baseline_buf)
baseline = if isempty(baseline_history)
zero(eltype(params))
else
mean(baseline_history)
end
q_stop = restructure(params)
samples = rand(rng, q_stop, obj.n_samples)
ℓprobs = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
q = restructure(params)
samples = rand(rng, q, obj.n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
aux = (
rng=rng,
adtype=adtype,
obj=obj,
logprobs=ℓprobs,
logprob=ℓπ,
restructure=restructure,
baseline=baseline,
samples=samples,
q_stop=q_stop,
)
AdvancedVI.value_and_gradient!(
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
if obj.baseline_window_size > 0
fit!(baseline_buf, -nelbo)
end
return out, baseline_buf, stat
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
elbo = mean(ℓπ - ℓq)
stat = (elbo=elbo,)
return out, nothing, stat
end

0 comments on commit dc23a02

Please sign in to comment.