Skip to content

Commit 0683088

Browse files
authored
Default float type to float(Real), not Real (#685)
* Default float type to float(Real), not Real Closes #684 * Trigger CI on backport branches/PRs * Add integration test for #684 * Bump Turing version to 0.34 in test subfolder
1 parent cdd3407 commit 0683088

File tree

5 files changed

+21
-4
lines changed

5 files changed

+21
-4
lines changed

.github/workflows/CI.yml

+2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ on:
44
push:
55
branches:
66
- master
7+
- backport-*
78
pull_request:
89
branches:
910
- master
11+
- backport-*
1012
merge_group:
1113
types: [checks_requested]
1214

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.28.4"
3+
version = "0.28.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -811,9 +811,9 @@ end
811811
"""
812812
float_type_with_fallback(x)
813813
814-
Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`.
814+
Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`.
815815
"""
816-
float_type_with_fallback(::Type) = Real
816+
float_type_with_fallback(::Type) = float(Real)
817817
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)
818818

819819
"""

test/turing/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28"
1515
HypothesisTests = "0.11"
1616
MCMCChains = "6"
1717
ReverseDiff = "1.15"
18-
Turing = "0.33"
18+
Turing = "0.34"
1919
julia = "1.7"

test/turing/varinfo.jl

+15
Original file line numberDiff line numberDiff line change
@@ -342,4 +342,19 @@
342342
model = state_space(y, length(t))
343343
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n
344344
end
345+
346+
if Threads.nthreads() > 1
347+
@testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin
348+
@model function f(x)
349+
ns ~ filldist(Normal(0, 2.0), 3)
350+
m ~ Uniform(0, 1)
351+
x ~ Normal(m, 1)
352+
end
353+
model = f(1)
354+
chain = sample(model, NUTS(), MCMCThreads(), 10, 2);
355+
loglikelihood(model, chain)
356+
logprior(model, chain)
357+
logjoint(model, chain)
358+
end
359+
end
345360
end

0 commit comments

Comments
 (0)