Skip to content

Commit

Permalink
Make pullback error for ColVecs and RowVecs a bit more informative (#523
Browse files Browse the repository at this point in the history
)

* make adjoint error message a bit more informative

* Apply suggestions from code review

* Update Project.toml

* Update src/chainrules.jl
  • Loading branch information
torfjelde authored Sep 27, 2023
1 parent 5127a26 commit cf937ce
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.56"
version = "0.10.57"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
12 changes: 8 additions & 4 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`",
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" *
"or because some external computation has acted on `ColVecs` to produce a vector of vectors." *
"In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." *
"In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," *
" rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it."
)
end
return ColVecs(X), ColVecs_pullback
Expand All @@ -162,8 +165,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}})
return error(
"Pullback on AbstractVector{<:AbstractVector}.\n" *
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" *
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n" *
"or because some external computation has acted on `RowVecs` to produce a vector of vectors." *
"If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`",
)
end
return RowVecs(X), RowVecs_pullback
Expand Down

2 comments on commit cf937ce

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92333

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.57 -m "<description of version>" cf937ceed9b2666a812cf1ec7350ac6c5a231d00
git push origin v0.10.57

Please sign in to comment.