Skip to content

TRPO #747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 11, 2022
Merged

TRPO #747

merged 14 commits into from
Sep 11, 2022

Conversation

baedan
Copy link
Contributor

@baedan baedan commented Aug 8, 2022

this PR implements Trust-Region Policy Optimization, and adds a CartPole experiment for it.

to this end, i wrote a few utility functions that are shared amongst policy gradient policies (#737). but perhaps a better way to go about it is to have a PolicyGradientPolicy type, and have it wrap different learners.

@findmyway findmyway self-requested a review August 8, 2022 10:26
@findmyway
Copy link
Member

Looks fine to me in general. I think there's still room for improvement in the gradient part. I'll add more detailed comments this weekend.


gps = gradient(params(A.model)) do
old_logits[] = A.model(s)
total_loss = map(eachcol(softmax(old_logits[])), a) do x, y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be simplified to logits[CartesianIndex.(a, 1:length(a))]

end

# store logits as intermediate value
old_logits = Ref{Matrix{Float32}}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a Ref is used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there were some oddities with Zygote 2nd order derivatives wrt implicit parameters when i tried local old_logits, yielding inconsistent results after mapreduce(vec, vcat, gradient) (i got a Ref(0) term sometimes as the first term). i'm not sure this is necessary anymore however, since i've since changed how the 2nd order gradient is calculated.

Comment on lines 105 to 109
for _ in 1:p.max_backtrack_step
θ = θₖ + Δ
search_condition(θ) && break
Δ = Δ * p.backtrack_coeff
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems in-line updating is fine here?

export action_distribution, policy_gradient_estimate, IsPolicyGradient
export conjugate_gradient!

struct IsPolicyGradient end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the usages, it seems a subtype of AbstractPolicyGradient <: AbstractPolicy is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so! or better yet, a PolicyGradient wrapper class perhaps?


function policy_gradient_estimate(::IsPolicyGradient, policy, states, actions, advantage)
gs = gradient(params(policy.approximator)) do
action_logits = action_distribution(policy.dist, policy.approximator(states))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

action_logits -> action_distribution?

```
"""
action_distribution(dist::Type{T}, model_output) where {T<:ContinuousDistribution} =
map(col -> dist(col...), eachcol(model_output))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has some extra assumptions here.

  1. Parameters of the distribution are of the same length and size (scalar to be more specific)
  2. The output of the model is a Matrix

Maybe we can figure out a more elegant way to use StructArrays.jl here later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the semantics of dist is pretty bad here. after this i realized punning dist wouldn't even work with a Normal, since we usually want the network to output the log of the variance, not the variance itself.

perhaps here we could just ask the user to specify distribution type with a trait, and overload a dist function (with a better name of course).

See [here](https://spinningup.openai.com/en/latest/algorithms/trpo.html#key-equations) for more information.
"""
function surrogate_advantage(model, states, actions, advantage, action_logits)
π_θₖ = map(eachcol(softmax(action_logits)), actions) do a, b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

π_θₖ = map(eachcol(softmax(action_logits)), actions) do a, b
a[b]
end
π_θ = map(eachcol(softmax(model(states))), actions) do a, b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@findmyway
Copy link
Member

I'll merge this first. I may find some time in the next week to polish this further ;)

@findmyway findmyway merged commit 0a344ce into JuliaReinforcementLearning:master Sep 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants