Skip to content

Commit

Permalink
DefaultTrajectoryState <= GraphTrajectoryState (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
Max9294D authored Jul 13, 2023
1 parent 546ffe3 commit 352d022
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/CP/valueselection/learning/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function RLBase.prob(p::PPOPolicy, env::AbstractCPEnv{ST}) where {ST <: NonTabul
prob(p, s, mask)
end

function RLBase.prob(p::PPOPolicy{<:ActorCritic,Categorical}, state::DefaultTrajectoryState, mask)
function RLBase.prob(p::PPOPolicy{<:ActorCritic,Categorical}, state::GraphTrajectoryState, mask)
logits = p.approximator.actor(send_to_device(device(p.approximator), state))
if !isnothing(mask)
logits .+= ifelse.(mask, 0f0, typemin(Float32))
Expand Down

0 comments on commit 352d022

Please sign in to comment.