Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix-JET-infer #321

Merged
merged 2 commits into from
Feb 15, 2022
Merged

Conversation

thautwarm
Copy link
Collaborator

fix JuliaStaging/GeneralizedGenerated.jl#69 :
Fix the following tests:

using Soss, JET

m1 = @model N begin
    p ~ Uniform()
    x ~ For(N) do j
            Bernoulli(p / j)
        end
    end

@test_opt rand(m1(10))


m3 = @model N begin
    p ~ Uniform()
    f(ctx) = Base.Fix1(ctx) do ctx, j
        Bernoulli(ctx.p / j)
    end
    x ~ For(f((p=p,)), N)
end

@test_opt rand(m3(10))

@thautwarm
Copy link
Collaborator Author

I'm not really sure why creating a gg function using mkfun and calling it later does not work.

@cscherrer
Copy link
Owner

Thanks! I'm getting an error though, does this work for you?

julia> rand(m1(10))
:(:(ERROR: task switch not allowed from inside staged nor pure functions
Stacktrace:
  [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
    @ Base ./task.jl:767
  [2] wait()
    @ Base ./task.jl:837
  [3] uv_write(s::Base.TTY, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:992
  [4] unsafe_write(s::Base.TTY, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1064
  [5] unsafe_write
    @ ./io.jl:362 [inlined]
  [6] write
    @ ./strings/io.jl:244 [inlined]
  [7] print
    @ ./strings/io.jl:246 [inlined]
  [8] show_unquoted_quote_expr(io::IOContext{Base.TTY}, value::Any, indent::Int64, prec::Int64, quote_level::Int64)
    @ Base ./show.jl:1685
  [9] show(io::Base.TTY, ex::Expr)
    @ Base ./show.jl:1304
 [10] show(x::Expr)
    @ Base ./show.jl:393
 [11] #s54#41
    @ ~/git/Soss.jl/src/primitives/interpret.jl:96 [inlined]
 [12] var"#s54#41"(MC::Any, T::Any, ::Any, _mc::Any, #unused#::Any, _cfg::Any, _ctx::Any)
    @ Soss ./none:0
 [13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [14] #rand#46
    @ ~/git/Soss.jl/src/primitives/rand.jl:35 [inlined]
 [15] rand
    @ ~/git/Soss.jl/src/primitives/rand.jl:34 [inlined]
 [16] #rand#44
    @ ~/git/Soss.jl/src/primitives/rand.jl:19 [inlined]
 [17] rand(m::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{102}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}})
    @ Soss ~/git/Soss.jl/src/primitives/rand.jl:19
 [18] top-level scope
    @ REPL[4]:1

@cscherrer
Copy link
Owner

I think this error comes from Base.show(xs), but without that things still aren't quite there:

julia> rand(m1(10))
ERROR: UndefVarError: _mc not defined
Stacktrace:
 [1] getproperty
   @ ./Base.jl:35 [inlined]
 [2] macro expansion
   @ ~/git/Soss.jl/src/primitives/interpret.jl:75 [inlined]
 [3] mkfun_call(_mc::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{102}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, #unused#::typeof(Soss.tilde_rand), _cfg::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, _ctx::NamedTuple{(), Tuple{}})
   @ Soss ~/git/Soss.jl/src/primitives/interpret.jl:75
 [4] #rand#46
   @ ~/git/Soss.jl/src/primitives/rand.jl:35 [inlined]
 [5] rand
   @ ~/git/Soss.jl/src/primitives/rand.jl:34 [inlined]
 [6] #rand#44
   @ ~/git/Soss.jl/src/primitives/rand.jl:19 [inlined]
 [7] rand(m::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{102}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}})
   @ Soss ~/git/Soss.jl/src/primitives/rand.jl:19
 [8] top-level scope
   @ REPL[4]:1

But I do think we're headed in the right direction.

@cscherrer
Copy link
Owner

I found this to work for the test case:

@gg function mkfun_call(_mc::MC, ::T, _cfg, _ctx) where {MC, T}
    _m = type2model(MC)
    M = getmodule(_m)

    _args = argvalstype(MC)
    _obs = obstype(MC)

    tilde = T.instance
    body = _m.body |> loadvals(_args, _obs)
    body = _interpret(M, body, tilde, _args, _obs)

    q = MacroTools.flatten(quote
            local _retn
            _args = Soss.argvals(_mc)
            _obs = Soss.observations(_mc)
            _cfg = merge(_cfg, (args=_args, obs=_obs))
            $body
            _retn
        end)

    @under_global M q
end

Then I get

julia> rand(m1(10))
(p = 0.880486, x = Bool[1, 1, 0, 1, 0, 0, 0, 0, 0, 0])

julia> @test_opt rand(m1(10))
Test Passed
  Expression: #= REPL[9]:1 =# JET.@test_call analyzer = OptAnalyzer rand(m1(10))

But then I'm not sure how robust it is to using from a different module, etc. There are a few things here that don't yet make sense to me...

First, I think I see what mk_expr is trying to do, but I can't yet get it to work.

I see this idiom come up a lot:

M = ...
@q let M
    ... # body of the function
end

I don't see how this can work. M is defined before the quote, so wouldn't it have to be interpolated into the expression?

And then for testing... Is there a minimal "calling things from other modules" setup you'd use to test for robustness of this sort of thing? Does precompilation complicate that?

@thautwarm
Copy link
Collaborator Author

thautwarm commented Feb 15, 2022

Just remove the invocation of Base.show in mkfun_call

@cscherrer
Copy link
Owner

Does that work for you? If I do that I get an error that _mc is not found.

@thautwarm
Copy link
Collaborator Author

Oh, I don't know why but it now raises "_mc" undefined.
It shall be caused by some cleanup before PR. Fixing it.

@thautwarm thautwarm requested a review from cscherrer February 15, 2022 04:02
@thautwarm
Copy link
Collaborator Author

Updated. Could you please have a try?

@thautwarm
Copy link
Collaborator Author

It works locally now.

julia> rand(m3(10))
(p = 0.189824138731166, x = Bool[0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

julia> rand(m1(10))
(p = 0.421553717757589, x = Bool[0, 1, 0, 0, 0, 0, 0, 0, 0, 0])

julia> @test_opt rand(m1(10))
Test Passed
  Expression: #= REPL[6]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m1(10))

M is defined before the quote, so wouldn't it have to be interpolated into the expression?

I don't M is explicitly used in the code generated by Soss, but it is to resolve global variables like +, *, etc.
We just pass M to the first argument of mk_function/mk_expr, which means the callee module.

If any generated code is explicitly visiting M, please let me know and I make a minor fix to support visiting M.

Copy link
Owner

@cscherrer cscherrer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works great, thank you!

@cscherrer cscherrer merged commit 4dc0bba into cscherrer:astmodels-dev Feb 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants