Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to DynamicPPLBenchmarks #346

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ uuid = "d94a1522-c11e-44a7-981a-42bf5dc1a001"
version = "0.1.0"

[deps]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DiffUtils = "8294860b-85a6-42f8-8c35-d911f667b5f6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9"
34 changes: 29 additions & 5 deletions benchmarks/benchmark_body.jmd
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
```julia
@time model_def(data)();
@time model_def(data...)();
```

```julia
m = time_model_def(model_def, data);
m = time_model_def(model_def, data...);
```

```julia
suite = make_suite(m);
results = run(suite);
results = run(suite, seconds=WEAVE_ARGS[:seconds]);
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

```julia
Expand All @@ -19,11 +19,35 @@ results["evaluation_untyped"]
results["evaluation_typed"]
```

```julia
let k = "evaluation_simple_varinfo_nt"
haskey(results, k) && results[k]
end
```

```julia
let k = "evaluation_simple_varinfo_componentarray"
haskey(results, k) && results[k]
end
```

```julia
let k = "evaluation_simple_varinfo_dict"
haskey(results, k) && results[k]
end
```

```julia
let k = "evaluation_simple_varinfo_dict_from_nt"
haskey(results, k) && results[k]
end
```

```julia; echo=false; results="hidden";
BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(m.name)_benchmarks.json"), results)
```

```julia; wrap=false
```julia; wrap=false; echo=false
if WEAVE_ARGS[:include_typed_code]
typed = typed_code(m)
end
Expand All @@ -37,7 +61,7 @@ end
```

```julia; wrap=false; echo=false;
if haskey(WEAVE_ARGS, :name_old)
if WEAVE_ARGS[:include_typed_code] && haskey(WEAVE_ARGS, :name_old)
# We want to compare the generated code to the previous version.
import DiffUtils
typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(m.name).jls"));
Expand Down
107 changes: 87 additions & 20 deletions benchmarks/benchmarks.jmd
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
# Benchmarks
`j display("text/markdown", "## $(WEAVE_ARGS[:name]) ##")`

## Setup
### Setup ###
yebai marked this conversation as resolved.
Show resolved Hide resolved

```julia
using BenchmarkTools, DynamicPPL, Distributions, Serialization
```

```julia
import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child
using DynamicPPLBenchmarks
using DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child
```

## Models
### Environment

### `demo1`
```julia; echo=false; skip="notebook"
DynamicPPLBenchmarks.display_environment()
```

### Models ###
yebai marked this conversation as resolved.
Show resolved Hide resolved

#### `demo1` ####
yebai marked this conversation as resolved.
Show resolved Hide resolved

```julia
@model function demo1(x)
Expand All @@ -23,14 +30,14 @@ import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child
end

model_def = demo1;
data = 1.0;
data = (1.0, );
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

```julia; results="markup"; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

### `demo2`
#### `demo2` ####
yebai marked this conversation as resolved.
Show resolved Hide resolved

```julia
@model function demo2(y)
yebai marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -43,17 +50,19 @@ weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end

return (; p)
end

model_def = demo2;
data = rand(0:1, 10);
data = (rand(0:1, 10), );
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

```julia; results="markup"; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

### `demo3`
#### `demo3` ####
yebai marked this conversation as resolved.
Show resolved Hide resolved

```julia
@model function demo3(x)
Expand All @@ -76,7 +85,8 @@ weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
yebai marked this conversation as resolved.
Show resolved Hide resolved
end
return k

return (; μ1, μ2, k)
end

model_def = demo3
Expand All @@ -88,43 +98,100 @@ N = 30
μs = [-3.5, 0.0]

# Construct the data points.
data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2);
data = (mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2), );
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

```julia; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

### `demo4`: loads of indexing
#### `demo4`: lots of variables

```julia
@model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV}
@model function demo4_1k(::Type{TV}=Vector{Float64}) where {TV}
m ~ Normal()
x = TV(undef, n)
x = TV(undef, 1_000)
for i in eachindex(x)
x[i] ~ Normal(m, 1.0)
end

return (; m, x)
end

model_def = demo4
data = (100_000, );
model_def = demo4_1k
data = ();
```

```julia; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
```

```julia
@model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV}
@model function demo4_10k(::Type{TV}=Vector{Float64}) where {TV}
m ~ Normal()
x = TV(undef, n)
x = TV(undef, 10_000)
for i in eachindex(x)
x[i] ~ Normal(m, 1.0)
end

return (; m, x)
end

model_def = demo4_10k
data = ();
```

```julia; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

