Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds @returned_quantities macro #696

Open
wants to merge 57 commits into
base: master
Choose a base branch
from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Oct 23, 2024

This adds the @returned_quantities macro as discussed @yebai @mhauru

This is meant to be a replacement for @submodel macro, but without the ability to do automatic prefixing. It ends up looking like

julia> @model function demo1(x)
           x ~ Normal()
           return 1 + abs(x)
       end;

julia> @model function demo2(x, y, z)
            a = @returned_quantities prefix(demo1(x), "sub1")
            b = @returned_quantities prefix(demo1(y), "sub2")
            return z ~ Uniform(-a, b)
       end;

julia> rand(demo2(missing, missing, 0.4))
(var"sub1.x" = 0.5865756059371534, var"sub2.x" = -0.25563799658500047)

Likely TODOs:

  • Add deprecation warning to @submodel telling the user to use @returned_quantities.
  • Do we do the renaming of generated_quantities to returned_quantities in this PR?

Fix #691

Copy link

codecov bot commented Oct 23, 2024

Codecov Report

Attention: Patch coverage is 93.75000% with 1 line in your changes missing coverage. Please review.

Project coverage is 77.78%. Comparing base (54691bf) to head (7aef65b).
Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
src/contexts.jl 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #696      +/-   ##
==========================================
- Coverage   79.22%   77.78%   -1.45%     
==========================================
  Files          30       30              
  Lines        4212     3938     -274     
==========================================
- Hits         3337     3063     -274     
  Misses        875      875              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Oct 23, 2024

Pull Request Test Coverage Report for Build 11627641486

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 15 of 16 (93.75%) changed or added relevant lines in 4 files are covered.
  • 62 unchanged lines in 2 files lost coverage.
  • Overall coverage decreased (-2.2%) to 77.447%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/contexts.jl 4 5 80.0%
Files with Coverage Reduction New Missed Lines %
src/utils.jl 23 82.48%
src/varinfo.jl 39 80.24%
Totals Coverage Status
Change from base Build 11381380435: -2.2%
Covered Lines: 3046
Relevant Lines: 3933

💛 - Coveralls

src/submodel_macro.jl Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Oct 24, 2024

@torfjelde I suggest we change the prefix feature to a prefix_variables model operation (feel free to come up with better names). Then we could use the same functionality prefix_variables in more places, e.g.

# submodel prefixing
julia> @model function demo2(x, y, z)
            a = @returned_quantities prefix_variables(demo1(x), "sub1")
            b = @returned_quantities prefix_variables(demo1(y), "sub2")
            return z ~ Uniform(-a, b)
       end;

julia> rand(demo2(missing, missing, 0.4))
(var"sub1.x" = 0.5865756059371534, var"sub2.x" = -0.25563799658500047)

# rand prefixing 

julia> ret = rand(prefix_variables(demo1(1.), "prior_sample"))

# generated quantities / predict 

julia> returned_quantities(prefix_variables(demo1(1.), "generated_var_"), chain) 

This would also help unify the syntax of @generated_qunatities and generated_quantities- IIRC, the only difference between them is that generated_quantities lacks the prefixing/renaming feature.

This could be further unified with NamedDist in the future. See, e.g., #414

@torfjelde
Copy link
Member Author

We already have DynamicPPL.prefix, though this doesn't do exactly what you want here. We could easily just add

prefix(model::Model, x) = contextualize(model, PrefixContext(model.context, Symbol(x)))

or something as an additional definition.

However, I'm a bit worred about

  1. It's quite verbose + a bit "too close to internals" for end-users.
  2. To achieve the same performance guarantees that we have currently, we need to wrap everything in Val before calling prefix(model, ...) 😕 This seems non-ideal to me vs. the current approach.

@yebai
Copy link
Member

yebai commented Oct 25, 2024

It's quite verbose + a bit "too close to internals" for end-users.

I like the @returned_quantities(prefix(model, "prefix_")) syntax because it is

  • less mysterious than @returned_quantities model "prefix_"
  • all the other model operations could share this, e.g. rand(prefix(model, "prefix_")) to verify the effects of prefixing, which is very useful

prefix(model, x) is NOT any closer to internals than any other model operation APIs. They are the same, so this is not a problem.

To achieve the same performance guarantees that we have currently, we need to wrap everything in Val before calling prefix(model, ...) 😕 This seems non-ideal to me vs. the current approach.

