Skip to content

Commit ced97ee

Browse files
authored
fix: refactor test loops (#848)
* Improve type stability in correctness tests * Add some fixes * Fixes * Missing conj * Reprepare cov * Add changelog
1 parent 353cac6 commit ced97ee

File tree

10 files changed

+561
-341
lines changed

10 files changed

+561
-341
lines changed

DifferentiationInterfaceTest/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.10.1...main)
99

10+
### Fixed
11+
12+
- Refactor test loops ([#848](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/848))
13+
1014
## [0.10.1](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.10.0...DifferentiationInterfaceTest-v0.10.1)
1115

1216
### Added

DifferentiationInterfaceTest/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.10.1"
4+
version = "0.10.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -12,6 +12,7 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1313
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1516
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -54,6 +55,7 @@ JLArrays = "0.1, 0.2"
5455
LinearAlgebra = "1"
5556
Lux = "1.1.0"
5657
LuxTestUtils = "1.3.1, 2"
58+
PrecompileTools = "1.2.1"
5759
ProgressMeter = "1"
5860
Random = "1"
5961
SparseArrays = "1"

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module DifferentiationInterfaceTestJLArraysExt
33
import DifferentiationInterface as DI
44
import DifferentiationInterfaceTest as DIT
55
using JLArrays: JLArray, JLVector, JLMatrix, jl
6+
using PrecompileTools: @compile_workload
67

78
jl_num_to_vec(x::Number) = sin.(jl([1, 2]) .* x)
89
jl_num_to_mat(x::Number) = hcat(jl_num_to_vec(x), jl_num_to_vec(3x))
@@ -42,4 +43,8 @@ function DIT.gpu_scenarios(args...; kwargs...)
4243
return myjl.(scens)
4344
end
4445

46+
@compile_workload begin
47+
DIT.gpu_scenarios(; include_constantified=true, include_cachified=true)
48+
end
49+
4550
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import DifferentiationInterface as DI
44
import DifferentiationInterfaceTest as DIT
55
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
66
using StaticArrays: StaticArray, MArray, MMatrix, MVector, SArray, SMatrix, SVector
7+
using PrecompileTools: @compile_workload
78

89
static_num_to_vec(x::Number) = sin.(SVector(1, 2) .* x)
910
static_num_to_mat(x::Number) = hcat(static_num_to_vec(x), static_num_to_vec(3x))
@@ -61,4 +62,8 @@ function DIT.static_scenarios(args...; kwargs...)
6162
return mystatic.(scens)
6263
end
6364

65+
@compile_workload begin
66+
DIT.static_scenarios(; include_constantified=true, include_cachified=true)
67+
end
68+
6469
end

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ using DifferentiationInterface: PreparationMismatchError
9494
using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES
9595
using JET: @test_opt
9696
using LinearAlgebra: Adjoint, Diagonal, Transpose, I, dot, parent
97+
using PrecompileTools: @compile_workload
9798
using ProgressMeter: ProgressUnknown, next!
9899
using Random: AbstractRNG, default_rng, rand!
99100
using SparseArrays:
@@ -114,7 +115,16 @@ List of all second-order operators, to facilitate exclusion during tests.
114115
"""
115116
const SECOND_ORDER = [:hvp, :second_derivative, :hessian]
116117

117-
const ALL_OPS = vcat(FIRST_ORDER, SECOND_ORDER)
118+
const ALL_OPS = (
119+
:pushforward,
120+
:pullback,
121+
:derivative,
122+
:gradient,
123+
:jacobian,
124+
:hvp,
125+
:second_derivative,
126+
:hessian,
127+
)
118128

119129
include("utils.jl")
120130

@@ -128,6 +138,7 @@ include("scenarios/empty.jl")
128138
include("scenarios/extensions.jl")
129139

130140
include("tests/correctness_eval.jl")
141+
include("tests/prep_eval.jl")
131142
include("tests/type_stability_eval.jl")
132143
include("tests/benchmark.jl")
133144
include("tests/benchmark_eval.jl")
@@ -140,4 +151,8 @@ export Scenario, compute_results
140151
export test_differentiation, benchmark_differentiation
141152
export DifferentiationBenchmarkDataRow
142153

154+
@compile_workload begin
155+
default_scenarios(; include_constantified=true, include_cachified=true)
156+
end
157+
143158
end

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ function arr_to_num_hessian(x0)
306306
return convert(typeof(similar(x0, length(x0), length(x0))), H)
307307
end
308308

309-
arr_to_num_pushforward(x, dx) = sum(arr_to_num_gradient(x) .* dx)
309+
arr_to_num_pushforward(x, dx) = sum(conj.(arr_to_num_gradient(x)) .* dx)
310310
arr_to_num_pullback(x, dy) = arr_to_num_gradient(x) .* dy
311311
arr_to_num_hvp(x, dx) = reshape(arr_to_num_hessian(x) * vec(dx), size(x))
312312

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ function test_differentiation(
164164
sparsity,
165165
reprepare,
166166
)
167+
test_prep(adapted_backend, scen)
167168
end
168169
yield()
169170
(type_stability != :none) && @testset "Type stability" begin

0 commit comments

Comments
 (0)