```julia
@model function demo4_100k(::Type{TV}=Vector{Float64}) where {TV}
m ~ Normal()
x = TV(undef, 100_000)
for i in eachindex(x)
x[i] ~ Normal(m, 1.0)
end

return (; m, x)
end

model_def = demo4_100k
data = ();
```

```julia; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```


yebai marked this conversation as resolved.
Show resolved Hide resolved
#### `demo4_dotted`: `.~` for large number of variables

```julia
@model function demo4_100k_dotted(::Type{TV}=Vector{Float64}) where {TV}
m ~ Normal()
x = TV(undef, 100_000)
x .~ Normal(m, 1.0)

return (; m, x)
end

model_def = demo4_dotted
data = (100_000, );
model_def = demo4_100k_dotted
data = ();
```

```julia; echo=false
weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS)
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

```julia; echo=false
if haskey(WEAVE_ARGS, :name_old)
display(MIME"text/markdown"(), "## Comparison with $(WEAVE_ARGS[:name_old]) ##")
end
```

```julia; echo=false
if haskey(WEAVE_ARGS, :name_old)
DynamicPPLBenchmarks.judgementtable(WEAVE_ARGS[:name], WEAVE_ARGS[:name_old])
end
```
70 changes: 70 additions & 0 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ module DynamicPPLBenchmarks

using DynamicPPL
using BenchmarkTools
using InteractiveUtils

using ComponentArrays: ComponentArrays

using Weave: Weave
using Markdown: Markdown
Expand Down Expand Up @@ -32,6 +35,40 @@ function benchmark_typed_varinfo!(suite, m)
return suite
end

function benchmark_simple_varinfo_namedtuple!(suite, m)
# We expect the model to return the random variables as a `NamedTuple`.
retvals = m()

# Populate.
vi = SimpleVarInfo{Float64}(retvals)
vi_ca = SimpleVarInfo{Float64}(ComponentArrays.ComponentArray(retvals))

# Evaluate.
suite["evaluation_simple_varinfo_nt"] = @benchmarkable $m($vi, $(DefaultContext()))
suite["evaluation_simple_varinfo_componentarrays"] = @benchmarkable $m($vi_ca, $(DefaultContext()))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
return suite
end

function benchmark_simple_varinfo_dict!(suite, m)
# Populate.
vi = SimpleVarInfo{Float64}(Dict())
retvals = m(vi)

# Evaluate.
suite["evaluation_simple_varinfo_dict"] = @benchmarkable $m($vi, $(DefaultContext()))

# We expect the model to return the random variables as a `NamedTuple`.
vns = map(keys(retvals)) do k
VarName{k}()
end
vi = SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))

# Evaluate.
suite["evaluation_simple_varinfo_dict_from_nt"] = @benchmarkable $m($vi, $(DefaultContext()))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return suite
end

function typed_code(m, vi=VarInfo(m))
rng = DynamicPPL.Random.MersenneTwister(42)
spl = DynamicPPL.SampleFromPrior()
Expand All @@ -51,6 +88,11 @@ function make_suite(model)
benchmark_untyped_varinfo!(suite, model)
benchmark_typed_varinfo!(suite, model)

if isdefined(DynamicPPL, :SimpleVarInfo)
benchmark_simple_varinfo_namedtuple!(suite, model)
benchmark_simple_varinfo_dict!(suite, model)
end

return suite
end

Expand Down Expand Up @@ -151,6 +193,7 @@ function weave_benchmarks(
name=default_name(; include_commit_id=include_commit_id),
name_old=nothing,
include_typed_code=false,
seconds=10,
doctype="github",
outpath="results/$(name)/",
kwargs...,
Expand All @@ -159,6 +202,7 @@ function weave_benchmarks(
:benchmarkbody => benchmarkbody,
:name => name,
:include_typed_code => include_typed_code,
:seconds => seconds
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
)
if !isnothing(name_old)
args[:name_old] = name_old
Expand All @@ -168,4 +212,30 @@ function weave_benchmarks(
return Weave.weave(input, doctype; out_path=outpath, args=args, kwargs...)
end

function display_environment()
display("text/markdown", "Computer Information:")
vinfo = sprint(InteractiveUtils.versioninfo)
display("text/markdown", """
```
$(vinfo)
```
""")
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

ctx = Pkg.API.Context()

pkg_status = let io = IOBuffer()
Pkg.status(Pkg.API.Context(); io = io)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
String(take!(io))
end

display("text/markdown","""
Package Information:
""")
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

md = "```\n$(pkg_status)\n```"
display("text/markdown", md)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
display("text/markdown", md)
return display("text/markdown", md)

end

include("tables.jl")

end # module
Loading