From e9a5bc77744d5672337b33a6415451b44b150047 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 3 Nov 2023 14:51:32 -0400 Subject: [PATCH 1/5] dualT4 should not be promoted with T since T might already be a dual --- src/norecompile.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 8446c8bcc..ecee30d44 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -57,7 +57,7 @@ function wrapfun_iip(ff, dualT = dualgen(T) dualT1 = ArrayInterface.promote_eltype(T1, dualT) dualT2 = ArrayInterface.promote_eltype(T2, dualT) - dualT4 = dualgen(promote_type(T, T4)) + dualT4 = dualgen(T4) iip_arglists = (Tuple{T1, T2, T3, T4}, Tuple{dualT1, dualT2, T3, T4}, From 5134ffeb656068182144e9188c65fb013258483f Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 6 Nov 2023 15:33:47 -0500 Subject: [PATCH 2/5] comment and remove 4th method --- src/norecompile.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index ecee30d44..1f0406050 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -59,10 +59,10 @@ function wrapfun_iip(ff, dualT2 = ArrayInterface.promote_eltype(T2, dualT) dualT4 = dualgen(T4) - iip_arglists = (Tuple{T1, T2, T3, T4}, - Tuple{dualT1, dualT2, T3, T4}, - Tuple{dualT1, T2, T3, dualT4}, - Tuple{dualT1, dualT2, T3, dualT4}) + iip_arglists = (Tuple{T1, T2, T3, T4}, # primal + Tuple{dualT1, dualT2, T3, T4}, # vjp + Tuple{dualT1, T2, T3, dualT4}, # tgrad + ) iip_returnlists = ntuple(x -> Nothing, 4) From 88c7527a94ce9e0381ffbfacf4f43d290820e7c4 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 6 Nov 2023 16:05:46 -0500 Subject: [PATCH 3/5] typo --- src/norecompile.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 1f0406050..52819287b 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -64,7 +64,7 @@ function wrapfun_iip(ff, Tuple{dualT1, T2, T3, dualT4}, # tgrad ) - iip_returnlists = ntuple(x -> Nothing, 4) + iip_returnlists = ntuple(x -> Nothing, length(iip_arglists)) fwt = map(iip_arglists, iip_returnlists) do A, R FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff)) From 124bbb60c74d00dc481c50f6fdef19accaf6949e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 8 Nov 2023 15:08:12 -0500 Subject: [PATCH 4/5] better fix --- src/norecompile.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 52819287b..01d3aa198 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -58,10 +58,12 @@ function wrapfun_iip(ff, dualT1 = ArrayInterface.promote_eltype(T1, dualT) dualT2 = ArrayInterface.promote_eltype(T2, dualT) dualT4 = dualgen(T4) + dualT4_T = promote_dual(dualT4, dualT) - iip_arglists = (Tuple{T1, T2, T3, T4}, # primal - Tuple{dualT1, dualT2, T3, T4}, # vjp - Tuple{dualT1, T2, T3, dualT4}, # tgrad + iip_arglists = (Tuple{T1, T2, T3, T4}, # primal + Tuple{dualT1, dualT2, T3, T4}, # vjp + Tuple{dualT1, T2, T3, dualT4}, # tgrad + Tuple{dualT1, T2, T3, dualT4_T}, # tgrad inside gradient wrt initial conditions ) iip_returnlists = ntuple(x -> Nothing, length(iip_arglists)) From 5473ea375fc00286fd656bb9b6bd59d9d5e769fd Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 13 Nov 2023 10:37:08 -0500 Subject: [PATCH 5/5] simplify --- src/norecompile.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/norecompile.jl b/src/norecompile.jl index 01d3aa198..d9288547d 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -57,13 +57,11 @@ function wrapfun_iip(ff, dualT = dualgen(T) dualT1 = ArrayInterface.promote_eltype(T1, dualT) dualT2 = ArrayInterface.promote_eltype(T2, dualT) - dualT4 = dualgen(T4) - dualT4_T = promote_dual(dualT4, dualT) + dualT4 = promote_dual(dualgen(T4), dualT) iip_arglists = (Tuple{T1, T2, T3, T4}, # primal Tuple{dualT1, dualT2, T3, T4}, # vjp Tuple{dualT1, T2, T3, dualT4}, # tgrad - Tuple{dualT1, T2, T3, dualT4_T}, # tgrad inside gradient wrt initial conditions ) iip_returnlists = ntuple(x -> Nothing, length(iip_arglists))