Skip to content

Commit abe30f1

Browse files
Updated random sampling of Box
For added consistency with gymnasium: - Finite intervals are sampled uniformly, - Infinite intervals are sampled normally, - Semi-infinite intervals are sampled as shifted exponential distributions.
1 parent 889dbc5 commit abe30f1

File tree

5 files changed

+65
-4
lines changed

5 files changed

+65
-4
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Jun Tian <[email protected]> and contributors"]
44
version = "0.2.1"
55

66
[deps]
7+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
78
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/CommonRLSpaces.jl

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Reexport
77
using StaticArrays
88
using FillArrays
99
using Random
10+
using Distributions
1011
import Base: clamp
1112

1213
export

src/array.jl

+52-4
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,61 @@ Box(lower::Array, upper::Array) = Box(lower, upper; convert_to_static=true)
4343

4444
SpaceStyle(::Box) = ContinuousSpaceStyle()
4545

46-
function Base.rand(rng::AbstractRNG, sp::Random.SamplerTrivial{<:Box})
46+
"""
47+
Base.rand(::AbstractRNG, ::Random.SamplerTrivial{<:Box})
48+
49+
Generate an array where each element is sampled from a dimension of a Box space.
50+
51+
* Finite intervals [a,b] are sampled from uniform distributions.
52+
* Semi-infinite intervals (a,Inf) and (-Inf,b) are sampled from shifted exponential distributions.
53+
* Infinite intervals (-Inf,Inf) are sampled from normal distributions.
54+
55+
#Example
56+
```julia
57+
julia> using Random: seed!
58+
59+
julia> using Distributions: Uniform, Normal, Exponential
60+
61+
julia> box = Box([-10, -Inf, 3], [10, Inf, Inf])
62+
Box{StaticArraysCore.SVector{3, Float64}}([-10.0, -Inf, 3.0], [10.0, Inf, Inf])
63+
64+
julia> seed!(0); rand(box)
65+
3-element StaticArraysCore.SVector{3, Float64} with indices SOneTo(3):
66+
-1.8860105821594164
67+
0.13392275765318448
68+
3.837385552384043
69+
70+
julia> seed!(0); [rand(Uniform(-10,10)), rand(Normal()), 3+rand(Exponential())]
71+
3-element Vector{Float64}:
72+
-1.8860105821594164
73+
0.13392275765318448
74+
3.837385552384043
75+
```
76+
"""
77+
function Base.rand(rng::AbstractRNG, sp::Random.SamplerTrivial{Box{T}}) where {T}
4778
box = sp[]
48-
return box.lower + rand_similar(rng, box.lower) .* (box.upper-box.lower)
79+
x = [rand_interval(rng, lb, ub) for (lb, ub) in zip(box.lower, box.upper)]
80+
return T(x)
4981
end
5082

51-
rand_similar(rng::AbstractRNG, a::StaticArray) = rand(rng, typeof(a))
52-
rand_similar(rng::AbstractRNG, a::AbstractArray) = rand(rng, eltype(a), size(a)...)
83+
function rand_interval(rng::AbstractRNG, lb::T, ub::T) where {T <: Real}
84+
offset, sign = zero(T), one(T)
85+
86+
if isfinite(lb) && isfinite(ub)
87+
dist = Uniform(lb, ub)
88+
elseif isfinite(lb) && !isfinite(ub)
89+
offset = lb
90+
dist = Exponential(one(T))
91+
elseif !isfinite(lb) && isfinite(ub)
92+
offset = ub
93+
sign = -one(T)
94+
dist = Exponential(one(T))
95+
else
96+
dist = Normal(zero(T), one(T))
97+
end
98+
99+
return offset + sign * rand(rng, dist)
100+
end
53101

54102
Base.in(x::AbstractArray, b::Box) = all(b.lower .<= x .<= b.upper)
55103

test/array.jl

+9
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@
7878
end
7979
end
8080

81+
@testset "Box random sample" begin
82+
box = Box([-10, -Inf, 3, -Inf], [10, Inf, Inf, 6])
83+
Random.seed!(0)
84+
x = rand(box)
85+
Random.seed!(0)
86+
y = SA[rand(Uniform(-10, 10)), rand(Normal()), 3+rand(Exponential()), 6-rand(Exponential())]
87+
@test x == y
88+
end
89+
8190
@testset "Interval to box conversion" begin
8291
@test convert(Box, 1..2) == Box([1], [2])
8392
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using CommonRLSpaces
22
using Test
33

44
using StaticArrays
5+
using Distributions
6+
using Random
57

68
@testset "CommonRLSpaces.jl" begin
79
include("basic.jl")

0 commit comments

Comments
 (0)