Skip to content

Commit 2713370

Browse files
committed
Remove toexpr overload
1 parent 712aabd commit 2713370

File tree

9 files changed

+141
-46
lines changed

9 files changed

+141
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TaylorDiff"
22
uuid = "b36ab563-344f-407b-a36a-4f200bebf99c"
33
authors = ["Songchen Tan <[email protected]>"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"

examples/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
[deps]
2+
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
3+
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
4+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
5+
LayeredLayouts = "f4a74d36-062a-4d48-97cd-1356bad1de4e"
6+
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
27
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
38
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
49
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
510
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
611
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
712
TaylorIntegration = "92b13dbe-c966-51a2-8445-caca9f8a7d42"
813
TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
14+
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"

examples/integration.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ function build_jetcoeffs(f::ODEFunction{iip}, p, ::Val{P}, length = nothing) whe
9090
d = get_coefficient(fu, index - 1) / index
9191
u = append_coefficient(u, d)
9292
end
93-
build_function(u, u0, t0; expression = Val(false), cse = true)
93+
u_term = make_term.(u)
94+
build_function(u_term, u0, t0; expression = Val(false), cse = true)
95+
end
96+
97+
function make_term(a)
98+
term(TaylorScalar, Symbolics.unwrap(a.value), map(Symbolics.unwrap, a.partials))
9499
end
95100

96101
function simplify_scalar_test()
@@ -111,6 +116,13 @@ function simplify_array_test()
111116
@btime $fast_oop($prob.u0, $t0)
112117
end
113118

119+
P = 6
120+
prob = prob_ode_lotkavolterra
121+
t0 = prob.tspan[1]
122+
@btime jetcoeffs($prob.f, $prob.u0, $prob.p, $t0, Val($P))
123+
fast_oop, fast_iip = build_jetcoeffs(prob.f, prob.p, Val(10), length(prob.u0));
124+
@btime $fast_oop($prob.u0, $t0)
125+
114126
@generated function evaluate_polynomial(t::TaylorScalar{T, P}, z) where {T, P}
115127
ex = :(v[$(P + 1)])
116128
for i in P:-1:1

examples/plot_expression.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using Symbolics, SymbolicUtils
2+
using Graphs, GraphMakie, CairoMakie, LayeredLayouts
3+
using Printf
4+
5+
function generate_dag(rexs...)
6+
exs = map(Symbolics.unwrap, rexs)
7+
dag = SimpleDiGraph()
8+
vertex_map = IdDict{Any, Int}()
9+
label_map = Dict{Int, String}()
10+
unicode_replace = Dict(
11+
"-" => "",
12+
"*" => "×",
13+
"/" => "÷"
14+
)
15+
function dfs(node)
16+
if haskey(vertex_map, node)
17+
return vertex_map[node]
18+
end
19+
add_vertex!(dag)
20+
v = nv(dag)
21+
vertex_map[node] = v
22+
if iscall(node)
23+
op = operation(node)
24+
args = map(dfs, arguments(node))
25+
for arg in args
26+
add_edge!(dag, v, arg)
27+
end
28+
rop = repr(op)
29+
label_map[v] = get(unicode_replace, rop, rop)
30+
elseif node isa Number
31+
label_map[v] = @sprintf "%.1g" node
32+
else
33+
label_map[v] = repr(node)
34+
end
35+
return v
36+
end
37+
for ex in exs
38+
dfs(ex)
39+
end
40+
return dag, label_map
41+
end
42+
43+
function plot_dag(dag, label_map)
44+
xs, ys, paths = solve_positions(Zarate(), dag)
45+
for (key, value) in paths
46+
paths[key] = (value[2], -value[1])
47+
end
48+
lay = Point.(zip(ys, -xs))
49+
wp = [Point2f.(zip(paths[e]...)) for e in edges(dag)]
50+
fig, ax, p = graphplot(
51+
dag; layout = lay, ilabels = [label_map[v] for v in 1:nv(dag)],
52+
waypoints = wp, node_color = :white)
53+
hidedecorations!(ax)
54+
hidespines!(ax)
55+
fig
56+
end
57+
58+
@variables a b c d
59+
x = (a + b) * (c + d)
60+
y = (a - b) * (c - d)
61+
z = (a + b) * (c - d)
62+
w = (a - b) * (c + d)
63+
p = x + y
64+
q = z - w
65+
rex = (x + y) / (z - w)
66+
67+
dag, label_map = generate_dag(p, q)
68+
fig = plot_dag(dag, label_map)
69+
save("expression_dag.png", fig)

examples/scaling.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using Symbolics, SymbolicUtils
2+
3+
@variables x
4+
p = Symbolics.variables(:p, 0:20)
5+
6+
struct A{T, P} <: Real
7+
a::T
8+
b::NTuple{P, T}
9+
end
10+
11+
function make_nested_expressions(order)
12+
exprs = [x]
13+
for i in 1:order
14+
term = p[i + 1]
15+
for j in i:-1:1
16+
term += p[j] * exprs[i - j + 1]
17+
end
18+
push!(exprs, term)
19+
end
20+
exprs
21+
end
22+
23+
for order in 1:17
24+
final_expr = make_nested_expressions(order)[end]
25+
print("Order: $order")
26+
build_function(final_expr, x, p; expression = Val(false), cse = true)
27+
@time build_function(final_expr, x, p; expression = Val(false), cse = true)
28+
end
29+
30+
for order in 1:17
31+
exprs = make_nested_expressions(order)
32+
tuple_expr = Symbolics.Code.MakeTuple(Symbolics.unwrap.(exprs))
33+
print("Order: $order")
34+
@time f = build_function(tuple_expr, x, p; expression = Val(false), cse = true)
35+
@show f(1.0, 1.0:5.0)
36+
end
37+
38+
# But it still doesn't scale with the unwrap trick. What else can I do?
39+
for order in 1:17
40+
exprs = make_nested_expressions(order)
41+
struct_expr = term(A, Symbolics.unwrap(exprs[1]), (Symbolics.unwrap.(exprs[2:end])...,))
42+
print("Order: $order")
43+
@time f = build_function(struct_expr, x, p; expression = Val(false), cse = true)
44+
@show f(1.0, 1.0:5.0)
45+
end

examples/test.jl

Lines changed: 0 additions & 34 deletions
This file was deleted.

src/chainrules.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)
5757
end
5858

