Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Immutable versions of link and invlink #525

Merged
merged 27 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7d4bcd9
added immutable versions of link and invlink
torfjelde Aug 31, 2023
050e099
added explicit invlink implementation for VarInfo
torfjelde Aug 31, 2023
3a77e03
remove false debug statement
torfjelde Aug 31, 2023
51013fb
fixed default impls of invlink for AbstractVarInfo
torfjelde Aug 31, 2023
0411c12
formatting
torfjelde Aug 31, 2023
7f633a8
use x to refer to the constrained space in invlink impl
torfjelde Aug 31, 2023
5d94640
added immuatable link implementation for VarInfo
torfjelde Aug 31, 2023
d37527a
added threadsafe versions of link and invlink
torfjelde Aug 31, 2023
3f7f4b1
added default implementations of link and invlink for DynamicTransfor…
torfjelde Aug 31, 2023
50b4332
formatting
torfjelde Aug 31, 2023
19c0dce
added tests for immutable link and invlink
torfjelde Aug 31, 2023
98fb4ba
export link and invlink
torfjelde Aug 31, 2023
4000c12
added link and invlink to docs
torfjelde Aug 31, 2023
2e7b71a
fixed setall! for UntypedVarInfo
torfjelde Aug 31, 2023
c822ea4
added testing model demo_one_variable_multiple_constraints
torfjelde Aug 31, 2023
afc5a26
fixed BangBang.setindex!! for setting vector in vector
torfjelde Aug 31, 2023
6465fa9
added tests with unflatten + linking
torfjelde Aug 31, 2023
4cc0a45
fixed reference to logabsdetjac in TestUtils
torfjelde Aug 31, 2023
d281487
improoved tests for unflatten + linking
torfjelde Aug 31, 2023
6f55139
improved testing of unflatten + linking a bit
torfjelde Aug 31, 2023
64ed2dd
added demo_lkjchol model to TestUtils
torfjelde Aug 31, 2023
d310fa8
formatting
torfjelde Aug 31, 2023
6a21ee2
fixed impl of link for AbstractVarInfo
torfjelde Aug 31, 2023
32b77ac
epxanded comment on BangBang hack
torfjelde Sep 1, 2023
4810374
Apply suggestions from code review
torfjelde Sep 1, 2023
eca5189
added references to BangBang issues and PRs talking about the
torfjelde Sep 1, 2023
e05bb09
Merge remote-tracking branch 'origin/torfjelde/immutable-versions-of-…
torfjelde Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Transforms the variables in `vi` to their linked space, using the transformation `t`.
Transform the variables in `vi` to their linked space, using the transformation `t`,
mutating `vi` if possible.

If `t` is not provided, `default_transformation(model, vi)` will be used.

Expand All @@ -383,12 +384,31 @@
return link!!(default_transformation(model, vi), vi, spl, model)
end

"""
link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Transform the variables in `vi` to their linked space, using the transformation `t`.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

If `t` is not provided, `default_transformation(model, vi)` will be used.

See also: [`default_transformation`](@ref), [`invlink`](@ref).
"""
link(vi::AbstractVarInfo, model::Model) = link!!(deepcopy(vi), SampleFromPrior(), model)
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, deepcopy(vi), SampleFromPrior(), model)

Check warning on line 399 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L397-L399

Added lines #L397 - L399 were not covered by tests
end
function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Check warning on line 401 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L401

Added line #L401 was not covered by tests
# Use `default_transformation` to decide which transformation to use if none is specified.
return link!!(default_transformation(model, vi), deepcopy(vi), spl, model)

Check warning on line 403 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L403

Added line #L403 was not covered by tests
end

"""
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Transform the variables in `vi` to their constrained space, using the (inverse of)
transformation `t`.
transformation `t`, mutating `vi` if possible.

If `t` is not provided, `default_transformation(model, vi)` will be used.

Expand Down Expand Up @@ -434,6 +454,25 @@
return settrans!!(vi_new, NoTransformation())
end

"""
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model)
invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model)

Transform the variables in `vi` to their constrained space, using the (inverse of)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
transformation `t`.

If `t` is not provided, `default_transformation(model, vi)` will be used.

See also: [`default_transformation`](@ref), [`link`](@ref).
"""
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SampleFromPrior was introduced before we had the mechanism for passing the t::Transformation. It is now no longer necessary. In the longer run, we can consider replacing SampleFromPrior with a suitable context for clarity and consistency.

Suggested change
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)
invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! This has always been the idea:) But for now, we're not quite read to do that.

function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink(t, vi, SampleFromPrior(), model)

Check warning on line 470 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L468-L470

Added lines #L468 - L470 were not covered by tests
end
function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model)
return invlink(transformation(vi), vi, spl, model)

Check warning on line 473 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L472-L473

Added lines #L472 - L473 were not covered by tests
end

"""
maybe_invlink_before_eval!!([t::Transformation,] vi, context, model)

Expand Down
139 changes: 138 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,6 @@
end

function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
@debug "X -> ℝ for $(vn)..."
# TODO: Use inplace versions to avoid allocations
y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, getval(vi, vn))
yvec = vectorize(dist, y)
Expand All @@ -899,6 +898,144 @@
return vi
end

