Skip to content

Commit 3705f60

Browse files
committed
switch from Ref to AtomicRef
1 parent 697e1fb commit 3705f60

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

src/StableTasks.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@ module StableTasks
33
macro spawn end
44
macro spawnat end
55

6-
using Base: RefValue
7-
struct StableTask{T}
8-
t::Task
9-
ret::RefValue{T}
6+
mutable struct AtomicRef{T}
7+
@atomic x::T
8+
AtomicRef{T}() where {T} = new{T}()
9+
AtomicRef(x::T) where {T} = new{T}(x)
10+
AtomicRef{T}(x) where {T} = new{T}(convert(T, x))
11+
end
12+
13+
mutable struct StableTask{T}
14+
const t::Task
15+
ret::AtomicRef{T}
1016
end
1117

1218
include("internals.jl")

src/internals.jl

+27-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
module Internals
22

3-
import StableTasks: @spawn, @spawnat, StableTask
3+
import StableTasks: @spawn, @spawnat, StableTask, AtomicRef
4+
5+
Base.getindex(r::AtomicRef) = @atomic r.x
6+
Base.setindex!(r::AtomicRef{T}, x) where {T} = @atomic r.x = convert(T, x)
47

58
function Base.fetch(t::StableTask{T}) where {T}
69
fetch(t.t)
@@ -25,32 +28,45 @@ Base.schedule(t::StableTask) = (schedule(t.t); t)
2528
Base.schedule(t, val; error=false) = (schedule(t.t, val; error); t)
2629

2730

28-
macro spawn(ex)
31+
macro spawn(args...)
32+
tp = QuoteNode(:default)
33+
na = length(args)
34+
if na == 2
35+
ttype, ex = args
36+
if ttype isa QuoteNode
37+
ttype = ttype.value
38+
if ttype !== :interactive && ttype !== :default
39+
throw(ArgumentError("unsupported threadpool in StableTasks.@spawn: $ttype"))
40+
end
41+
tp = QuoteNode(ttype)
42+
else
43+
tp = ttype
44+
end
45+
elseif na == 1
46+
ex = args[1]
47+
else
48+
throw(ArgumentError("wrong number of arguments in @spawn"))
49+
end
2950
letargs = _lift_one_interp!(ex)
3051

3152
thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
3253
var = esc(Base.sync_varname) # This is for the @sync macro which sets a local variable whose name is
3354
# the symbol bound to Base.sync_varname
3455
# I asked on slack and this is apparently safe to consider a public API
35-
set_pool = if VERSION < v"1.9"
36-
nothing
37-
else
38-
:(Threads._spawn_set_thrpool(task, :default))
39-
end
4056
quote
4157
let $(letargs...)
4258
f = $thunk
4359
T = Core.Compiler.return_type(f, Tuple{})
44-
ref = Ref{T}()
60+
ref = AtomicRef{T}()
4561
f_wrap = () -> (ref[] = f(); nothing)
4662
task = Task(f_wrap)
4763
task.sticky = false
48-
$set_pool
64+
Threads._spawn_set_thrpool(task, $(esc(tp)))
4965
if $(Expr(:islocal, var))
5066
put!($var, task) # Sync will set up a Channel, and we want our task to be in there.
5167
end
5268
schedule(task)
53-
StableTask(task, ref)
69+
StableTask{T}(task, ref)
5470
end
5571
end
5672
end
@@ -75,7 +91,7 @@ macro spawnat(thrdid, ex)
7591
let $(letargs...)
7692
thunk = $thunk
7793
RT = Core.Compiler.return_type(thunk, Tuple{})
78-
ret = Ref{RT}()
94+
ret = AtomicRef{RT}()
7995
thunk_wrap = () -> (ret[] = thunk(); nothing)
8096
local task = Task(thunk_wrap)
8197
task.sticky = true

0 commit comments

Comments
 (0)