From de5893f24ff265143eec5a83a394bed57f1cb296 Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:21:09 +0100 Subject: [PATCH] Update Docs for v0.11 release (#1056) * update run function * update docs * fix naming, update docs * fix random walk example * Add wrapper test * update docs * update docs * fix / update docs * bump trajectories * fix tests * add player type * syntax * migrate tictactoe * fix import * Add RLCore as dependency to RLEnvs * update player * Fix tests * Fix player state in abstract_learner.jl * type annotations * Add PlayerNamedTuple * Fix files * Simplify Player syntax * symbol -> player * Fix tests * fix * Move player struct * Fix tests * Fix typo * Fix * Fix player * Fix test * Fix Poker * Fix wrapper * Fix tests * Fix naming * Fix env tests * Fix KuhnPoker * Fix env * Fix type ambiguity * Fix pigenv * Fix tic tac toe * Fix errors --------- Co-authored-by: Jeremiah Lewis <--get> --- Project.toml | 4 +- docs/Project.toml | 2 +- docs/homepage/guide/index.md | 2 +- docs/src/How_to_implement_a_new_algorithm.md | 33 +++-- docs/src/How_to_use_hooks.md | 113 +++++++----------- .../How_to_write_a_customized_environment.md | 53 ++++---- docs/src/non_episodic.md | 31 ++--- docs/src/tips.md | 32 ++--- docs/src/tutorial.md | 52 ++++---- src/ReinforcementLearningBase/Project.toml | 2 +- .../src/interface.jl | 10 +- src/ReinforcementLearningCore/Project.toml | 4 +- .../src/core/core.jl | 1 + .../src/core/hooks.jl | 28 ++--- .../src/core/player.jl | 7 ++ .../src/core/stages.jl | 4 +- .../src/core/stop_conditions.jl | 30 +++-- .../src/policies/agent/agent_srt_cache.jl | 4 +- .../src/policies/agent/multi_agent.jl | 60 +++++++--- .../src/policies/learners/abstract_learner.jl | 4 +- .../src/policies/q_based_policy.jl | 2 +- .../src/policies/random_policy.jl | 2 +- .../test/core/core.jl | 1 + .../test/core/hooks.jl | 13 +- .../test/core/player.jl | 4 + .../test/core/stop_conditions.jl | 17 ++- .../policies/learners/abstract_learner.jl | 10 +- .../test/policies/multi_agent.jl | 111 +++++++++-------- .../test/policies/q_based_policy.jl | 2 +- .../Project.toml | 9 +- .../src/ReinforcementLearningEnvironments.jl | 1 + .../src/environments/3rd_party/pettingzoo.jl | 6 +- .../src/environments/examples/CartPoleEnv.jl | 4 +- .../src/environments/examples/KuhnPokerEnv.jl | 29 ++--- .../src/environments/examples/PigEnv.jl | 39 +++--- .../examples/RockPaperScissorsEnv.jl | 24 ++-- .../src/environments/examples/TicTacToeEnv.jl | 53 ++++---- .../environments/examples/TinyHanabiEnv.jl | 23 ++-- .../src/environments/wrappers/wrappers.jl | 8 +- .../environments/examples/random_walk_1d.jl | 42 +++++++ .../examples/rock_paper_scissors.jl | 6 +- .../test/environments/examples/tic_tac_toe.jl | 18 +-- .../test/environments/wrappers/wrappers.jl | 71 ++++++++++- .../test/runtests.jl | 1 + .../hooks/total_reward_per_last_n_episodes.jl | 4 +- .../hooks/total_reward_per_last_n_episodes.jl | 6 +- 46 files changed, 566 insertions(+), 416 deletions(-) create mode 100644 src/ReinforcementLearningCore/src/core/player.jl create mode 100644 src/ReinforcementLearningCore/test/core/player.jl diff --git a/Project.toml b/Project.toml index 300ab06ae..009748135 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,9 @@ ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" [compat] Reexport = "0.2, 1" -ReinforcementLearningBase = "0.12" +ReinforcementLearningBase = "0.13" ReinforcementLearningCore = "0.15" -ReinforcementLearningEnvironments = "0.8" +ReinforcementLearningEnvironments = "0.9" julia = "1.6" [extras] diff --git a/docs/Project.toml b/docs/Project.toml index b8f72ce4c..8d6fe746b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,6 @@ [deps] ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" -BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DemoCards = "311a05b2-6137-4a5a-b473-18580a3d38b5" diff --git a/docs/homepage/guide/index.md b/docs/homepage/guide/index.md index cf695c3dc..e7de146de 100644 --- a/docs/homepage/guide/index.md +++ b/docs/homepage/guide/index.md @@ -85,7 +85,7 @@ Usually a closure or a functional object will be used to store some intermediate In most cases, you don't need to write a customized hook. Some generic hooks are provided so that you can inject logic at the appropriate time: - [`DoEveryNSteps`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNSteps) -- [`DoEveryNEpisode`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNEpisode) +- [`DoEveryNEpisodes`](https://juliareinforcementlearning.org/ReinforcementLearning.jl/latest/rl_core/#ReinforcementLearningCore.DoEveryNEpisodes) However, if you do need to write a customized hook, the following methods must be provided: diff --git a/docs/src/How_to_implement_a_new_algorithm.md b/docs/src/How_to_implement_a_new_algorithm.md index 80a04f27f..20c3673ba 100644 --- a/docs/src/How_to_implement_a_new_algorithm.md +++ b/docs/src/How_to_implement_a_new_algorithm.md @@ -10,7 +10,6 @@ function _run(policy::AbstractPolicy, stop_condition::AbstractStopCondition, hook::AbstractHook, reset_condition::AbstractResetCondition) - push!(policy, PreExperimentStage(), env) is_stop = false while !is_stop @@ -18,17 +17,17 @@ function _run(policy::AbstractPolicy, push!(policy, PreEpisodeStage(), env) optimise!(policy, PreEpisodeStage()) - while !reset_condition(policy, env) # one episode + while !check!(reset_condition, policy, env) # one episode push!(policy, PreActStage(), env) optimise!(policy, PreActStage()) - RLBase.plan!(policy, env) + action = RLBase.plan!(policy, env) act!(env, action) push!(policy, PostActStage(), env, action) optimise!(policy, PostActStage()) - if check_stop(stop_condition, policy, env) + if check!(stop_condition, policy, env) is_stop = true break end @@ -36,17 +35,17 @@ function _run(policy::AbstractPolicy, push!(policy, PostEpisodeStage(), env) optimise!(policy, PostEpisodeStage()) + end push!(policy, PostExperimentStage(), env) hook end - ``` Implementing a new algorithm mainly consists of creating your own `AbstractPolicy` (or `AbstractLearner`, see [this section](#using-resources-from-rlcore)) subtype, its action sampling method (by overloading `Base.push!(policy::YourPolicyType, env)`) and implementing its behavior at each stage. However, ReinforcemementLearning.jl provides plenty of pre-implemented utilities that you should use to 1) have less code to write 2) lower the chances of bugs and 3) make your code more understandable and maintainable (if you intend to contribute your algorithm). ## Using Agents -The recommended way is to use the policy wrapper `Agent`. An agent is itself an `AbstractPolicy` that wraps a policy and a trajectory (also called Experience Replay Buffer in RL literature). Agent comes with default implementations of `push!(agent, stage, env)` and `plan!(agent, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are +The recommended way is to use the policy wrapper `Agent`. An agent is itself an `AbstractPolicy` that wraps a policy and a trajectory (also called Experience Replay Buffer in reinforcement learning literature). Agent comes with default implementations of `push!(agent, stage, env)` and `plan!(agent, env)` that will probably fit what you need at most stages so that you don't have to write them again. Looking at the [source code](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningCore/src/policies/agent.jl/), we can see that the default Agent calls are ```julia function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv) @@ -61,21 +60,21 @@ end The function `RLBase.plan!(agent::Agent, env::AbstractEnv)`, is called at the `action = RLBase.plan!(policy, env)` line. It simply gets an action from the policy of the agent by calling `RLBase.plan!(your_new_policy, env)` function. At the `PreEpisodeStage()`, the agent pushes the initial state to the trajectory. At the `PostActStage()`, the agent pushes the transition to the trajectory. -If you need a different behavior at some stages, then you can overload the `Base.push!(Agent{<:YourPolicyType}, [stage,] env)` or `Base.push!(Agent{<:Any, <: YourTrajectoryType}, [stage,] env)`, or `Base.plan!`, depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the logpdf of the sampled actions and thus overload the function at the `PreActStage()`. +If you need a different behavior at some stages, then you can overload the `Base.push!(Agent{<:YourPolicyType}, [stage,] env)` or `Base.push!(Agent{<:Any, <: YourTrajectoryType}, [stage,] env)`, or `Base.plan!`, depending on whether you have a custom policy or just a custom trajectory. For example, many algorithms (such as PPO) need to store an additional trace of the `logpdf` of the sampled actions and thus overload the function at the `PreActStage()`. ## Updating the policy Finally, you need to implement the learning function by implementing `RLBase.optimise!(::YourPolicyType, ::Stage, ::Trajectory)`. By default this does nothing at all stages. Overload it on the stage where you wish to optimise (most often, at `PostActStage()` or `PostEpisodeStage()`). This function should loop the trajectory to sample batches. Inside the loop, put whatever is required. For example: ```julia -function RLBase.optimise!(p::YourPolicyType, ::PostEpisodeStage, traj::Trajectory) - for batch in traj - optimise!(p, batch) +function RLBase.optimise!(policy::YourPolicyType, ::PostEpisodeStage, trajectory::Trajectory) + for batch in trajectory + optimise!(policy, batch) end end ``` -where `optimise!(p, batch)` is a function that will typically compute the gradient and update a neural network, or update a tabular policy. What is inside the loop is free to be whatever you need but it's a good idea to implement a `optimise!(p::YourPolicyType, batch::NamedTuple)` function for clarity instead of coding everything in the loop. This is further discussed in the next section on `Trajectory`s. An example of where this could be different is when you want to update priorities, see [the PER learner](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/main/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl) for an example. +where `optimise!(policy, batch)` is a function that will typically compute the gradient and update a neural network, or update a tabular policy. What is inside the loop is free to be whatever you need but it's a good idea to implement a `optimise!(policy::YourPolicyType, batch::NamedTuple)` function for clarity instead of coding everything in the loop. This is further discussed in the next section on `Trajectory`s. ## ReinforcementLearningTrajectories @@ -112,13 +111,13 @@ ReinforcementLearningTrajectories' design aims to eventually support distributed The sampler is the object that will fetch data in your trajectory to create the `batch` in the optimise for loop. The simplest one is the `BatchSampler{names}(batchsize, rng)`.`batchsize` is the number of elements to sample and `rng` is an optional argument that you may set to a custom rng for reproducibility. `names` is the set of traces the sampler must query. For example a `BatchSampler{(:state, :action, :next_state)}(32)` will sample a named tuple `(state = [32 states], action=[32 actions], next_state=[32 states that are one-off with respect that in state])`. -## Using resources from RLCore +## Using resources from ReinforcementLearningCore -RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage RLCore contains some modules that you can reuse to implement your algorithm. -These will take care of many aspects of training for you. See the [RLCore manual](./rlcore.md) +RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage ReinforcementLearningCore contains some modules that you can reuse to implement your algorithm. +These will take care of many aspects of training for you. See the [ReinforcementLearningCore manual](./rlcore.md) ### Utils -In utils/distributions.jl you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using Distributions.jl structs. +In `utils/distributions.jl` you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using `Distributions.jl` structs. ## Conventions Finally, there are a few "conventions" and good practices that you should follow, especially if you intend to contribute to this package (don't worry we'll be happy to help if needed). @@ -127,9 +126,9 @@ Finally, there are a few "conventions" and good practices that you should follow ReinforcementLearning.jl aims to provide a framework for reproducible experiments. To do so, make sure that your policy type has a `rng` field and that all random operations (e.g. action sampling) use `rand(your_policy.rng, args...)`. For trajectory sampling, you can set the sampler's rng to that of the policy when creating and agent or simply instantiate its own rng. ### GPU compatibility -Deep RL algorithms are often much faster when the neural nets are updated on a GPU. For now, we only support CUDA.jl as a backend. This means that you will have to think about the transfer of data between the CPU (where the trajectory is) and the GPU memory (where the neural nets are). To do so you will find in utils/device.jl some functions that do most of the work for you. The ones that you need to know are `send_to_device(device, data)` that sends data to the specified device, `send_to_host(data)` which sends data to the CPU memory (it fallbacks to `send_to_device(Val{:cpu}, data)`) and `device(x)` that returns the device on which `x` is. +Deep RL algorithms are often much faster when the neural nets are updated on a GPU. This means that you will have to think about the transfer of data between the CPU (where the trajectory is) and the GPU memory (where the neural nets are). `Flux.jl` offers `gpu` and `cpu` functions to make it easier to send data back and forth. Normally, you should be able to write a single implementation of your algorithm that works on CPU and GPUs thanks to the multiple dispatch offered by Julia. -GPU friendlyness will also require that your code does not use _scalar indexing_ (see the CUDA.jl documentation for more information), make sure to test your algorithm on the GPU after disallowing scalar indexing by using `CUDA.allowscalar(false)`. +GPU friendliness will also require that your code does not use _scalar indexing_ (see the `CUDA.jl` or `Metal.jl` documentation for more information); when using `CUDA.jl` make sure to test your algorithm on the GPU after disallowing scalar indexing by using `CUDA.allowscalar(false)`. Finally, it is a good idea to implement the `Flux.gpu(yourpolicy)` and `cpu(yourpolicy)` functions, for user convenience. Be careful that sampling on the GPU requires a specific type of rng, you can generate one with `CUDA.default_rng()` diff --git a/docs/src/How_to_use_hooks.md b/docs/src/How_to_use_hooks.md index b6938f93a..318a9e7a6 100644 --- a/docs/src/How_to_use_hooks.md +++ b/docs/src/How_to_use_hooks.md @@ -8,10 +8,12 @@ programming. We write the code in a loop and execute them step by step. ```julia while true - env |> policy |> env + action = plan!(policy, env) + act!(env, action) + # write your own logic here # like saving parameters, recording loss function, evaluating policy, etc. - stop_condition(env, policy) && break + check!(stop_condition, env, policy) && break is_terminated(env) && reset!(env) end ``` @@ -30,18 +32,19 @@ execution pipeline. However, we believe this is not necessary in Julia. With the declarative programming approach, we gain much more flexibilities. Now the question is how to design the hook. A natural choice is to wrap the -comments part in the above pseudocode into a function: +comments part in the above pseudo-code into a function: ```julia while true - env |> policy |> env - hook(policy, env) - stop_condition(env, policy) && break + action = plan!(policy, env) + act!(env, action) + push!(hook, policy, env) + check!(stop_condition, env, policy) && break is_terminated(env) && reset!(env) end ``` -But sometimes, we'd like to have a more fingrained control. So we split the calling +But sometimes, we'd like to have a more fine-grained control. So we split the calling of hooks into several different stages: - [`PreExperimentStage`](@ref) @@ -54,20 +57,22 @@ of hooks into several different stages: ## How to define a customized hook? By default, an instance of [`AbstractHook`](@ref) will do nothing when called -with `(hook::AbstractHook)(::AbstractStage, policy, env)`. So when writing a +with `push!(hook::AbstractHook, ::AbstractStage, policy, env)`. So when writing a customized hook, you only need to implement the necessary runtime logic. For example, assume we want to record the wall time of each episode. ```@repl how_to_use_hooks using ReinforcementLearning +import Base.push! Base.@kwdef mutable struct TimeCostPerEpisode <: AbstractHook t::UInt64 = time_ns() time_costs::Vector{UInt64} = [] end -(h::TimeCostPerEpisode)(::PreEpisodeStage, policy, env) = h.t = time_ns() -(h::TimeCostPerEpisode)(::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t) +Base.push!(h::TimeCostPerEpisode, ::PreEpisodeStage, policy, env) = h.t = time_ns() +Base.push!(h::TimeCostPerEpisode, ::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t) h = TimeCostPerEpisode() + run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h) h.time_costs ``` @@ -77,14 +82,13 @@ h.time_costs - [`StepsPerEpisode`](@ref) - [`RewardsPerEpisode`](@ref) - [`TotalRewardPerEpisode`](@ref) -- [`TotalBatchRewardPerEpisode`](@ref) ## Periodic jobs Sometimes, we'd like to periodically run some functions. Two handy hooks are provided for this kind of tasks: -- [`DoEveryNEpisode`](@ref) +- [`DoEveryNEpisodes`](@ref) - [`DoEveryNSteps`](@ref) Following are some typical usages. @@ -98,7 +102,7 @@ run( policy, CartPoleEnv(), StopAfterNEpisodes(100), - DoEveryNEpisode(;n=10) do t, policy, env + DoEveryNEpisodes(;n=10) do t, policy, env # In real world cases, the policy is usually wrapped in an Agent, # we need to extract the inner policy to run it in the *actor* mode. # Here for illustration only, we simply use the original policy. @@ -117,40 +121,33 @@ run( ### Save parameters -[BSON.jl](https://github.com/JuliaIO/BSON.jl) is recommended to save the parameters of a policy. +[JLD2.jl](https://github.com/JuliaIO/JLD2.jl) is recommended to save the parameters of a policy. ```@repl how_to_use_hooks -using Flux -using Flux.Losses: huber_loss -using BSON +using ReinforcementLearning +using JLD2 -env = CartPoleEnv(; T = Float32) -ns, na = length(state(env)), length(action_space(env)) +env = RandomWalk1D() +ns, na = length(state_space(env)), length(action_space(env)) policy = Agent( - policy = QBasedPolicy( - learner = BasicDQNLearner( - approximator = NeuralNetworkApproximator( - model = Chain( - Dense(ns, 128, relu; init = glorot_uniform), - Dense(128, 128, relu; init = glorot_uniform), - Dense(128, na; init = glorot_uniform), - ) |> cpu, - optimizer = Adam(), - ), - batchsize = 32, - min_replay_history = 100, - loss_func = huber_loss, - ), - explorer = EpsilonGreedyExplorer( - kind = :exp, - ϵ_stable = 0.01, - decay_steps = 500, + QBasedPolicy(; + learner = TDLearner( + TabularQApproximator(n_state = ns, n_action = na), + :SARS; ), + explorer = EpsilonGreedyExplorer(ϵ_stable=0.01), ), - trajectory = CircularArraySARTTrajectory( - capacity = 1000, - state = Vector{Float32} => (ns,), + Trajectory( + CircularArraySARTSTraces(; + capacity = 1, + state = Int64 => (), + action = Int64 => (), + reward = Float64 => (), + terminal = Bool => (), + ), + DummySampler(), + InsertSampleRatioController(), ), ) @@ -161,40 +158,10 @@ run( env, StopAfterNSteps(10_000), DoEveryNSteps(n=1_000) do t, p, e - ps = params(p) - f = joinpath(parameters_dir, "parameters_at_step_$t.bson") - BSON.@save f ps + ps = policy.policy.learner.approximator + f = joinpath(parameters_dir, "parameters_at_step_$t.jld2") + JLD2.@save f ps println("parameters at step $t saved to $f") end ) ``` - -### Logging data - -Below we demonstrate how to use -[TensorBoardLogger.jl](https://github.com/PhilipVinc/TensorBoardLogger.jl) to -log runtime metrics. But users could also other tools like -[wandb](https://wandb.ai/site) through -[PyCall.jl](https://github.com/JuliaPy/PyCall.jl). - - -```@repl how_to_use_hooks -using TensorBoardLogger -using Logging -tf_log_dir = "logs" -lg = TBLogger(tf_log_dir, min_level = Logging.Info) -total_reward_per_episode = TotalRewardPerEpisode() -hook = ComposedHook( - total_reward_per_episode, - DoEveryNEpisode() do t, agent, env - with_logger(lg) do - @info "training" reward = total_reward_per_episode.rewards[end] - end - end -) -run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(50), hook) -readdir(tf_log_dir) -``` - -Then run `tensorboard --logdir logs` and open the link on the screen in your -browser. (Obviously you need to install tensorboard first.) diff --git a/docs/src/How_to_write_a_customized_environment.md b/docs/src/How_to_write_a_customized_environment.md index 149dc7cf6..8b38d0c07 100644 --- a/docs/src/How_to_write_a_customized_environment.md +++ b/docs/src/How_to_write_a_customized_environment.md @@ -7,7 +7,7 @@ write many different kinds of environments based on interfaces defined in The most commonly used interfaces to describe reinforcement learning tasks is [OpenAI/Gym](https://gym.openai.com/). Inspired by it, we expand those -interfaces a little to utilize the multiple-dispatch in Julia and to cover +interfaces a little to utilize multiple-dispatch in Julia and to cover multi-agent environments. ## The Minimal Interfaces to Implement @@ -24,7 +24,7 @@ state_space(env::YourEnv) reward(env::YourEnv) is_terminated(env::YourEnv) reset!(env::YourEnv) -(env::YourEnv)(action) +act!(env::YourEnv, action) ``` ## An Example: The LotteryEnv @@ -55,7 +55,13 @@ The `LotteryEnv` has only one field named `reward`, by default it is initialized with `nothing`. Now let's implement the necessary interfaces: ```@repl customized_env -RLBase.action_space(env::LotteryEnv) = (:PowerRich, :MegaHaul, nothing) +struct LotteryAction{a} + function LotteryAction(a) + new{a}() + end +end + +RLBase.action_space(env::LotteryEnv) = LotteryAction.([:PowerRich, :MegaHaul, nothing]) ``` Here `RLBase` is just an alias for `ReinforcementLearningBase`. @@ -78,12 +84,13 @@ in the initial state again. The only left one is to implement the game logic: ```@repl customized_env -function (x::LotteryEnv)(action) - if action == :PowerRich +function RLBase.act!(x::LotteryEnv, action) + if action == LotteryAction(:PowerRich) x.reward = rand() < 0.01 ? 100_000_000 : -10 - elseif action == :MegaHaul + elseif action == LotteryAction(:MegaHaul) x.reward = rand() < 0.05 ? 1_000_000 : -10 - elseif isnothing(action) x.reward = 0 + elseif action == LotteryAction(nothing) + x.reward = 0 else @error "unknown action of $action" end @@ -102,12 +109,13 @@ RLBase.test_runnable!(env) It is a simple smell test which works like this: -``` +```julia n_episode = 10 for _ in 1:n_episode reset!(env) while !is_terminated(env) - env |> action_space |> rand |> env + action = rand(action_space(env)) + act!(env, action) end end ``` @@ -117,7 +125,7 @@ ReinforcementLearning.jl also work. Similar to the test above, let's try the [`RandomPolicy`](@ref) first: ```@repl customized_env -run(RandomPolicy(action_space(env)), env, StopAfterNEpisodes(1_000)) +run(RandomPolicy(action_space(env)), env, StopAfterNEpisodes(1_000)) ``` If no error shows up, then it means our environment at least works with @@ -141,21 +149,19 @@ Now suppose we'd like to use a tabular based monte carlo method to estimate the state-action value. ```@repl customized_env -using Flux: InvDecay p = QBasedPolicy( - learner = MonteCarloLearner(; - approximator=TabularQApproximator( - ;n_state = length(state_space(env)), + learner = TDLearner( + TabularQApproximator( + n_state = length(state_space(env)), n_action = length(action_space(env)), - opt = InvDecay(1.0) - ) + ), :SARS ), explorer = EpsilonGreedyExplorer(0.1) ) -p(env) +plan!(p, env) ``` -Oops, we get an error here. So what does it mean? +Oops, we get an error here. So what does it mean? Before answering this question, let's spend some time on understanding the policy we defined above. A [`QBasedPolicy`](@ref) @@ -168,9 +174,9 @@ by the `learner`. Inside of the [`MonteCarloLearner`](@ref), a That's the problem! A [`TabularQApproximator`](@ref) only accepts states of type `Int`. ```@repl customized_env -p.learner.approximator(1, 1) # Q(s, a) -p.learner.approximator(1) # [Q(s, a) for a in action_space(env)] -p.learner.approximator(false) +RLCore.forward(p.learner.approximator, 1, 1) # Q(s, a) +RLCore.forward(p.learner.approximator, 1) # [Q(s, a) for a in action_space(env)] +RLCore.forward(p.learner.approximator, false) ``` OK, now we know where the problem is. But how to fix it? @@ -191,7 +197,7 @@ wrapped_env = ActionTransformedEnv( action_mapping = i -> action_space(env)[i], action_space_mapping = _ -> Base.OneTo(3), ) -p(wrapped_env) +plan!(p, wrapped_env) ``` Nice job! Now we are ready to run the experiment: @@ -385,7 +391,8 @@ environments must take a collection of actions from different players as input. ```@repl customized_env rps = RockPaperScissorsEnv(); action_space(rps) -rps(rand(action_space(rps))) +action = plan!(RandomPolicy(), rps) +act!(rps, action) ``` ### [`ChanceStyle`](@ref) diff --git a/docs/src/non_episodic.md b/docs/src/non_episodic.md index 655990a31..cfd4e61d8 100644 --- a/docs/src/non_episodic.md +++ b/docs/src/non_episodic.md @@ -13,34 +13,35 @@ To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to ## Custom reset conditions -You can specify a custom `reset_condition` instead of using the built-in's. Your condition must be callable with the method `my_condition(policy, env)`. For example, here is how to implement a custom condition that checks for a terminal state but will also reset if the episode is too long: +You can specify a custom `reset_condition` instead of using the built-in's. Your condition must be callable with the method `RLCore.check!(my_condition, policy, env)`. For example, here is how to implement a custom condition that checks for a terminal state but will also reset if the episode is too long: ```julia +using ReinforcementLearning +import ReinforcementLearning: RLCore reset_n_steps = ResetAfterNSteps(10000) -function my_condition(policy, env) +struct MyCondition <: AbstractResetCondition end + +function RLCore.check!(my_condition::MyCondition, policy, env) terminal = is_terminated(env) - too_long = reset_n_steps(policy, env) + too_long = RLCore.check!(reset_n_steps, policy, env) return terminal || too_long end - -run(agent, env, stop_condition, hook, my_condition) +env = RandomWalk1D() +agent = RandomPolicy() +stop_condition = StopIfEnvTerminated() +hook = EmptyHook() +run(agent, env, stop_condition, hook, MyCondition()) ``` We can instead make a callable struct instead of a function to avoid the global `reset_n_step`. ```julia -mutable struct MyCondition -reset_after +mutable struct MyCondition1 <: AbstractResetCondition + reset_after end -(c::MyCondition)(policy, env) = is_terminated(env) || c.reset_after(policy, env) - -run(agent, env, stop_condition, hook, MyCondition(ResetAfterNSteps(10000))) -``` - -A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetIfEnvTerminated`: +RLCore.check!(c::MyCondition1, policy, env) = is_terminated(env) || RLCore.check!(c.reset_after, policy, env) -```julia -run(agent, env, stop_condition, hook, (p,e) -> is_terminated(e)) +run(agent, env, stop_condition, hook, MyCondition1(ResetAfterNSteps(10000))) ``` diff --git a/docs/src/tips.md b/docs/src/tips.md index e55c88c9e..6bd180846 100644 --- a/docs/src/tips.md +++ b/docs/src/tips.md @@ -2,32 +2,22 @@ ## How to setup local development environment? -You can activate the local development mode as follows: from the base project directory, -load `ReinforcementLearning` via `using ReinforcementLearning`. -Then run `ReinforcementLearning.activate_devmode!()`. +You can activate the local development mode as follows: from the base project directory, run: + +```julia +using Pkg +Pkg.develop(path="src/ReinforcementLearningBase") +Pkg.develop(path="src/ReinforcementLearningCore") +Pkg.develop(path="src/ReinforcementLearningEnvironments") +Pkg.develop(path="src/ReinforcementLearningFarm") # optional +``` + Sometimes, you may need to add some extra dependencies. Remember to switch the environment before adding new packages. For example, if you want to add -`Statistics` in `ReinforcementLearningBase`, first run `]activate +`Statistics` to `ReinforcementLearningBase`, first run `]activate src/ReinforcementLearningBase`, then `]add Statistics`. -## How to contribute a new experiment? - -We use the [DemoCards.jl](https://johnnychen94.github.io/DemoCards.jl/stable/) -to generate the documentation of all the experiments. If you want to contribute -a new experiment, simply create a `Your_Experiment.jl` file in a specific -algorithm category under the `docs/experiments` folder. -Note that this file should follow the format defined in -[Literate.jl](https://github.com/fredrikekre/Literate.jl). And then update the -`config.json` file correspondingly. If your experiment needs an extra -dependency, remember to update both `docs/Project.toml` and -`src/ReinforcementLearningExperiments/Project.toml`. - -!!! note - All the cells after the `#+ tangle=true` line in `Your_Experment.jl` will be extracted into the - `ReinforcementLearningExperiments` package automatically. This feature is - supported by [Weave.jl](https://weavejl.mpastell.com/stable/usage/#tangle). - ## How to enable debug timings for experiment runs? Call `RLCore.TimerOutputs.enable_debug_timings(RLCore)` and default timings for hooks, policies and optimization steps will be printed. How do I reset the timer? Call `RLCore.TimerOutputs.reset_timer!(RLCore.timer)`. How do I show the timer results? Call `RLCore.timer`. diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index f389498a1..1aa07057d 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -17,14 +17,18 @@ Let's get familiar with some basic interfaces first. ```@repl randomwalk1d using ReinforcementLearning env = RandomWalk1D() + S = state_space(env) s = state(env) # the initial position A = action_space(env) + is_terminated(env) + while true - env(rand(A)) + act!(env, rand(A)) is_terminated(env) && break end + state(env) reward(env) ``` @@ -48,41 +52,25 @@ run( ) ``` -The [`RandomPolicy`](@ref) simply draws a random element from the legal action -set at each step. Beyond that, we can also set the action at each position ahead -of time by using a [`TabularPolicy`](@ref). - -```@repl randomwalk1d -NS, NA = length(S), length(A) -policy = TabularPolicy(;table=Dict(zip(1:NS, fill(2, NS)))) -run( - policy, - RandomWalk1D(), - StopAfterNEpisodes(10), - TotalRewardPerEpisode() -) -``` - Next, let's introduce one of the most common policies, the [`QBasedPolicy`](@ref). It contains two parts, a state-action value function to estimate the estimated value of each state-action pair and an explorer to select which action to take based on the result of the state-action values. ```@repl randomwalk1d -using Flux: InvDecay policy = QBasedPolicy( - learner = MonteCarloLearner(; - approximator=TabularQApproximator( - ;n_state = NS, + learner = TDLearner( + TabularQApproximator( + n_state = NS, n_action = NA, - opt = InvDecay(1.0) - ) + ), + :SARS ), explorer = EpsilonGreedyExplorer(0.1) ) ``` -Here we choose the [`MonteCarloLearner`](@ref) and the +Here we choose the [`TDLearner`](@ref) and the [`EpsilonGreedyExplorer`](@ref). But you can also replace them with some other Q value learners or value explorers. Similar to what we did before, we can apply this policy to the `env` to estimate its performance. @@ -105,12 +93,24 @@ in this case. To run policies in the **learner** mode, a dedicated wrapper polic [`Agent`](@ref) is provided. ```@repl randomwalk1d +using ReinforcementLearningTrajectories + +trajectory = Trajectory( + ElasticArraySARTSTraces(; + state = Int64 => (), + action = Int64 => (), + reward = Float64 => (), + terminal = Bool => (), + ), + DummySampler(), + InsertSampleRatioController(), +) agent = Agent( - policy = policy, - trajectory = VectorSARTTrajectory() + policy = RandomPolicy(), + trajectory = trajectory ) run(agent, env, StopAfterNEpisodes(10), TotalRewardPerEpisode()) ``` -Here the [`VectorSARTTrajectory`](@ref) is used to store the **S**tate, +Here the [`Trajectory`](@ref) is used to store the **S**tate, **A**ction, **R**eward, is_**T**erminated info during interactions with the environment. diff --git a/src/ReinforcementLearningBase/Project.toml b/src/ReinforcementLearningBase/Project.toml index 69f9e26bf..5978559aa 100644 --- a/src/ReinforcementLearningBase/Project.toml +++ b/src/ReinforcementLearningBase/Project.toml @@ -1,7 +1,7 @@ name = "ReinforcementLearningBase" uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44" authors = ["Johanni Brea ", "Jun Tian "] -version = "0.12.3" +version = "0.13.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/ReinforcementLearningBase/src/interface.jl b/src/ReinforcementLearningBase/src/interface.jl index 47f57d781..65204c160 100644 --- a/src/ReinforcementLearningBase/src/interface.jl +++ b/src/ReinforcementLearningBase/src/interface.jl @@ -394,13 +394,14 @@ abstract type AbstractEpisodeStyle end # General ##### -@api struct DefaultPlayer end +@api abstract type AbstractPlayer end +@api struct DefaultPlayer <: AbstractPlayer end @api const DEFAULT_PLAYER = DefaultPlayer() -@api struct ChancePlayer end +@api struct ChancePlayer <: AbstractPlayer end @api const CHANCE_PLAYER = ChancePlayer() -@api struct SimultaneousPlayer end +@api struct SimultaneousPlayer <: AbstractPlayer end @api const SIMULTANEOUS_PLAYER = SimultaneousPlayer() @api struct Spector end @@ -450,6 +451,7 @@ Get all available actions from environment. See also: [`legal_action_space`](@ref) """ @multi_agent_env_api action_space(env::AbstractEnv, player=current_player(env)) +action_space(env::AbstractEnv, ::DefaultPlayer) = action_space(env) """ legal_action_space(env, player=current_player(env)) @@ -460,7 +462,7 @@ For environments of [`MINIMAL_ACTION_SET`](@ref), the result is the same with @multi_agent_env_api legal_action_space(env::AbstractEnv, player=current_player(env)) = legal_action_space(ActionStyle(env), env, player) -legal_action_space(::MinimalActionSet, env, player) = action_space(env) +legal_action_space(::MinimalActionSet, env, player::AbstractPlayer) = action_space(env) """ legal_action_space_mask(env, player=current_player(env)) -> AbstractArray{Bool} diff --git a/src/ReinforcementLearningCore/Project.toml b/src/ReinforcementLearningCore/Project.toml index 6f6987a59..1714fd9bb 100644 --- a/src/ReinforcementLearningCore/Project.toml +++ b/src/ReinforcementLearningCore/Project.toml @@ -36,8 +36,8 @@ GPUArrays = "8, 9, 10" Metal = "1.0" ProgressMeter = "1" Reexport = "1" -ReinforcementLearningBase = "0.12" -ReinforcementLearningTrajectories = "0.3.7" +ReinforcementLearningBase = "0.13" +ReinforcementLearningTrajectories = "0.4" Statistics = "1" StatsBase = "0.32, 0.33, 0.34" TimerOutputs = "0.5" diff --git a/src/ReinforcementLearningCore/src/core/core.jl b/src/ReinforcementLearningCore/src/core/core.jl index 6ade5aa16..052e88d12 100644 --- a/src/ReinforcementLearningCore/src/core/core.jl +++ b/src/ReinforcementLearningCore/src/core/core.jl @@ -1,3 +1,4 @@ +include("player.jl") include("stages.jl") include("stop_conditions.jl") include("hooks.jl") diff --git a/src/ReinforcementLearningCore/src/core/hooks.jl b/src/ReinforcementLearningCore/src/core/hooks.jl index b9482f540..f93afbb56 100644 --- a/src/ReinforcementLearningCore/src/core/hooks.jl +++ b/src/ReinforcementLearningCore/src/core/hooks.jl @@ -6,7 +6,7 @@ export AbstractHook, TotalRewardPerEpisode, BatchStepsPerEpisode, TimePerStep, - DoEveryNEpisode, + DoEveryNEpisodes, DoEveryNSteps, DoOnExit @@ -89,7 +89,7 @@ Base.getindex(h::StepsPerEpisode) = h.steps Base.push!(hook::StepsPerEpisode, ::PostActStage, args...) = hook.count += 1 -Base.push!(hook::StepsPerEpisode, stage::PostEpisodeStage, agent, env, ::Symbol) = Base.push!(hook, stage, agent, env) +Base.push!(hook::StepsPerEpisode, stage::PostEpisodeStage, agent, env, ::Player) = Base.push!(hook, stage, agent, env) function Base.push!(hook::StepsPerEpisode, ::PostEpisodeStage, agent, env) Base.push!(hook.steps, hook.count) @@ -123,10 +123,10 @@ function Base.push!(h::RewardsPerEpisode{T}, ::PreEpisodeStage, agent, env) wher push!(h.rewards, T[]) end -Base.push!(h::RewardsPerEpisode, s::PreEpisodeStage, agent, env, ::Symbol) = push!(h, s, agent, env) +Base.push!(h::RewardsPerEpisode, s::PreEpisodeStage, agent, env, ::Player) = push!(h, s, agent, env) Base.push!(h::RewardsPerEpisode, ::PostActStage, agent::P, env::E) where {P <: AbstractPolicy, E <: AbstractEnv} = push!(last(h.rewards), reward(env)) -Base.push!(h::RewardsPerEpisode, ::PostActStage, agent::P, env::E, player::Symbol) where {P <: AbstractPolicy, E <: AbstractEnv} = push!(last(h.rewards), reward(env, player)) +Base.push!(h::RewardsPerEpisode, ::PostActStage, agent::Policy, env::E, player::Player) where {Policy <: AbstractPolicy, E <: AbstractEnv, Player <: AbstractPlayer} = push!(last(h.rewards), reward(env, player)) ##### # TotalRewardPerEpisode @@ -155,7 +155,7 @@ end Base.getindex(h::TotalRewardPerEpisode) = h.rewards Base.push!(h::TotalRewardPerEpisode, ::PostActStage, agent::P, env::E) where {P <: AbstractPolicy, E <: AbstractEnv} = h.reward += reward(env) -Base.push!(h::TotalRewardPerEpisode, ::PostActStage, agent::P, env::E, player::Symbol) where {P <: AbstractPolicy, E <: AbstractEnv} = h.reward += reward(env, player) +Base.push!(h::TotalRewardPerEpisode, ::PostActStage, agent::P, env::E, player::Player) where {P <: AbstractPolicy, E <: AbstractEnv, Player <: AbstractPlayer} = h.reward += reward(env, player) function Base.push!(hook::TotalRewardPerEpisode, ::PostEpisodeStage, @@ -195,8 +195,8 @@ function Base.push!(hook::TotalRewardPerEpisode, stage::Union{PostEpisodeStage, PostExperimentStage}, agent, env, - player::Symbol -) + player::Player +) where {Player <: AbstractPlayer} push!(hook, stage, agent, @@ -245,8 +245,8 @@ function Base.push!(hook::BatchStepsPerEpisode, stage::PostActStage, agent, env, - player::Symbol -) + player::Player +) where {Player <: AbstractPlayer} push!(hook, stage, agent, @@ -310,21 +310,21 @@ function Base.push!(hook::DoEveryNSteps, ::PostActStage, agent, env) end """ - DoEveryNEpisode(f; n=1, t=0) + DoEveryNEpisodes(f; n=1, t=0) Execute `f(t, agent, env)` every `n` episode. `t` is a counter of episodes. """ -mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: AbstractHook +mutable struct DoEveryNEpisodes{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: AbstractHook f::F n::Int t::Int end -DoEveryNEpisode(f::F; n=1, t=0, stage::S=PostEpisodeStage()) where {S,F} = - DoEveryNEpisode{S,F}(f, n, t) +DoEveryNEpisodes(f::F; n=1, t=0, stage::S=PostEpisodeStage()) where {S,F} = + DoEveryNEpisodes{S,F}(f, n, t) -function Base.push!(hook::DoEveryNEpisode{S}, ::S, agent, env) where {S} +function Base.push!(hook::DoEveryNEpisodes{S}, ::S, agent, env) where {S} hook.t += 1 if hook.t % hook.n == 0 hook.f(hook.t, agent, env) diff --git a/src/ReinforcementLearningCore/src/core/player.jl b/src/ReinforcementLearningCore/src/core/player.jl new file mode 100644 index 000000000..af828682b --- /dev/null +++ b/src/ReinforcementLearningCore/src/core/player.jl @@ -0,0 +1,7 @@ +struct Player <: AbstractPlayer + name::Symbol + + function Player(name) + new(Symbol(name)) + end +end diff --git a/src/ReinforcementLearningCore/src/core/stages.jl b/src/ReinforcementLearningCore/src/core/stages.jl index 61e48f57d..815e61392 100644 --- a/src/ReinforcementLearningCore/src/core/stages.jl +++ b/src/ReinforcementLearningCore/src/core/stages.jl @@ -18,8 +18,8 @@ struct PostActStage <: AbstractStage end Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action) = nothing -Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing -Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action, ::Symbol) = nothing +Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Player) = nothing +Base.push!(p::AbstractPolicy, ::PostActStage, ::AbstractEnv, action, ::Player) = nothing RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing diff --git a/src/ReinforcementLearningCore/src/core/stop_conditions.jl b/src/ReinforcementLearningCore/src/core/stop_conditions.jl index 8004aa924..4d391896b 100644 --- a/src/ReinforcementLearningCore/src/core/stop_conditions.jl +++ b/src/ReinforcementLearningCore/src/core/stop_conditions.jl @@ -1,5 +1,5 @@ export AbstractStopCondition, StopAfterNSteps, - StopAfterNEpisodes, StopIfEnvTerminated, StopSignal, StopAfterNoImprovement, StopAfterNSeconds, ComposedStopCondition + StopAfterNEpisodes, StopIfEnvTerminated, StopSignal, StopAfterNoImprovement, StopAfterNSeconds, StopIfAll, StopIfAny import ProgressMeter @@ -7,24 +7,34 @@ import ProgressMeter abstract type AbstractStopCondition end ##### -# ComposedStopCondition +# AnyStopCondition ##### """ - ComposedStopCondition(stop_conditions...; reducer = any) + AnyStopCondition(stop_conditions...) -The result of `stop_conditions` is reduced by `reducer`. The default `reducer` is the `any` function, which means that the condition is true when any one of the `stop_conditions...` is true. Can be replaced by any function returning a boolean. For example `reducer = x->sum(x) >= 2` will require at least two of the conditions to be true. +The result of `stop_conditions` is reduced by `any`. """ -struct ComposedStopCondition{S,reducer} <: AbstractStopCondition +struct StopIfAny{S<:Tuple} <: AbstractStopCondition stop_conditions::S - reducer - function ComposedStopCondition(stop_conditions...; reducer = any) - new{typeof(stop_conditions),reducer}(stop_conditions, reducer) + function StopIfAny(stop_conditions...) + new{typeof(stop_conditions)}(stop_conditions) end end -function check!(s::ComposedStopCondition{S,R}, policy::P, env::E) where {S,R,P<:AbstractPolicy,E<:AbstractEnv} - s.reducer(check!(sc, policy, env) for sc in s.stop_conditions) +function check!(s::StopIfAny{S}, policy::P, env::E) where {S<:Tuple, P<:AbstractPolicy, E<:AbstractEnv} + any(check!.(s.stop_conditions, (policy,), (env,))) +end + +struct StopIfAll{S<:Tuple} <: AbstractStopCondition + stop_conditions::S + function StopIfAll(stop_conditions...) + new{typeof(stop_conditions)}(stop_conditions) + end +end + +function check!(s::StopIfAll{S}, policy::P, env::E) where {S<:Tuple, P<:AbstractPolicy, E<:AbstractEnv} + all(check!.(s.stop_conditions, (policy,), (env,))) end ##### diff --git a/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl b/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl index 7bd0bed80..1693af4b8 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/agent_srt_cache.jl @@ -27,12 +27,12 @@ struct SART{S,A,R,T} end # This method is used to push a state and action to a trace -function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SA) +function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTSTraces}, xs::SA) push!(ts.traces[1].trace, xs.state) push!(ts.traces[2].trace, xs.action) end -function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTTraces}, xs::SART) +function Base.push!(ts::Union{CircularArraySARTSTraces,ElasticArraySARTSTraces}, xs::SART) push!(ts.traces[1].trace, xs.state) push!(ts.traces[2].trace, xs.action) push!(ts.traces[3], xs.reward) diff --git a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl index d3b3fd3da..72704ca37 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl @@ -1,5 +1,4 @@ -export MultiAgentPolicy -export MultiAgentHook +export MultiAgentPolicy, MultiAgentHook, Player, PlayerTuple using Random # for RandomPolicy @@ -7,15 +6,38 @@ import Base.getindex import Base.iterate import Base.push! +""" + PlayerTuple + +A NamedTuple that maps players to their respective values. +""" +struct PlayerTuple{N,T} + data::NamedTuple{N,T} + + function PlayerTuple(data::Pair...) + nt = NamedTuple(first(item).name => last(item) for item in data) + new{typeof(nt).parameters...}(nt) + end + + function PlayerTuple(data::Base.Generator) + PlayerTuple(collect(data)...) + end +end + +Base.getindex(nt::PlayerTuple, player::Player) = nt.data[player.name] +Base.keys(nt::PlayerTuple) = Player.(keys(nt.data)) +Base.iterate(nt::PlayerTuple) = iterate(nt.data) +Base.iterate(nt::PlayerTuple, state) = iterate(nt.data, state) + """ MultiAgentPolicy(agents::NT) where {NT<: NamedTuple} MultiAgentPolicy is a policy struct that contains `<:AbstractPolicy` structs indexed by the player's symbol. """ -struct MultiAgentPolicy{names,T} <: AbstractPolicy - agents::NamedTuple{names,T} +struct MultiAgentPolicy{players,T} <: AbstractPolicy + agents::PlayerTuple{players, T} - function MultiAgentPolicy(agents::NamedTuple{names,T}) where {names,T} - new{names,T}(agents) + function MultiAgentPolicy(agents::PlayerTuple{players,T}) where {players,T} + new{players, T}(agents) end end @@ -23,11 +45,11 @@ end MultiAgentHook(hooks::NT) where {NT<: NamedTuple} MultiAgentHook is a hook struct that contains `<:AbstractHook` structs indexed by the player's symbol. """ -struct MultiAgentHook{names,T} <: AbstractHook - hooks::NamedTuple{names,T} +struct MultiAgentHook{players,T} <: AbstractHook + hooks::PlayerTuple{players,T} - function MultiAgentHook(hooks::NamedTuple{names,T}) where {names,T} - new{names,T}(hooks) + function MultiAgentHook(hooks::PlayerTuple{players,T}) where {players, T} + new{players,T}(hooks) end end @@ -48,10 +70,10 @@ function Base.iterate(current_player_iterator::CurrentPlayerIterator, state) end Base.iterate(p::MultiAgentPolicy) = iterate(p.agents) -Base.iterate(p::MultiAgentPolicy, s) = iterate(p.agents, s) +Base.iterate(p::MultiAgentPolicy, state) = iterate(p.agents, state) -Base.getindex(p::MultiAgentPolicy, s::Symbol) = p.agents[s] -Base.getindex(h::MultiAgentHook, s::Symbol) = h.hooks[s] +Base.getindex(p::MultiAgentPolicy, player::Player) = p.agents[player] +Base.getindex(h::MultiAgentHook, player::Player) = h.hooks[player] Base.keys(p::MultiAgentPolicy) = keys(p.agents) Base.keys(p::MultiAgentHook) = keys(p.hooks) @@ -186,7 +208,7 @@ function Base.push!(multiagent::MultiAgentPolicy, stage::S, env::E) where {S<:Ab end # Like in the single-agent case, push! at the PostActStage() calls push! on each player. -function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv, player::Symbol) +function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv, player::Player) push!(agent.trajectory, (state = state(env, player),)) end @@ -196,7 +218,7 @@ function Base.push!(multiagent::MultiAgentPolicy, s::PreEpisodeStage, env::E) wh end end -function RLBase.plan!(agent::Agent, env::AbstractEnv, player::Symbol) +function RLBase.plan!(agent::Agent, env::AbstractEnv, player::Player) RLBase.plan!(agent.policy, env, player) end @@ -214,7 +236,7 @@ function Base.push!(multiagent::MultiAgentPolicy, ::PostActStage, env::E, action end end -function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv, p::Symbol) +function Base.push!(agent::Agent, ::PostEpisodeStage, env::AbstractEnv, player::Player) if haskey(agent.trajectory, :next_action) action = RLBase.plan!(agent.policy, env, p) push!(agent.trajectory, PartialNamedTuple((action = action, ))) @@ -227,18 +249,18 @@ function Base.push!(hook::MultiAgentHook, stage::S, multiagent::MultiAgentPolicy end end -@inline function _push!(stage::AbstractStage, policy::P, env::E, player::Symbol, hook::H, hook_tuple...) where {P <: AbstractPolicy, E <: AbstractEnv, H <: AbstractHook} +@inline function _push!(stage::AbstractStage, policy::P, env::E, player::Player, hook::H, hook_tuple...) where {P <: AbstractPolicy, E <: AbstractEnv, H <: AbstractHook} push!(hook, stage, policy, env, player) _push!(stage, policy, env, player, hook_tuple...) end -_push!(stage::AbstractStage, policy::P, env::E, player::Symbol) where {P <: AbstractPolicy, E <: AbstractEnv} = nothing +_push!(stage::AbstractStage, policy::P, env::E, player::Player) where {P <: AbstractPolicy, E <: AbstractEnv} = nothing function Base.push!(composed_hook::ComposedHook{T}, stage::AbstractStage, policy::P, env::E, - player::Symbol + player::Player ) where {T <: Tuple, P <: AbstractPolicy, E <: AbstractEnv} _push!(stage, policy, env, player, composed_hook.hooks...) end diff --git a/src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl b/src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl index 09f0942a6..d78524c14 100644 --- a/src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl @@ -12,7 +12,7 @@ function forward(learner::L, env::E) where {L <: AbstractLearner, E <: AbstractE end # Take Learner and Environment, get state, send to RLCore.forward(Learner, State) -function forward(learner::L, env::E, player::Symbol) where {L <: AbstractLearner, E <: AbstractEnv} +function forward(learner::L, env::E, player::Player) where {L <: AbstractLearner, E <: AbstractEnv, Player <: AbstractPlayer} env |> (x -> state(x, player)) |> (x -> forward(learner, x)) end @@ -25,7 +25,7 @@ function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env: RLBase.plan!(explorer, forward(learner, env), legal_action_space_) end -function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player::Symbol) +function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player::AbstractPlayer) legal_action_space_ = RLBase.legal_action_space_mask(env, player) return RLBase.plan!(explorer, forward(learner, env, player), legal_action_space_) end diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl index bb51c1297..4d3f91384 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl @@ -31,7 +31,7 @@ function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExp RLBase.plan!(policy.explorer, policy.learner, env) end -function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:TDLearner,E<:AbstractEnv} +function RLBase.plan!(policy::QBasedPolicy{L,Ex}, env::E, player::Player) where {Ex<:AbstractExplorer,L<:TDLearner,E<:AbstractEnv, Player<:AbstractPlayer} RLBase.plan!(policy.explorer, policy.learner, env, player) end diff --git a/src/ReinforcementLearningCore/src/policies/random_policy.jl b/src/ReinforcementLearningCore/src/policies/random_policy.jl index 3719d478c..0937db6e0 100644 --- a/src/ReinforcementLearningCore/src/policies/random_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/random_policy.jl @@ -31,7 +31,7 @@ function RLBase.plan!(p::RandomPolicy{Nothing,RNG}, env::AbstractEnv) where {RNG return rand(p.rng, legal_action_space_) end -function RLBase.plan!(p::RandomPolicy{Nothing,RNG}, env::E, player::Symbol) where {E<:AbstractEnv, RNG<:AbstractRNG} +function RLBase.plan!(p::RandomPolicy{Nothing,RNG}, env::E, player::Player) where {E<:AbstractEnv, RNG<:AbstractRNG, Player <: AbstractPlayer} legal_action_space_ = RLBase.legal_action_space(env, player) return rand(p.rng, legal_action_space_) end diff --git a/src/ReinforcementLearningCore/test/core/core.jl b/src/ReinforcementLearningCore/test/core/core.jl index 8277118c5..0823ae2d1 100644 --- a/src/ReinforcementLearningCore/test/core/core.jl +++ b/src/ReinforcementLearningCore/test/core/core.jl @@ -1,4 +1,5 @@ include("base.jl") +include("player.jl") include("hooks.jl") include("reset_conditions.jl") include("stop_conditions.jl") diff --git a/src/ReinforcementLearningCore/test/core/hooks.jl b/src/ReinforcementLearningCore/test/core/hooks.jl index 13bc64765..c8243f137 100644 --- a/src/ReinforcementLearningCore/test/core/hooks.jl +++ b/src/ReinforcementLearningCore/test/core/hooks.jl @@ -1,4 +1,5 @@ struct MockHook <: AbstractHook end +struct TestPlayer <: AbstractPlayer end """ test_noop!(hook; stages=[PreActStage()]) @@ -10,7 +11,7 @@ function test_noop!(hook::AbstractHook; stages=[PreActStage(), PostActStage(), P env = RandomWalk1D() env.pos = 7 policy = RandomPolicy(legal_action_space(env)) - + player = TestPlayer() hook_fieldnames = fieldnames(typeof(hook)) for mode in [:MultiAgent, :SingleAgent] for stage in stages @@ -18,7 +19,7 @@ function test_noop!(hook::AbstractHook; stages=[PreActStage(), PostActStage(), P if mode == :SingleAgent push!(hook_copy, stage, policy, env) elseif mode == :MultiAgent - push!(hook_copy, stage, policy, env, :player_i) + push!(hook_copy, stage, policy, env, player) end for field_ in hook_fieldnames if getfield(hook, field_) isa Ref @@ -168,10 +169,10 @@ end @test env.pos == 2 end -@testset "DoEveryNEpisode" begin - h_1 = DoEveryNEpisode((hook, agent, env) -> (env.pos += 1); n=2, stage=PreEpisodeStage()) - h_2 = DoEveryNEpisode((hook, agent, env) -> (env.pos += 1); n=2, stage=PostEpisodeStage()) - h_3 = DoEveryNEpisode((hook, agent, env) -> (env.pos += 1); n=1) +@testset "DoEveryNEpisodes" begin + h_1 = DoEveryNEpisodes((hook, agent, env) -> (env.pos += 1); n=2, stage=PreEpisodeStage()) + h_2 = DoEveryNEpisodes((hook, agent, env) -> (env.pos += 1); n=2, stage=PostEpisodeStage()) + h_3 = DoEveryNEpisodes((hook, agent, env) -> (env.pos += 1); n=1) h_list = (h_1, h_2, h_3) stage_list = (PreEpisodeStage(), PostEpisodeStage(), PostEpisodeStage()) diff --git a/src/ReinforcementLearningCore/test/core/player.jl b/src/ReinforcementLearningCore/test/core/player.jl new file mode 100644 index 000000000..6f5ba7c78 --- /dev/null +++ b/src/ReinforcementLearningCore/test/core/player.jl @@ -0,0 +1,4 @@ +@testset "Player" begin + @test Player(1) == Player(Symbol(1)) + @test Player("test").name == :test +end diff --git a/src/ReinforcementLearningCore/test/core/stop_conditions.jl b/src/ReinforcementLearningCore/test/core/stop_conditions.jl index c41269ab8..c1e4c3fbe 100644 --- a/src/ReinforcementLearningCore/test/core/stop_conditions.jl +++ b/src/ReinforcementLearningCore/test/core/stop_conditions.jl @@ -11,15 +11,26 @@ import ReinforcementLearningCore.check! @test sum([check!(stop_condition, policy, env) for i in 1:20]) == 11 end -@testset "ComposedStopCondition" begin +@testset "StopIfAny" begin stop_10 = StopAfterNSteps(10) stop_3 = StopAfterNSteps(3) env = RandomWalk1D() policy = RandomPolicy(legal_action_space(env)) - composed_stop = ComposedStopCondition(stop_10, stop_3) - @test sum([check!(composed_stop, policy, env) for i in 1:20]) == 18 + composed_stop = StopIfAny(stop_10, stop_3) + @test sum([RLCore.check!(composed_stop, policy, env) for i in 1:20]) == 18 +end + +@testset "StopIfAll" begin + stop_10 = StopAfterNSteps(10) + stop_3 = StopAfterNSteps(3) + + env = RandomWalk1D() + policy = RandomPolicy(legal_action_space(env)) + + composed_stop = StopIfAll(stop_10, stop_3) + @test sum([RLCore.check!(composed_stop, policy, env) for i in 1:20]) == 11 end @testset "StopAfterNEpisodes" begin diff --git a/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl b/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl index ae3df0202..b6e7ec6c3 100644 --- a/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl +++ b/src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl @@ -15,7 +15,7 @@ struct MockLearner <: AbstractLearner end end RLBase.state(::MockEnv, ::Observation{Any}, ::DefaultPlayer) = 1 - RLBase.state(::MockEnv, ::Observation{Any}, ::Symbol) = 1 + RLBase.state(::MockEnv, ::Observation{Any}, ::Player) = 1 env = MockEnv() learner = MockLearner() @@ -23,7 +23,7 @@ struct MockLearner <: AbstractLearner end output = RLCore.forward(learner, env) @test output == Float64[1.0, 2.0] - output = RLCore.forward(learner, env, Symbol(1)) + output = RLCore.forward(learner, env, Player(1)) @test output == Float64[1.0, 2.0] end @@ -44,18 +44,18 @@ struct MockLearner <: AbstractLearner end @testset "Plan with Player" begin # Mock explorer, environment, and learner - function RLBase.action_space(::MockEnv, ::Symbol) + function RLBase.action_space(::MockEnv, ::Player) return [1, 2] end - function RLBase.plan!(::MockExplorer, learner::MockLearner, env::MockEnv, p::Symbol) + function RLBase.plan!(::MockExplorer, learner::MockLearner, env::MockEnv, p::Player) return rand(2) end env = MockEnv() learner = MockLearner() explorer = MockExplorer() - player = :player1 + player = Player(:player1) output = RLBase.plan!(explorer, learner, env, player) diff --git a/src/ReinforcementLearningCore/test/policies/multi_agent.jl b/src/ReinforcementLearningCore/test/policies/multi_agent.jl index 7102e71d1..2e43efac0 100644 --- a/src/ReinforcementLearningCore/test/policies/multi_agent.jl +++ b/src/ReinforcementLearningCore/test/policies/multi_agent.jl @@ -3,6 +3,20 @@ using ReinforcementLearningTrajectories using ReinforcementLearningBase using DomainSets + +@testset "Basic PlayerTuple tests" begin + nt = PlayerTuple(Player(1) => "test1", Player(2) => "test2") + @test nt.data == (; Symbol(1) => "test1", Symbol(2) => "test2") + @test typeof(nt).parameters == typeof(nt.data).parameters + @test nt[Player(1)] == "test1" + @test PlayerTuple(Player(i) => i for i in 1:2) == PlayerTuple(Player(1) => 1, Player(2) => 2) + + @test iterate(nt) == ("test1", 2) + @test iterate(nt, 1) == ("test1", 2) + collect(iterate(nt)) + +end + @testset "MultiAgentPolicy" begin trajectory_1 = Trajectory( CircularArraySARTSTraces(; capacity = 1), @@ -16,12 +30,14 @@ using DomainSets InsertSampleRatioController(n_inserted = -1), ) - multiagent_policy = MultiAgentPolicy((; - :Cross => Agent(RandomPolicy(), trajectory_1), - :Nought => Agent(RandomPolicy(), trajectory_2), - )) + multiagent_policy = MultiAgentPolicy( + PlayerTuple( + Player(:Cross) => Agent(RandomPolicy(), trajectory_1), + Player(:Nought) => Agent(RandomPolicy(), trajectory_2), + ) + ) - @test multiagent_policy.agents[:Cross].policy isa RandomPolicy + @test multiagent_policy.agents[Player(:Cross)].policy isa RandomPolicy end @testset "MultiAgentHook" begin @@ -34,8 +50,8 @@ end TimePerStep() ) - multiagent_hook = MultiAgentHook((; :Cross => composed_hook, :Nought => EmptyHook())) - @test multiagent_hook.hooks[:Cross][3] isa StepsPerEpisode + multiagent_hook = MultiAgentHook(PlayerTuple(Player(:Cross) => composed_hook, Player(:Nought) => EmptyHook())) + @test multiagent_hook.hooks[Player(:Cross)][3] isa StepsPerEpisode end @testset "CurrentPlayerIterator" begin @@ -48,7 +64,7 @@ end RLBase.act!(env, 1) i == 2 && break end - @test player_log == [:Cross, :Nought] + @test player_log == Player.([:Cross, :Nought]) end @testset "Basic TicTacToeEnv (Sequential) env checks" begin @@ -64,41 +80,41 @@ end InsertSampleRatioController(n_inserted = -1), ) - multiagent_policy = MultiAgentPolicy((; - :Cross => Agent(RandomPolicy(), trajectory_1), - :Nought => Agent(RandomPolicy(), trajectory_2), + multiagent_policy = MultiAgentPolicy(PlayerTuple( + Player(:Cross) => Agent(RandomPolicy(), trajectory_1), + Player(:Nought) => Agent(RandomPolicy(), trajectory_2), )) - multiagent_hook = MultiAgentHook((; :Cross => StepsPerEpisode(), :Nought => StepsPerEpisode())) + multiagent_hook = MultiAgentHook(PlayerTuple(Player(:Cross) => StepsPerEpisode(), Player(:Nought) => StepsPerEpisode())) env = TicTacToeEnv() stop_condition = StopIfEnvTerminated() hook = StepsPerEpisode() - @test RLBase.reward(env, :Cross) == 0 + @test RLBase.reward(env, Player(:Cross)) == 0 @test length(RLBase.legal_action_space(env)) == 9 Base.run(multiagent_policy, env, Sequential(), stop_condition, multiagent_hook) - @test multiagent_hook.hooks[:Nought].steps[1] > 0 - @test multiagent_hook.hooks[:Cross].steps[1] > 0 + @test multiagent_hook.hooks[Player(:Nought)].steps[1] > 0 + @test multiagent_hook.hooks[Player(:Cross)].steps[1] > 0 @test RLBase.is_terminated(env) - @test RLEnvs.is_win(env, :Cross) isa Bool - @test RLEnvs.is_win(env, :Nought) isa Bool - @test RLBase.reward(env, :Cross) == (RLBase.reward(env, :Nought) * -1) - @test RLBase.legal_action_space_mask(env, :Cross) == falses(9) + @test RLEnvs.is_win(env, Player(:Cross)) isa Bool + @test RLEnvs.is_win(env, Player(:Nought)) isa Bool + @test RLBase.reward(env, Player(:Cross)) == (RLBase.reward(env, Player(:Nought)) * -1) + @test RLBase.legal_action_space_mask(env, Player(:Cross)) == falses(9) @test RLBase.legal_action_space(env) == [] - @test RLBase.state(env, Observation{BitArray{3}}(), :Cross) isa BitArray{3} - @test RLBase.state_space(env, Observation{BitArray{3}}(), :Cross) isa ArrayProductDomain - @test RLBase.state_space(env, Observation{String}(), :Cross) isa DomainSets.FullSpace{String} - @test RLBase.state(env, Observation{String}(), :Cross) isa String + @test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) isa BitArray{3} + @test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain + @test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String} + @test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String @test RLBase.state(env, Observation{String}()) isa String end @testset "next_player!" begin env = TicTacToeEnv() - @test RLBase.next_player!(env) == :Nought + @test RLBase.next_player!(env) == Player(:Nought) end @testset "Basic RockPaperScissors (simultaneous) env checks" begin @@ -114,19 +130,19 @@ end InsertSampleRatioController(n_inserted = -1), ) - @test MultiAgentPolicy((; - Symbol(1) => Agent(RandomPolicy(), trajectory_1), - Symbol(2) => Agent(RandomPolicy(), trajectory_2), + @test MultiAgentPolicy(PlayerTuple( + Player(1) => Agent(RandomPolicy(), trajectory_1), + Player(2) => Agent(RandomPolicy(), trajectory_2), )) isa MultiAgentPolicy - @test MultiAgentPolicy((; - Symbol(1) => Agent(RandomPolicy(), trajectory_1), - Symbol(2) => Agent(RandomPolicy(), trajectory_2), + @test MultiAgentPolicy(PlayerTuple( + Player(1) => Agent(RandomPolicy(), trajectory_1), + Player(2) => Agent(RandomPolicy(), trajectory_2), )) isa MultiAgentPolicy - multiagent_policy = MultiAgentPolicy((; - Symbol(1) => Agent(RandomPolicy(), trajectory_1), - Symbol(2) => Agent(RandomPolicy(), trajectory_2), + multiagent_policy = MultiAgentPolicy(PlayerTuple( + Player(1) => Agent(RandomPolicy(), trajectory_1), + Player(2) => Agent(RandomPolicy(), trajectory_2), )) env = RockPaperScissorsEnv() @@ -139,34 +155,35 @@ end TimePerStep() ) - multiagent_hook = MultiAgentHook((; Symbol(1) => composed_hook, Symbol(2) => EmptyHook())) + multiagent_hook = MultiAgentHook(PlayerTuple(Player(1) => composed_hook, Player(2) => EmptyHook())) @test Base.iterate(RLCore.CurrentPlayerIterator(env))[1] == SimultaneousPlayer() @test Base.iterate(RLCore.CurrentPlayerIterator(env), env)[1] == SimultaneousPlayer() @test Base.iterate(multiagent_policy)[1] isa Agent @test Base.iterate(multiagent_policy, 1)[1] isa Agent - @test Base.getindex(multiagent_policy, Symbol(1)) isa Agent - @test Base.getindex(multiagent_hook, Symbol(1))[3] isa StepsPerEpisode + @test Base.getindex(multiagent_policy, Player(1)) isa Agent + @test Base.getindex(multiagent_hook, Player(1))[3] isa StepsPerEpisode - @test Base.keys(multiagent_policy) == (Symbol(1), Symbol(2)) - @test Base.keys(multiagent_hook) == (Symbol(1), Symbol(2)) + @test Base.keys(multiagent_policy) == (Player(1), Player(2)) + @test Base.keys(multiagent_hook) == (Player(1), Player(2)) @test length(RLBase.legal_action_space(env)) == 9 Base.run(multiagent_policy, env, stop_condition, multiagent_hook) - @test multiagent_hook[Symbol(1)][1].steps[1][1] == 1 - @test -1 <= multiagent_hook[Symbol(1)][2].rewards[1][1] <= 1 - @test multiagent_hook[Symbol(1)][3].steps[1] == 1 - @test -1 <= multiagent_hook[Symbol(1)][4].rewards[1][1] <= 1 - @test 0 <= multiagent_hook[Symbol(1)][5].times[1] <= 5 + @test multiagent_hook[Player(1)][1].steps[1][1] == 1 + @test -1 <= multiagent_hook[Player(1)][2].rewards[1][1] <= 1 + @test multiagent_hook[Player(1)][3].steps[1] == 1 + @test -1 <= multiagent_hook[Player(1)][4].rewards[1][1] <= 1 + @test 0 <= multiagent_hook[Player(1)][5].times[1] <= 5 # Add more hook tests here... # TODO: Split up TicTacToeEnv and MultiAgent tests @test RLBase.is_terminated(env) - @test RLBase.legal_action_space(env) == () - @test RLBase.action_space(env, Symbol(1)) == ('💎', '📃', '✂') + @test RLBase.legal_action_space(env) == action_space(env) + @test RLBase.legal_action_space(env, Player(1)) == () + @test RLBase.action_space(env, Player(1)) == ('💎', '📃', '✂') env = RockPaperScissorsEnv() push!(multiagent_policy, PreActStage(), env) a = RLBase.plan!(multiagent_policy, env) @@ -178,8 +195,8 @@ end @testset "Sequential Environments correctly ended by termination signal" begin #rng = StableRNGs.StableRNG(123) e = TicTacToeEnv(); - m = MultiAgentPolicy(NamedTuple((player => RandomPolicy() for player in players(e)))) - hooks = MultiAgentHook(NamedTuple((p => EmptyHook() for p ∈ players(e)))) + m = MultiAgentPolicy(PlayerTuple(player => RandomPolicy() for player in players(e))) + hooks = MultiAgentHook(PlayerTuple(p => EmptyHook() for p ∈ players(e))) let err = nothing try diff --git a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl index cbe3382ab..ea9bd8f34 100644 --- a/src/ReinforcementLearningCore/test/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/test/policies/q_based_policy.jl @@ -25,7 +25,7 @@ learner = TDLearner(q_approx, :SARS) explorer = EpsilonGreedyExplorer(0.1) policy = QBasedPolicy(learner, explorer) - player = :player1 + player = Player(:player1) @test 1 <= RLBase.plan!(policy, env) <= 9 end end diff --git a/src/ReinforcementLearningEnvironments/Project.toml b/src/ReinforcementLearningEnvironments/Project.toml index fbc7884f0..818e9d01d 100644 --- a/src/ReinforcementLearningEnvironments/Project.toml +++ b/src/ReinforcementLearningEnvironments/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningEnvironments" uuid = "25e41dd2-4622-11e9-1641-f1adca772921" -version = "0.8.9" +version = "0.9.0" [deps] CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98" @@ -12,6 +12,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" +ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -22,7 +23,7 @@ DelimitedFiles = "1" IntervalSets = "0.7" MacroTools = "0.5" OpenSpiel = "0.2.1" -ReinforcementLearningBase = "0.12" +ReinforcementLearningBase = "0.13" ReinforcementLearningCore = "0.15" Requires = "1.0" StatsBase = "0.32, 0.33, 0.34" @@ -30,12 +31,12 @@ julia = "1.6" [extras] ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -44,12 +45,12 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [targets] test = [ "ArcadeLearningEnvironment", + "JLD2", "Conda", "DomainSets", "OpenSpiel", "OrdinaryDiffEq", "PyCall", - "ReinforcementLearningCore", "StableRNGs", "Statistics", "Test", diff --git a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl index 39d57ea6b..b2ab9b700 100644 --- a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl +++ b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl @@ -1,6 +1,7 @@ module ReinforcementLearningEnvironments using ReinforcementLearningBase +using ReinforcementLearningCore using Requires using Random diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/pettingzoo.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/pettingzoo.jl index 38aa4542b..16999469f 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/pettingzoo.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/pettingzoo.jl @@ -62,7 +62,7 @@ end RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, players) = Space(Dict(player => state_space(env, player) for player in players)) # partial observability -RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::Symbol) = space_transform(env.pyenv.observation_space(String(player))) +RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::Player) = space_transform(env.pyenv.observation_space(String(player))) # for full observability. Be careful: action_space has also to be adjusted # RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::String) = space_transform(env.pyenv.state_space) @@ -73,7 +73,7 @@ RLBase.state_space(env::PettingZooEnv, ::Observation{Any}, player::Symbol) = spa RLBase.action_space(env::PettingZooEnv, players::Tuple{Symbol}) = Space(Dict(p => action_space(env, p) for p in players)) -RLBase.action_space(env::PettingZooEnv, player::Symbol) = space_transform(env.pyenv.action_space(String(player))) +RLBase.action_space(env::PettingZooEnv, player::Player) = space_transform(env.pyenv.action_space(String(player))) RLBase.action_space(env::PettingZooEnv, player::Integer) = space_transform(env.pyenv.action_space(env.pyenv.agents[player])) @@ -119,7 +119,7 @@ function RLBase.act!(env::PettingZooEnv, action) end # reward of player ====================================================================================================================== -function RLBase.reward(env::PettingZooEnv, player::Symbol) +function RLBase.reward(env::PettingZooEnv, player::Player) env.pyenv.rewards[String(player)] end diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/CartPoleEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/CartPoleEnv.jl index 3d33fbc9d..4d729589a 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/CartPoleEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/CartPoleEnv.jl @@ -92,8 +92,8 @@ function RLBase.state_space(env::CartPoleEnv{T}) where {T} (typemin(T) .. typemax(T)) end -RLBase.action_space(env::CartPoleEnv{<:AbstractFloat,Int}, player) = Base.OneTo(2) -RLBase.action_space(env::CartPoleEnv{<:AbstractFloat,<:AbstractFloat}, player) = -1.0 .. 1.0 +RLBase.action_space(env::CartPoleEnv{<:AbstractFloat,Int}, ::DefaultPlayer) = Base.OneTo(2) +RLBase.action_space(env::CartPoleEnv{<:AbstractFloat,<:AbstractFloat}, ::DefaultPlayer) = -1.0 .. 1.0 function RLBase.reset!(env::CartPoleEnv{T}) where {T} env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/KuhnPokerEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/KuhnPokerEnv.jl index 39f9cb03c..b07f69a8c 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/KuhnPokerEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/KuhnPokerEnv.jl @@ -1,5 +1,6 @@ export KuhnPokerEnv +const KUHN_POKER_PLAYER_INDEX = Dict(Player(1) => 1, Player(2) => 2, CHANCE_PLAYER => 0) const KUHN_POKER_CARDS = (:J, :Q, :K) const KUHN_POKER_CARD_COMBINATIONS = ((:J, :Q), (:J, :K), (:Q, :J), (:Q, :K), (:K, :J), (:K, :Q)) @@ -87,11 +88,11 @@ end RLBase.is_terminated(env::KuhnPokerEnv) = length(env.actions) == 2 && (env.actions[1] == :bet || env.actions[2] == :pass) || length(env.actions) == 3 -RLBase.players(env::KuhnPokerEnv) = (1, 2, CHANCE_PLAYER) +RLBase.players(env::KuhnPokerEnv) = (Player(1), Player(2), CHANCE_PLAYER) -function RLBase.state(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, p::Int) - if length(env.cards) >= p - (env.cards[p], env.actions...) +function RLBase.state(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, player::Player) + if length(env.cards) >= KUHN_POKER_PLAYER_INDEX[player] + (env.cards[KUHN_POKER_PLAYER_INDEX[player]], env.actions...) else () end @@ -99,9 +100,9 @@ end RLBase.state(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, ::ChancePlayer) = Tuple(env.cards) -RLBase.state_space(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, p) = KUHN_POKER_STATES +RLBase.state_space(env::KuhnPokerEnv, ::InformationSet{Tuple{Vararg{Symbol}}}, player::AbstractPlayer) = KUHN_POKER_STATES -RLBase.action_space(env::KuhnPokerEnv, ::Int) = Base.OneTo(length(KUHN_POKER_ACTIONS)) +RLBase.action_space(env::KuhnPokerEnv, ::Player) = Base.OneTo(length(KUHN_POKER_ACTIONS)) RLBase.action_space(env::KuhnPokerEnv, ::ChancePlayer) = Base.OneTo(length(KUHN_POKER_CARDS)) RLBase.legal_action_space(env::KuhnPokerEnv, p::ChancePlayer) = Tuple(x for x in action_space(env, p) if KUHN_POKER_CARDS[x] ∉ env.cards) @@ -125,17 +126,17 @@ function RLBase.prob(env::KuhnPokerEnv, ::ChancePlayer) end end -RLBase.act!(env::KuhnPokerEnv, action::Int, p::Int) = RLBase.act!(env, KUHN_POKER_ACTIONS[action], p) +RLBase.act!(env::KuhnPokerEnv, action::Int, p::Player) = RLBase.act!(env, KUHN_POKER_ACTIONS[action], p) RLBase.act!(env::KuhnPokerEnv, action::Int, p::ChancePlayer) = RLBase.act!(env, KUHN_POKER_CARDS[action], p) RLBase.act!(env::KuhnPokerEnv, action::Symbol, ::ChancePlayer) = push!(env.cards, action) -RLBase.act!(env::KuhnPokerEnv, action::Symbol, ::Int) = push!(env.actions, action) +RLBase.act!(env::KuhnPokerEnv, action::Symbol, ::Player) = push!(env.actions, action) RLBase.reward(::KuhnPokerEnv, ::ChancePlayer) = 0 -function RLBase.reward(env::KuhnPokerEnv, p) +function RLBase.reward(env::KuhnPokerEnv, p::Player) if is_terminated(env) v = KUHN_POKER_REWARD_TABLE[(env.cards..., env.actions...)] - p == 1 ? v : -v + p == Player(1) ? v : -v else 0 end @@ -145,13 +146,13 @@ RLBase.current_player(env::KuhnPokerEnv) = if length(env.cards) < 2 CHANCE_PLAYER elseif length(env.actions) == 0 - 1 + Player(1) elseif length(env.actions) == 1 - 2 + Player(2) elseif length(env.actions) == 2 - 1 + Player(1) else - 2 # actually the game is over now + Player(2) # actually the game is over now end RLBase.NumAgentStyle(::KuhnPokerEnv) = MultiAgent(2) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl index c3be31fb2..3f1ff5f31 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl @@ -5,7 +5,7 @@ const PIG_N_SIDES = 6 mutable struct PigEnv{N} <: AbstractEnv scores::Vector{Int} - current_player::Int + current_player::Player is_chance_player_active::Bool tmp_score::Int end @@ -17,30 +17,39 @@ See [wiki](https://en.wikipedia.org/wiki/Pig_(dice_game)) for explanation of thi Here we use it to demonstrate how to write a game with more than 2 players. """ -PigEnv(; n_players=2) = PigEnv{n_players}(zeros(Int, n_players), 1, false, 0) +PigEnv(; n_players=2) = PigEnv{n_players}(zeros(Int, n_players), Player(1), false, 0) + +function next_player(env::PigEnv) + next_player_int = parse(Int64, string(env.current_player.name)) + 1 + if next_player_int > length(players(env)) + return Player(1) + else + return Player(next_player_int) + end +end function RLBase.reset!(env::PigEnv) fill!(env.scores, 0) - env.current_player = 1 + env.current_player = Player(1) env.is_chance_player_active = false env.tmp_score = 0 end RLBase.current_player(env::PigEnv) = env.is_chance_player_active ? CHANCE_PLAYER : env.current_player -RLBase.players(env::PigEnv) = 1:length(env.scores) -RLBase.action_space(env::PigEnv, ::Int) = (:roll, :hold) +RLBase.players(env::PigEnv) = Player.(1:length(env.scores)) +RLBase.action_space(env::PigEnv, ::Player) = (:roll, :hold) RLBase.action_space(env::PigEnv, ::ChancePlayer) = Base.OneTo(PIG_N_SIDES) RLBase.prob(env::PigEnv, ::ChancePlayer) = fill(1 / 6, 6) # TODO: uniform distribution, more memory efficient -RLBase.state(env::PigEnv, ::Observation{Vector{Int}}, p) = env.scores -RLBase.state_space(env::PigEnv, ::Observation, p) = ArrayProductDomain([0 .. (PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores]) +RLBase.state(env::PigEnv, ::Observation{Vector{Int}}, p::AbstractPlayer) = env.scores +RLBase.state_space(env::PigEnv, ::Observation, p::AbstractPlayer) = ArrayProductDomain([0 .. (PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores]) RLBase.is_terminated(env::PigEnv) = any(s >= PIG_TARGET_SCORE for s in env.scores) -function RLBase.reward(env::PigEnv, player) +function RLBase.reward(env::PigEnv, player::AbstractPlayer) winner = findfirst(>=(PIG_TARGET_SCORE), env.scores) if isnothing(winner) 0 @@ -51,16 +60,13 @@ function RLBase.reward(env::PigEnv, player) end end -function RLBase.act!(env::PigEnv, action, player::Int) +function RLBase.act!(env::PigEnv, action, player::Player) if action == :roll env.is_chance_player_active = true else - env.scores[player] += env.tmp_score + env.scores[parse(Int64, string(player.name))] += env.tmp_score env.tmp_score = 0 - env.current_player += 1 - if env.current_player > length(players(env)) - env.current_player = 1 - end + env.current_player = next_player(env) end end @@ -68,10 +74,7 @@ function RLBase.act!(env::PigEnv, action, ::ChancePlayer) env.is_chance_player_active = false if action == 1 env.tmp_score = 0 - env.current_player += 1 - if env.current_player > length(players(env)) - env.current_player = 1 - end + env.current_player = next_player(env) else env.tmp_score += action end diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/RockPaperScissorsEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/RockPaperScissorsEnv.jl index d2acb9177..d343cf384 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/RockPaperScissorsEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/RockPaperScissorsEnv.jl @@ -9,11 +9,11 @@ import CommonRLInterface simultaneous, zero sum game. """ Base.@kwdef mutable struct RockPaperScissorsEnv <: AbstractEnv - reward::NamedTuple{(Symbol(1), Symbol(2)), Tuple{Int64, Int64}} = (; Symbol(1) => 0, Symbol(2) => 0) + reward::PlayerTuple{(Symbol(1), Symbol(2)), Tuple{Int64, Int64}} = PlayerTuple(Player(1) => 0, Player(2) => 0) is_done::Bool = false end -RLBase.players(::RockPaperScissorsEnv) = (Symbol(1), Symbol(2)) +RLBase.players(::RockPaperScissorsEnv) = (Player(1), Player(2)) """ Note that although this is a two player game. The current player is always a @@ -21,26 +21,26 @@ dummy simultaneous player. """ RLBase.current_player(::RockPaperScissorsEnv) = SIMULTANEOUS_PLAYER -RLBase.action_space(::RockPaperScissorsEnv, ::Symbol) = ('💎', '📃', '✂') +RLBase.action_space(::RockPaperScissorsEnv, ::Player) = ('💎', '📃', '✂') RLBase.action_space(::RockPaperScissorsEnv, ::SimultaneousPlayer) = Tuple((i, j) for i in ('💎', '📃', '✂') for j in ('💎', '📃', '✂')) RLBase.action_space(env::RockPaperScissorsEnv) = action_space(env, SIMULTANEOUS_PLAYER) -RLBase.legal_action_space(env::RockPaperScissorsEnv, p) = - is_terminated(env) ? () : action_space(env, p) +RLBase.legal_action_space(env::RockPaperScissorsEnv, player::Player) = + is_terminated(env) ? () : action_space(env, player) "Since it's a one-shot game, the state space doesn't have much meaning." -RLBase.state_space(::RockPaperScissorsEnv, ::Observation, p) = Base.OneTo(1) +RLBase.state_space(::RockPaperScissorsEnv, ::Observation, ::AbstractPlayer) = Base.OneTo(1) """ For multi-agent environments, we usually implement the most detailed one. """ -RLBase.state(::RockPaperScissorsEnv, ::Observation, p) = 1 +RLBase.state(::RockPaperScissorsEnv, ::Observation, ::AbstractPlayer) = 1 -RLBase.reward(env::RockPaperScissorsEnv) = env.is_done ? env.reward : (; Symbol(1) => 0, Symbol(2) => 0) -RLBase.reward(env::RockPaperScissorsEnv, p::Symbol) = reward(env)[p] +RLBase.reward(env::RockPaperScissorsEnv) = env.is_done ? env.reward : PlayerTuple(Player(1) => 0, Player(2) => 0) +RLBase.reward(env::RockPaperScissorsEnv, player::Player) = reward(env)[player] RLBase.is_terminated(env::RockPaperScissorsEnv) = env.is_done RLBase.reset!(env::RockPaperScissorsEnv) = env.is_done = false @@ -48,11 +48,11 @@ RLBase.reset!(env::RockPaperScissorsEnv) = env.is_done = false # TODO: Consider using CRL.all_act! and adjusting run function accordingly function RLBase.act!(env::RockPaperScissorsEnv, (x, y)) if x == y - env.reward = (; Symbol(1) => 0, Symbol(2) => 0) + env.reward = PlayerTuple(Player(1) => 0, Player(2) => 0) elseif x == '💎' && y == '✂' || x == '✂' && y == '📃' || x == '📃' && y == '💎' - env.reward = (; Symbol(1) => 1, Symbol(2) => -1) + env.reward = PlayerTuple(Player(1) => 1, Player(2) => -1) else - env.reward = (; Symbol(1) => -1, Symbol(2) => 1) + env.reward = PlayerTuple(Player(1) => -1, Player(2) => 1) end env.is_done = true end diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index e9ba35dc3..efe33e6f7 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -1,24 +1,25 @@ export TicTacToeEnv import ReinforcementLearningBase: RLBase +import ReinforcementLearningCore: Player import CommonRLInterface mutable struct TicTacToeEnv <: AbstractEnv board::BitArray{3} - player::Symbol + player::Player end function TicTacToeEnv() board = BitArray{3}(undef, 3, 3, 3) fill!(board, false) board[:, :, 1] .= true - TicTacToeEnv(board, :Cross) + TicTacToeEnv(board, Player(:Cross)) end function RLBase.reset!(env::TicTacToeEnv) fill!(env.board, false) env.board[:, :, 1] .= true - env.player = :Cross + env.player = Player(:Cross) end struct TicTacToeInfo @@ -30,21 +31,21 @@ const TIC_TAC_TOE_STATE_INFO = Dict{ TicTacToeEnv, NamedTuple{ (:index, :is_terminated, :winner), - Tuple{Int,Bool,Union{Nothing,Symbol}}, + Tuple{Int,Bool,Union{Nothing,Player}}, }, }() Base.hash(env::TicTacToeEnv, h::UInt) = hash(env.board, h) Base.isequal(a::TicTacToeEnv, b::TicTacToeEnv) = isequal(a.board, b.board) -Base.to_index(::TicTacToeEnv, player) = player == :Cross ? 2 : 3 +Base.to_index(::TicTacToeEnv, player::Player) = player == Player(:Cross) ? 2 : 3 -RLBase.action_space(::TicTacToeEnv, player) = Base.OneTo(9) +RLBase.action_space(::TicTacToeEnv, player::Player) = Base.OneTo(9) -RLBase.legal_action_space(env::TicTacToeEnv, p) = findall(legal_action_space_mask(env)) +RLBase.legal_action_space(env::TicTacToeEnv, player::Player) = findall(legal_action_space_mask(env)) -function RLBase.legal_action_space_mask(env::TicTacToeEnv, p) - if is_win(env, :Cross) || is_win(env, :Nought) +function RLBase.legal_action_space_mask(env::TicTacToeEnv, player::Player) + if is_win(env, Player(:Cross)) || is_win(env, Player(:Nought)) falses(9) else vec(env.board[:, :, 1]) @@ -59,25 +60,25 @@ function RLBase.act!(env::TicTacToeEnv, action::CartesianIndex{2}) end function RLBase.next_player!(env::TicTacToeEnv) - env.player = env.player == :Cross ? :Nought : :Cross + env.player = env.player == Player(:Cross) ? Player(:Nought) : Player(:Cross) end -RLBase.players(::TicTacToeEnv) = (:Cross, :Nought) +RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) -RLBase.state(env::TicTacToeEnv) = state(env, Observation{Int}(), 1) -RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board -RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), 1) -RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) = +RLBase.state(env::TicTacToeEnv) = state(env, Observation{Int}(), Player(:Any)) +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board +RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1)) +RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = get_tic_tac_toe_state_info()[env].index -RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = ArrayProductDomain(fill(false:true, 3, 3, 3)) -RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) = +RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, player::Player) = ArrayProductDomain(fill(false:true, 3, 3, 3)) +RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, player::Player) = Base.OneTo(length(get_tic_tac_toe_state_info())) -RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, p) = fullspace(String) +RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, player::Player) = fullspace(String) -RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), 1) +RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), Player(1)) -function RLBase.state(env::TicTacToeEnv, ::Observation{String}, p) +function RLBase.state(env::TicTacToeEnv, ::Observation{String}, player::Player) buff = IOBuffer() for i in 1:3 for j in 1:3 @@ -97,7 +98,7 @@ end RLBase.is_terminated(env::TicTacToeEnv) = get_tic_tac_toe_state_info()[env].is_terminated -function RLBase.reward(env::TicTacToeEnv, player::Symbol) +function RLBase.reward(env::TicTacToeEnv, player::Player) if is_terminated(env) winner = get_tic_tac_toe_state_info()[env].winner if isnothing(winner) @@ -112,7 +113,7 @@ function RLBase.reward(env::TicTacToeEnv, player::Symbol) end end -function is_win(env::TicTacToeEnv, player::Symbol) +function is_win(env::TicTacToeEnv, player::Player) b = env.board p = Base.to_index(env, player) @inbounds begin @@ -139,10 +140,10 @@ function get_tic_tac_toe_state_info() if !haskey(TIC_TAC_TOE_STATE_INFO, env) n += 1 has_empty_pos = any(view(env.board, :, :, 1)) - w = if is_win(env, :Cross) - :Cross - elseif is_win(env, :Nought) - :Nought + w = if is_win(env, Player(:Cross)) + Player(:Cross) + elseif is_win(env, Player(:Nought)) + Player(:Nought) else nothing end diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TinyHanabiEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TinyHanabiEnv.jl index 4b201efb0..57229822c 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TinyHanabiEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TinyHanabiEnv.jl @@ -43,21 +43,21 @@ function RLBase.reset!(env::TinyHanabiEnv) empty!(env.actions) end -RLBase.players(env::TinyHanabiEnv) = 1:2 +RLBase.players(env::TinyHanabiEnv) = Player.(1:2) RLBase.current_player(env::TinyHanabiEnv) = if length(env.cards) < 2 CHANCE_PLAYER elseif length(env.actions) == 0 - 1 + Player(1) else - 2 + Player(2) end RLBase.act!(env::TinyHanabiEnv, action, ::ChancePlayer) = push!(env.cards, action) -RLBase.act!(env::TinyHanabiEnv, action, ::Int) = push!(env.actions, action) +RLBase.act!(env::TinyHanabiEnv, action, ::Player) = push!(env.actions, action) -RLBase.action_space(env::TinyHanabiEnv, ::Int) = Base.OneTo(3) +RLBase.action_space(env::TinyHanabiEnv, ::Player) = Base.OneTo(3) RLBase.action_space(env::TinyHanabiEnv, ::ChancePlayer) = Base.OneTo(2) RLBase.legal_action_space(env::TinyHanabiEnv, ::ChancePlayer) = findall(!in(env.cards), 1:2) @@ -80,24 +80,25 @@ RLBase.state_space(env::TinyHanabiEnv, ::InformationSet, ::ChancePlayer) = ((0,), (0, 1), (0, 2), (0, 1, 2), (0, 2, 1)) # (chance_player_id(0), chance_player's actions...) RLBase.state(env::TinyHanabiEnv, ::InformationSet, ::ChancePlayer) = (0, env.cards...) -function RLBase.state_space(env::TinyHanabiEnv, ::InformationSet, p::Int) +function RLBase.state_space(env::TinyHanabiEnv, ::InformationSet, p::Player) Tuple( (p, c..., a...) for p in 1:2 for c in ((), 1, 2) for a in ((), 1:3..., ((i, j) for i in 1:3 for j in 1:3)...) ) end -function RLBase.state(env::TinyHanabiEnv, ::InformationSet, p::Int) - card = length(env.cards) >= p ? env.cards[p] : () - (p, card..., env.actions...) +function RLBase.state(env::TinyHanabiEnv, ::InformationSet, player::Player) + player_int = parse(Int, string(player.name)) + card = length(env.cards) >= player_int ? env.cards[player_int] : () + (player_int, card..., env.actions...) end RLBase.is_terminated(env::TinyHanabiEnv) = length(env.actions) == 2 -RLBase.reward(env::TinyHanabiEnv, player) = +RLBase.reward(env::TinyHanabiEnv, player::AbstractPlayer) = is_terminated(env) ? env.reward_table[env.actions..., env.cards...] : 0 RLBase.act!(env::TinyHanabiEnv, action::Int, ::ChancePlayer) = push!(env.cards, action) -RLBase.act!(env::TinyHanabiEnv, action::Int, ::Int) = push!(env.actions, action) +RLBase.act!(env::TinyHanabiEnv, action::Int, ::Player) = push!(env.actions, action) RLBase.NumAgentStyle(::TinyHanabiEnv) = MultiAgent(2) RLBase.DynamicStyle(::TinyHanabiEnv) = SEQUENTIAL diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/wrappers.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/wrappers.jl index a216d4690..a0c10d8d9 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/wrappers.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/wrappers.jl @@ -18,13 +18,13 @@ for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) end # avoid ambiguous -RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = - state(env[], ss, p) +RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, player::Player) = + state(env[], ss, player) RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state(env[], ss) RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state_space(env[], ss) -RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = - state_space(env[], ss, p) +RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, player::Player) = + state_space(env[], ss, player) include("ActionTransformedEnv.jl") include("DefaultStateStyle.jl") diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/random_walk_1d.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/random_walk_1d.jl index 330c72788..a616c2e7f 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/random_walk_1d.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/random_walk_1d.jl @@ -74,3 +74,45 @@ end @test reward(env) == 0 @test (@allocated reward(env)) == 0 end + +@testset "Full run" begin + using JLD2 + + env = RandomWalk1D() + ns, na = length(state_space(env)), length(action_space(env)) + + policy = Agent( + QBasedPolicy(; + learner = TDLearner( + TabularQApproximator(n_state = ns, n_action = na), + :SARS; + ), + explorer = EpsilonGreedyExplorer(ϵ_stable=0.01), + ), + Trajectory( + CircularArraySARTSTraces(; + capacity = 1, + state = Int64 => (), + action = Int64 => (), + reward = Float64 => (), + terminal = Bool => (), + ), + DummySampler(), + InsertSampleRatioController(), + ), + ) + + parameters_dir = mktempdir() + + run( + policy, + env, + StopAfterNSteps(10_000), + DoEveryNSteps(n=1_000) do t, p, e + ps = policy.policy.learner.approximator + f = joinpath(parameters_dir, "parameters_at_step_$t.jld2") + JLD2.@save f ps + println("parameters at step $t saved to $f") + end + ) +end diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/rock_paper_scissors.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/rock_paper_scissors.jl index 293d4bfce..1d1749f06 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/rock_paper_scissors.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/rock_paper_scissors.jl @@ -11,10 +11,10 @@ while !is_terminated(env) RLBase.act!(env, rand(rng, legal_action_space(env))) end - @test RLBase.reward(env, Symbol(1)) == (-1 * RLBase.reward(env, Symbol(2))) + @test RLBase.reward(env, Player(1)) == (-1 * RLBase.reward(env, Player(2))) @test RLBase.is_terminated(env) isa Bool - push!(rewards[1], RLBase.reward(env, Symbol(1))) - push!(rewards[2], RLBase.reward(env, Symbol(2))) + push!(rewards[1], RLBase.reward(env, Player(1))) + push!(rewards[2], RLBase.reward(env, Player(2))) reset!(env) end diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl index c10ab6628..0eca516ff 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl @@ -14,12 +14,12 @@ InsertSampleRatioController(n_inserted = -1), ) - multiagent_policy = MultiAgentPolicy((; - :Cross => Agent(RandomPolicy(), trajectory_1), - :Nought => Agent(RandomPolicy(), trajectory_2), + multiagent_policy = MultiAgentPolicy(PlayerTuple( + Player(:Cross) => Agent(RandomPolicy(), trajectory_1), + Player(:Nought) => Agent(RandomPolicy(), trajectory_2), )) - multiagent_hook = MultiAgentHook((; :Cross => StepsPerEpisode(), :Nought => StepsPerEpisode())) + multiagent_hook = MultiAgentHook(PlayerTuple(Player(:Cross) => StepsPerEpisode(), Player(:Nought) => StepsPerEpisode())) env = TicTacToeEnv() stop_condition = StopIfEnvTerminated() @@ -29,11 +29,11 @@ @test length(state_space(env, Observation{Int}())) == 5478 - @test RLBase.state(env, Observation{BitArray{3}}(), :Cross) == env.board - @test RLBase.state_space(env, Observation{BitArray{3}}(), :Cross) isa ArrayProductDomain - @test RLBase.state_space(env, Observation{String}(), :Cross) isa DomainSets.FullSpace{String} - @test RLBase.state(env, Observation{String}(), :Cross) isa String + @test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) == env.board + @test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain + @test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String} + @test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String @test RLBase.state(env, Observation{String}()) isa String Base.run(multiagent_policy, env, stop_condition, multiagent_hook) - @test RLBase.legal_action_space_mask(env, :Cross) == falses(9) + @test RLBase.legal_action_space_mask(env, Player(:Cross)) == falses(9) end diff --git a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl index c79d6e283..baf016a47 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl @@ -130,20 +130,79 @@ reset!(env) - @test reward(env) == (; Symbol(1) => 0, Symbol(2) => 0) + @test reward(env) == PlayerTuple(Player(1) => 0, Player(2) => 0) @test is_terminated(env) == false RLBase.act!(env, ['💎', '📃']) - @test reward(env, Symbol(1)) == -1 - @test reward(env, Symbol(2)) == 1 + @test reward(env, Player(1)) == -1 + @test reward(env, Player(2)) == 1 @test is_terminated(env) == true reset!(env) - @test reward(env) == (; Symbol(1) => 0, Symbol(2) => 0) - @test reward(env, Symbol(1)) == 0 - @test reward(env, Symbol(2)) == 0 + @test reward(env) == PlayerTuple(Player(1) => 0, Player(2) => 0) + @test reward(env, Player(1)) == 0 + @test reward(env, Player(2)) == 0 @test is_terminated(env) == false end end + +@testset "Wrapped Env run" begin + Base.@kwdef mutable struct LotteryEnv <: AbstractEnv + reward::Union{Nothing, Int} = nothing + end + + struct LotteryAction{a} + function LotteryAction(a) + new{a}() + end + end + + RLBase.action_space(env::LotteryEnv) = LotteryAction.([:PowerRich, :MegaHaul, nothing]) + + RLBase.reward(env::LotteryEnv) = env.reward + RLBase.state(env::LotteryEnv) = !isnothing(env.reward) + RLBase.state_space(env::LotteryEnv) = [false, true] + RLBase.is_terminated(env::LotteryEnv) = !isnothing(env.reward) + RLBase.reset!(env::LotteryEnv) = env.reward = nothing + + function RLBase.act!(x::LotteryEnv, action) + if action == LotteryAction(:PowerRich) + x.reward = rand() < 0.01 ? 100_000_000 : -10 + elseif action == LotteryAction(:MegaHaul) + x.reward = rand() < 0.05 ? 1_000_000 : -10 + elseif action == LotteryAction(nothing) + x.reward = 0 + else + @error "unknown action of $action" + end + end + + env = LotteryEnv() + + p = QBasedPolicy( + learner = TDLearner( + TabularQApproximator( + n_state = length(state_space(env)), + n_action = length(action_space(env)), + ), :SARS + ), + explorer = EpsilonGreedyExplorer(0.1) + ) + + wrapped_env = ActionTransformedEnv( + StateTransformedEnv( + env; + state_mapping=s -> s ? 1 : 2, + state_space_mapping = _ -> Base.OneTo(2) + ); + action_mapping = i -> action_space(env)[i], + action_space_mapping = _ -> Base.OneTo(3), + ) + @test plan!(p, wrapped_env) ∈ [1, 2, 3] + + h = TotalRewardPerEpisode() + e = run(p, wrapped_env, StopAfterNEpisodes(1_000), h) + @test h.reward ∈ [-10, 100_000_000, 1_000_000, 0] +end diff --git a/src/ReinforcementLearningEnvironments/test/runtests.jl b/src/ReinforcementLearningEnvironments/test/runtests.jl index d8dfe872f..80bb8fe8a 100644 --- a/src/ReinforcementLearningEnvironments/test/runtests.jl +++ b/src/ReinforcementLearningEnvironments/test/runtests.jl @@ -12,6 +12,7 @@ using Statistics using OrdinaryDiffEq using TimerOutputs using Conda +using JLD2 Conda.add("gym") Conda.add("numpy") diff --git a/src/ReinforcementLearningFarm/src/hooks/total_reward_per_last_n_episodes.jl b/src/ReinforcementLearningFarm/src/hooks/total_reward_per_last_n_episodes.jl index 816f63744..1fb4e95ce 100644 --- a/src/ReinforcementLearningFarm/src/hooks/total_reward_per_last_n_episodes.jl +++ b/src/ReinforcementLearningFarm/src/hooks/total_reward_per_last_n_episodes.jl @@ -25,7 +25,7 @@ Base.push!( ::PostActStage, agent::P, env::E, - player::Symbol, + player::Player, ) where {P<:AbstractPolicy,E<:AbstractEnv,B<:CircularArrayBuffer} = h.rewards[end] += reward(env, player) @@ -41,5 +41,5 @@ Base.push!( stage::Union{PreEpisodeStage,PostEpisodeStage,PostExperimentStage}, agent, env, - player::Symbol, + player::Player, ) where {B<:CircularArrayBuffer} = Base.push!(hook, stage, agent, env) diff --git a/src/ReinforcementLearningFarm/test/hooks/total_reward_per_last_n_episodes.jl b/src/ReinforcementLearningFarm/test/hooks/total_reward_per_last_n_episodes.jl index 0d865e5fa..16e09c222 100644 --- a/src/ReinforcementLearningFarm/test/hooks/total_reward_per_last_n_episodes.jl +++ b/src/ReinforcementLearningFarm/test/hooks/total_reward_per_last_n_episodes.jl @@ -20,10 +20,10 @@ using ReinforcementLearningFarm: TotalRewardPerLastNEpisodes agent = RandomPolicy() for i = 1:15 - push!(hook, PreEpisodeStage(), agent, env, :Cross) - push!(hook, PostActStage(), agent, env, :Cross) + push!(hook, PreEpisodeStage(), agent, env, Player(:Cross)) + push!(hook, PostActStage(), agent, env, Player(:Cross)) @test length(hook.rewards) == min(i, 10) - @test hook.rewards[min(i, 10)] == reward(env, :Cross) + @test hook.rewards[min(i, 10)] == reward(env, Player(:Cross)) end end end