Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
cscherrer committed Nov 17, 2019
1 parent 1512d5e commit 16912d7
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/core/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

function type2model(::Type{Model{A,B,M}}) where {A,B,M}
args = [fieldnames(A)...]
body = interpret(B)
body = from_type(B)
Model(from_type(M), convert(Vector{Symbol},args), body)
end

Expand Down
4 changes: 2 additions & 2 deletions src/importance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ using MonteCarloMeasurements

export importanceSample
@inline function importanceSample(p::JointDistribution, q::JointDistribution, _data)
return _importanceSample(p.model, p.args, q.model, q.args, _data)
return _importanceSample(getmodule(p.model), p.model, p.args, q.model, q.args, _data)
end

@gg function _importanceSample(p::Model, _pargs, q::Model, _qargs, _data)
@gg M function _importanceSample(M::Module, p::Model, _pargs, q::Model, _qargs, _data)
p = type2model(p)
q = type2model(q)

Expand Down
6 changes: 3 additions & 3 deletions src/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ parts(d::iid; N=1000) = map(1:d.size) do j parts(d.dist) end


@inline function particles(m::JointDistribution)
return _particles(m.model, m.args)
return _particles(getmodule(m.model), m.model, m.args)
end

@gg function _particles(_m::Model, _args)
@gg M function _particles(M::Module, _m::Model, _args)
type2model(_m) |> sourceParticles() |> loadvals(_args, NamedTuple())
end

@gg function _particles(_m::Model, _args::NamedTuple{()})
@gg M function _particles(M::Module, _m::Model, _args::NamedTuple{()})
type2model(_m) |> sourceParticles()
end

Expand Down
4 changes: 2 additions & 2 deletions src/primitives/likelihood-weighting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
export weightedSample

function weightedSample(m::JointDistribution, _data)
return _weightedSample(m.model, m.args, _data)
return _weightedSample(getmodule(m.model), m.model, m.args, _data)
end

@gg function _weightedSample(_m::Model, _args, _data)
@gg M function _weightedSample(M::Module, _m::Model, _args, _data)
type2model(_m) |> sourceWeightedSample(_data) |> loadvals(_args, _data)
end

Expand Down
4 changes: 2 additions & 2 deletions src/primitives/logpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
export logpdf

function logpdf(m::JointDistribution,x, method=logpdf)
return method(m.model, m.args, x)
return method(getmodule(m.model), m.model, m.args, x)
end

@gg function logpdf(_m::Model, _args, _data)
@gg M function logpdf(M::Module, _m::Model, _args, _data)
type2model(_m) |> sourceLogpdf() |> loadvals(_args, _data)
end

Expand Down
9 changes: 5 additions & 4 deletions src/primitives/rand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ export rand
EmptyNTtype = NamedTuple{(),Tuple{}} where T<:Tuple

@inline function rand(m::JointDistribution)
return _rand(m.model, m.args)
@show getmodule(m.model)
return _rand(getmodule(m.model), m.model, m.args)
end

@inline function rand(m::Model)
return _rand(m, NamedTuple())
return _rand(getmodule(m), m, NamedTuple())
end

@gg function _rand(_m::Model, _args)
@gg M function _rand(M::Module, _m::Model, _args)
type2model(_m) |> sourceRand() |> loadvals(_args, NamedTuple())
end

@gg function _rand(_m::Model, _args::NamedTuple{()})
@gg M function _rand(M::Module, _m::Model, _args::NamedTuple{()})
type2model(_m) |> sourceRand()
end

Expand Down
4 changes: 2 additions & 2 deletions src/primitives/xform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ export xform


function xform(m::JointDistribution{A, B}, _data) where {A,B}
return _xform(m.model, m.args, _data)
return _xform(getmodule(m.model), m.model, m.args, _data)
end

@gg function _xform(_m::Model{Asub,B}, _args::A, _data) where {Asub, A,B}
@gg M function _xform(M::Module, _m::Model{Asub,B}, _args::A, _data) where {Asub, A,B}
type2model(_m) |> sourceXform(_data) |> loadvals(_args, _data)
end

Expand Down
4 changes: 2 additions & 2 deletions src/symbolic/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ end
export reduce

function reduce(m::JointDistribution,x)
return _reduce(m.model, m.args, x)
return _reduce(getmodule(m.model), m.model, m.args, x)
end

@gg function _reduce(_m::Model, _args, _data)
@gg M function _reduce(M::Module, _m::Model, _args, _data)
type2model(_m) |> sourceReduce() |> loadvals(_args, _data)
end

Expand Down
8 changes: 4 additions & 4 deletions src/symbolic/symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function __init__()
))

@eval begin
@gg function codegen(_m::Model, _args, _data)
@gg M function codegen(M::Module, _m::Model, _args, _data)
f = _codegen(type2model(_m))
:($f(_args, _data))
end
Expand Down Expand Up @@ -430,14 +430,14 @@ end
symlogpdf(d,x::Sym) = logpdf(d,x)

function symlogpdf(m::JointDistribution)
return _symlogpdf(m.model)
return _symlogpdf(getmodule(m.model), m.model)
end

function symlogpdf(m::Model)
return _symlogpdf(m)
return _symlogpdf(getmodule(m), m)
end

@gg function _symlogpdf(_m::Model)
@gg M function _symlogpdf(M::Module, _m::Model)
type2model(_m) |> canonical |> sourceSymlogpdf()
end

Expand Down

0 comments on commit 16912d7

Please sign in to comment.