Point taken, but this is very minor and a bit subjective.

@torfjelde
Copy link
Member Author

torfjelde commented Oct 26, 2024

Point taken, but this is very minor and a bit subjective.

But this means that the user needs to be careful and do prefix(model, Val{:whatever}()); if we just do prefix(model, :whatever), this will lead to type-instabilities. Do we really want to force end-users of Turing.jl to explicitly use Val? 😕

@yebai
Copy link
Member

yebai commented Oct 27, 2024

It is a standard Julia performance trick, so it is okay.

By default, we can print a performance warning message when users call prefix(model, x::String) or similiar.

@yebai
Copy link
Member

yebai commented Oct 27, 2024

I'm also happy to turn prefix into a macro: @prefix(model, :prefix_) if that helps. Then we could do

@returned_quantities @prefix(model, :prefix_)

@torfjelde
Copy link
Member Author

Added a @prefix macro:) See the docstring of @returned_quantities for what it looks like 👍

torfjelde and others added 3 commits October 29, 2024 18:39
…cro' into torfjelde/returned-quantities-macro
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@yebai
Copy link
Member

yebai commented Oct 30, 2024

Thanks, @torfjelde; I'm happy with the changes.

To minimise interface confusion (prefix vs. @prefix, and @returned_quantities vs. returned_quantities), shall we consider keeping only @prefix and @returned_quantities and depreciating generated_quantities and prefix?

Thoughts? @mhauru and @penelopeysm

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For prefix/@prefix, maybe keep both but only export the macro? It sounds like unless you know what you are doing, you should use @prefix. And if you know what you're doing, you don't need it be exported. I do generally think it's a good idea to have a macro-free option available if possible.

For returned_quantities/@returned_quantities we still need both, because one is to be used outside of @model, the other inside, right? I forget what we concluded about this in our call, but I do worry users will mix the two up and get confusing errors.

src/submodel_macro.jl Outdated Show resolved Hide resolved
true

julia> # Or using some arbitrary expression.
@model submodel_prefix_expr() = a = @returned_quantities prefix=1 + 2 inner()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found

@returned_quantities prefix=1 + 2 inner()

hard and unintuitive to parse. I think

@returned_quantities prefix=(1 + 2) inner()

would be much clearer. Not sure if this a documentation issue, or if we should disallow the former.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a documentation issue IMO, as this is not doing any special parsing but reliying on Julia's expression parsing.

docs/src/api.md Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Oct 30, 2024

For returned_quantities/@returned_quantities we still need both, because one is to be used outside of @model, the other inside, right?

generated_quantities allows users to fix model parameter values and/or accept MCMC chain objects.
We can throw an error if users try to pass fixed parameter values or chain objects to @returned_quantities called within a model.

Then, @returned_quantities can match generated_quantities / returned_quantities exactly, thus allowing us to remove the generated_quantities / returned_quantities altogether.

@torfjelde
Copy link
Member Author

Then, @returned_quantities can match generated_quantities / returned_quantities exactly, thus allowing us to remove the generated_quantities / returned_quantities altogether.

Just so we're all on the same page: @returned_quantities and returned_quantities will not match since the former only takes a single argument, while the other takes two, right? If so, then why would we want to raise explicit errors for incorrect arguments provided vs. just letting Julia raise the "not implemented error"?

@torfjelde
Copy link
Member Author

Deprecated generated_quantities in favour of returned_quantities + removed the prefix=... argument for @prefix.

@@ -141,8 +141,12 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
value, logp, vi = tilde_assume(context, right, vn, vi)
return value, acclogp_assume!!(context, vi, logp)
return if is_rhs_model(right)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be generalized as we desire, e.g. if want to do something special with lantent(model), we can overload this to be true and then overload rand_like!!

