From 6264afd33f9441697d2f5ed4a75bbcc37f15feba Mon Sep 17 00:00:00 2001 From: "Killian Q. Zhuo" Date: Tue, 22 Oct 2024 20:17:22 +0800 Subject: [PATCH] static_parameter expr (#175) * comment about TypedSlot * find an Instruction dynamically * translate static_parameter expr * translate literal variables * test case * test case for static parameter * drop compat for Julia < 1.10; mark as breaking release * Restrict CI to 1.10 * eval static parameters in args * Test whether Julia 1.7 still works * eval static parameter when binding vars * [skip ci] fix typos --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Penelope Yong --- .github/workflows/Testing.yaml | 1 + Project.toml | 2 +- perf/benchmark.jl | 3 ++- src/Libtask.jl | 3 ++- src/tapedfunction.jl | 18 +++++++++++++++--- test/issues.jl | 21 +++++++++++++++++++++ 6 files changed, 42 insertions(+), 6 deletions(-) diff --git a/.github/workflows/Testing.yaml b/.github/workflows/Testing.yaml index 48d03b62..728f92e1 100644 --- a/.github/workflows/Testing.yaml +++ b/.github/workflows/Testing.yaml @@ -12,6 +12,7 @@ jobs: matrix: version: - '1.7' + - '1.10' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index bc778669..41385419 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8.7" +version = "0.8.8" [deps] FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" diff --git a/perf/benchmark.jl b/perf/benchmark.jl index 2b2312c0..dcfb3638 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -114,7 +114,8 @@ println("======= breakdown benchmark =======") x = rand(100000) tf = Libtask.TapedFunction(ackley, x, nothing) tf(x, nothing); -ins = tf.tape[45] +idx = findlast((x)->isa(x, Libtask.Instruction), tf.tape) +ins = tf.tape[idx] b = ins.input[1] @show ins.input |> length diff --git a/src/Libtask.jl b/src/Libtask.jl index ac349a80..8fa79533 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -9,7 +9,8 @@ export TArray, tzeros, tfill, TRef # legacy types back compat @static if isdefined(Core, :TypedSlot) || isdefined(Core.Compiler, :TypedSlot) - # Julia v1.10 removed Core.TypedSlot + # Julia v1.10 moved Core.TypedSlot to Core.Compiler + # Julia v1.11 removed Core.Compiler.TypedSlot const TypedSlot = @static if isdefined(Core, :TypedSlot) Core.TypedSlot else diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl index 2d947aaf..5e85b87b 100644 --- a/src/tapedfunction.jl +++ b/src/tapedfunction.jl @@ -269,9 +269,10 @@ end const IRVar = Union{Core.SSAValue, Core.SlotNumber} -function bind_var!(var_literal, bindings::Bindings, ir::Core.CodeInfo) - # for literal constants - push!(bindings, var_literal) +function bind_var!(var, bindings::Bindings, ir::Core.CodeInfo) + # for literal constants, and static parameters + var = Meta.isexpr(var, :static_parameter) ? ir.parent.sparam_vals[var.args[1]] : var + push!(bindings, var) idx = length(bindings) return idx end @@ -368,6 +369,14 @@ function translate!!(var::IRVar, line::Core.SlotNumber, return Instruction(func, input, output) end +function translate!!(var::IRVar, line::Number, # literal vars + bindings::Bindings, isconst::Bool, ir) + func = identity + input = (bind_var!(line, bindings, ir),) + output = bind_var!(var, bindings, ir) + return Instruction(func, input, output) +end + function translate!!(var::IRVar, line::NTuple{N, Symbol}, bindings::Bindings, isconst::Bool, ir) where {N} # for syntax (; x, y, z), see Turing.jl#1873 @@ -439,6 +448,9 @@ function translate!!(var::IRVar, line::Expr, end return Instruction(identity, (_bind_fn(rhs),), _bind_fn(lhs)) end + elseif head === :static_parameter + v = ir.parent.sparam_vals[line.args[1]] + return Instruction(identity, (_bind_fn(v),), _bind_fn(var)) else @error "Unknown Expression: " typeof(var) var typeof(line) line throw(ErrorException("Unknown Expression")) diff --git a/test/issues.jl b/test/issues.jl index f9deeffd..370c0235 100644 --- a/test/issues.jl +++ b/test/issues.jl @@ -59,4 +59,25 @@ r = tf(1, 2) @test r == (c=3, x=1, y=2) end + + @testset "Issue-Libtask-174, SSAValue=Int and static parameter" begin + # SSAValue = Int + function f() + # this line generates: %1 = 1::Core.Const(1) + r = (a = 1) + return nothing + end + tf = Libtask.TapedFunction(f) + r = tf() + @test r == nothing + + # static parameter + function g(::Type{T}) where {T} + a = zeros(T, 10) + end + tf = Libtask.TapedFunction(g, Float64) + r = tf(Float64) + @test r == zeros(Float64, 10) + end + end