Skip to content

Commit

Permalink
Throw error if ordered unsafe to use (#241)
Browse files Browse the repository at this point in the history
* Throw error if transform not supported

* Document requirements of ordered

* Add ordered test

* Update Project.toml

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
sethaxen and yebai authored Feb 2, 2023
1 parent 8b924d0 commit d8ebcd4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.11.0"
version = "0.11.1"


[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
9 changes: 8 additions & 1 deletion src/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@ struct OrderedBijector <: Bijector end
ordered(d::Distribution)
Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements.
This transformation is currently only supported for otherwise unconstrained distributions.
"""
ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector())
function ordered(d::ContinuousMultivariateDistribution)
if !isa(bijector(d), Identity)
throw(ArgumentError("ordered transform is currently only supported for unconstrained distributions."))
end
return Bijectors.transformed(d, OrderedBijector())
end

with_logabsdet_jacobian(b::OrderedBijector, x) = transform(b, x), logabsdetjac(b, x)

Expand Down
22 changes: 21 additions & 1 deletion test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Bijectors: OrderedBijector
import Bijectors: OrderedBijector, ordered
using LinearAlgebra

@testset "OrderedBijector" begin
b = OrderedBijector()
Expand All @@ -14,3 +15,22 @@ import Bijectors: OrderedBijector
y = b(x)
@test sort(y) == y
end

@testset "ordered" begin
d = MvNormal(1:5, Diagonal(6:10))
d_ordered = ordered(d)
@test d_ordered isa Bijectors.TransformedDistribution
@test d_ordered.dist === d
@test d_ordered.transform isa OrderedBijector
y = randn(5)
x = inv(bijector(d_ordered))(y)
@test issorted(x)

d = Product(fill(Normal(), 5))
# currently errors because `bijector(Product(fill(Normal(), 5)))` is not an `Identity`
@test_broken ordered(d) isa Bijectors.TransformedDistribution

# non-Identity bijector is not supported
d = Dirichlet(ones(5))
@test_throws ArgumentError ordered(d)
end

0 comments on commit d8ebcd4

Please sign in to comment.