@@ -197,6 +201,11 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(context, right, left, vname, vi)
is_rhs_model(right) && throw(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we want "more" things to be allowed on right, we can easily deal with this by just generalizing is_rhs_model.

src/model.jl Outdated Show resolved Hide resolved
src/model.jl Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

Accepted your suggestions @penelopeysm 👍

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to approve once you're happy and if CI passes 😄

Comment on lines +1252 to +1257
"""
is_rhs_model(x)

Return `true` if `x` is a model or model wrapper, and `false` otherwise.
"""
is_rhs_model(x) = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this return true for Models themselves? It seems to me that it's only true for a ReturnedModelWrapper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont't want to allow usage of Model on the RHS. So yeah, should probably rename the function. Long-term we probably just want something like is_valid_rhs_tilde(...) ,etc.

@torfjelde
Copy link
Member Author

Okay, so the final call has been made and we'll introduce several wrappers:

  • struct RerturnedModelWrapper: wraps a Model to indicate that it's a model over the return-values rather than its latents.
  • abstract Distributional: a type indicating that something is distributional in some sense.
  • struct Sampleable <: Distributional: a wrapper type that specifies that whatever it wraps is sampleable.

In combination with these types, we'll introduce the following functionality

to_sampleable(x) = Sampleable(x)
returned(model::Model) = ReturnedModelWrapper(model)
to_submodel(model::Model) = to_sampleable(returned(model))

At the moment, the Sampleable and Distributional (and their corresponding methods) doesn't have any particular purpose, but I believe @yebai has some plans for them in the future.

Moreover, the way we do it currently is that anything is allowed to on the RHS of a ~, as long as you override the correct methods.
But I think the plan (@yebai should confirm this) is that going forward we only allow subtypes of Distribution (from Distributions.jl) and Distributional (our type) to occur on the RHS of a ~ rather than specifying this by using dispatching / method overloading.

@yebai also mentioned that we want methods such as fix, condition, etc. to be supported for ReturnedModelWrapper, but I'll leave this for future PRs as a) we're not going do document that people can do this atm, so users will just use this inside a @model, and b) this, IMO, requires some more thought as to exactly how to execute. If we do it naively, we'll have to duplicate all of the necessary methods on Model, which is non-ideal. A likely approach is to introduce another abstract type AbstractModel which both Model and ReturnedModelWrapper inherits from; we could then define fix, etc. on this abstract type and only overload contextualize for the specific method in question.
But because this has dragged on for so long, let's leave this for another time.

@torfjelde
Copy link
Member Author

Encountered another problem with ~ 😕

If m is used both inside the submodel and for the return-value (which is not uncommon, since the return value might very well be a random variable in the model), we can't distinguish the two:

julia> @model demo_inner() = m ~ Normal()
demo_inner (generic function with 2 methods)

julia> @model function demo_outer()
           m ~ to_submodel(demo_inner())
           return m
       end
demo_outer (generic function with 2 methods)

julia> model = demo_outer();

julia> model()  1.0
true

julia> conditioned_model = model | (m = 1.0, );

julia> conditioned_model()
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
...

@penelopeysm
Copy link
Member

Encountered another problem with ~

:/

Does prefixing help?

@torfjelde
Copy link
Member Author

Does prefixing help?

Aye, this was the first solution that popped into my mind too. Though this will technically fix it, we don't have a way of warning the user about this properly 😕

@mhauru
Copy link
Member

mhauru commented Nov 27, 2024

A bit stumped by this at first sight. I think the desired user experience would be an error saying "variable names in the outer and inner model conflict, you must use prefixing", but not sure how to implement that neatly. I'm guessing there's a reason you haven't implemented that already when e.g. the user uses the same submodel twice without prefixing.

@yebai
Copy link
Member

yebai commented Nov 27, 2024

In general, for errors that are hard or impossible to catch (at compile time), we could add a tool to check_model so users get more informative error messages:

https://turinglang.org/DynamicPPL.jl/stable/api/#Debugging-Utilities

@torfjelde
Copy link
Member Author

A bit stumped by this at first sight. I think the desired user experience would be an error saying "variable names in the outer and inner model conflict, you must use prefixing", but not sure how to implement that neatly. I'm guessing there's a reason you haven't implemented that already when e.g. the user uses the same submodel twice without prefixing.

Exactly 😕

In general, for errors that are hard or impossible to catch (at compile time), we could add a tool to check_model so users get more informative error messages:

This crossed my mind, but the issue is ofc that this is only done upon explicit call by the user or in sample (when using Turing.jl only) and so it's quite limited in the sort of information it provides to the user 😕 This is the sort of issue that can be quite subtle, and so not erroring when used improperly is a bit scary I think.

But yeah, don't see a better solution atm. The other solution is to either force usage of prefixing or just always prefix (none of which are ideal either).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

submodel and generated_quantities operations on models
7 participants