function link(

Check warning on line 901 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L901

Added line #L901 was not covered by tests
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
return _link(varinfo)

Check warning on line 904 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L904

Added line #L904 was not covered by tests
end

function _link(varinfo::UntypedVarInfo)
varinfo = deepcopy(varinfo)
return VarInfo(

Check warning on line 909 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L907-L909

Added lines #L907 - L909 were not covered by tests
_link_metadata!(varinfo, varinfo.metadata),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)

Check warning on line 913 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L913

Added line #L913 was not covered by tests
end

function _link(varinfo::TypedVarInfo)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata)

Check warning on line 918 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L916-L918

Added lines #L916 - L918 were not covered by tests
# TODO: Update logp, etc.
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))

Check warning on line 920 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L920

Added line #L920 was not covered by tests
end

function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
vns = metadata.vns

Check warning on line 924 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L923-L924

Added lines #L923 - L924 were not covered by tests

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn

Check warning on line 927 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L927

Added line #L927 was not covered by tests
# Return early if we're already in unconstrained space.
if istrans(varinfo, vn)
return metadata.vals[getrange(metadata, vn)]

Check warning on line 930 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L929-L930

Added lines #L929 - L930 were not covered by tests
end

# Transform to constrained space.
x = getval(varinfo, vn)
dist = getdist(varinfo, vn)
f = link_transform(dist)
y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, x)

Check warning on line 937 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L934-L937

Added lines #L934 - L937 were not covered by tests
# Vectorize value.
yvec = vectorize(dist, y)

Check warning on line 939 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L939

Added line #L939 was not covered by tests
# Accumulate the log-abs-det jacobian correction.
acclogp!!(varinfo, -logjac)

Check warning on line 941 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L941

Added line #L941 was not covered by tests
# Mark as no longer transformed.
settrans!!(varinfo, true, vn)

Check warning on line 943 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L943

Added line #L943 was not covered by tests
# Return the vectorized transformed value.
return yvec

Check warning on line 945 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L945

Added line #L945 was not covered by tests
end

# Determine new ranges.
ranges_new = similar(metadata.ranges)
offset = 0
for (i, v) in enumerate(vals_new)
r_start, r_end = offset + 1, length(v) + offset
offset = r_end
ranges_new[i] = r_start:r_end
end

Check warning on line 955 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L949-L955

Added lines #L949 - L955 were not covered by tests

# Now we just create a new metadata with the new `vals` and `ranges`.
return Metadata(

Check warning on line 958 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L958

Added line #L958 was not covered by tests
metadata.idcs,
metadata.vns,
ranges_new,
reduce(vcat, vals_new),
metadata.dists,
metadata.gids,
metadata.orders,
metadata.flags,
)
end

function invlink(

Check warning on line 970 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L970

Added line #L970 was not covered by tests
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
)
return _invlink(varinfo)

Check warning on line 973 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L973

Added line #L973 was not covered by tests
end

function _invlink(varinfo::UntypedVarInfo)
varinfo = deepcopy(varinfo)
return VarInfo(

Check warning on line 978 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L976-L978

Added lines #L976 - L978 were not covered by tests
_invlink_metadata!(varinfo, varinfo.metadata),
Base.Ref(getlogp(varinfo)),
Ref(get_num_produce(varinfo)),
)
end

function _invlink(varinfo::TypedVarInfo)
varinfo = deepcopy(varinfo)
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata)

Check warning on line 987 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L985-L987

Added lines #L985 - L987 were not covered by tests
# TODO: Update logp, etc.
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))

Check warning on line 989 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L989

Added line #L989 was not covered by tests
end

function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
vns = metadata.vns

Check warning on line 993 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L992-L993

Added lines #L992 - L993 were not covered by tests

# Construct the new transformed values, and keep track of their lengths.
vals_new = map(vns) do vn

Check warning on line 996 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L996

Added line #L996 was not covered by tests
# Return early if we're already in constrained space.
if !istrans(varinfo, vn)
return metadata.vals[getrange(metadata, vn)]

Check warning on line 999 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L998-L999

Added lines #L998 - L999 were not covered by tests
end

# Transform to constrained space.
y = getval(varinfo, vn)
dist = getdist(varinfo, vn)
f = invlink_transform(dist)
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)

Check warning on line 1006 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1003-L1006

Added lines #L1003 - L1006 were not covered by tests
# Vectorize value.
xvec = vectorize(dist, x)

Check warning on line 1008 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1008

Added line #L1008 was not covered by tests
# Accumulate the log-abs-det jacobian correction.
acclogp!!(varinfo, -logjac)

Check warning on line 1010 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1010

Added line #L1010 was not covered by tests
# Mark as no longer transformed.
settrans!!(varinfo, false, vn)

Check warning on line 1012 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1012

Added line #L1012 was not covered by tests
# Return the vectorized transformed value.
return xvec

Check warning on line 1014 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1014

Added line #L1014 was not covered by tests
end

# Determine new ranges.
ranges_new = similar(metadata.ranges)
offset = 0
for (i, v) in enumerate(vals_new)
r_start, r_end = offset + 1, length(v) + offset
offset = r_end
ranges_new[i] = r_start:r_end
end

Check warning on line 1024 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1018-L1024

Added lines #L1018 - L1024 were not covered by tests

# Now we just create a new metadata with the new `vals` and `ranges`.
return Metadata(

Check warning on line 1027 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1027

Added line #L1027 was not covered by tests
metadata.idcs,
metadata.vns,
ranges_new,
reduce(vcat, vals_new),
metadata.dists,
metadata.gids,
metadata.orders,
metadata.flags,
)
end

"""
islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior})

Expand Down
Loading