5959
function rrule(::typeof(*), A::AbstractMatrix{S},
60-
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
60+
t::AbstractVector{TaylorScalar{T, N}}) where {
61+
N, S <: Union{Real, Complex}, T <: Union{Real, Complex}}
6162
project_A = ProjectTo(A)
6263
function gemv_pullback(x̄)
6364
= reinterpret(reshape, T, x̄)
@@ -68,7 +69,8 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
6869
end
6970

7071
function rrule(::typeof(*), A::AbstractMatrix{S},
71-
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
72+
B::AbstractMatrix{TaylorScalar{T, N}}) where {
73+
N, S <: Union{Real, Complex}, T <: Union{Real, Complex}}
7274
project_A = ProjectTo(A)
7375
project_B = ProjectTo(B)
7476
function gemm_pullback(x̄)

src/primitive.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ sincos(t::TaylorScalar) = (sin(t), cos(t))
119119
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
120120
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
121121

122-
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, Complex, RoundingMode)
122+
const AMBIGUOUS_TYPES = (
123+
AbstractFloat, Irrational, Integer, Rational, Real, Complex, RoundingMode)
123124

124125
for op in [:>, :<, :(==), :(>=), :(<=)]
125126
for R in AMBIGUOUS_TYPES

src/utils.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,11 @@
33

44
using ChainRules
55
using ChainRulesCore
6-
using Symbolics: Symbolics, @variables, @rule, @register_symbolic, unwrap, isdiv
6+
using Symbolics: Symbolics, @variables, @rule, unwrap, isdiv
77
using SymbolicUtils.Code: toexpr
88
using MacroTools
99
using MacroTools: prewalk, postwalk
1010

11-
@register_symbolic TaylorScalar(x, y)
12-
function Symbolics.Code.toexpr(t::TaylorScalar, st)
13-
:($TaylorScalar($(Symbolics.Code.toexpr(t.value, st)),
14-
$(Symbolics.Code.toexpr(Symbolics.Code.MakeTuple(t.partials), st))))
15-
end
16-
1711
"""
1812
Pick a strategy for raising the derivative of a function.
1913
If the derivative is like 1 over something, raise with the division rule;

0 commit comments

Comments
 (0)