Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rrule of KwFunc for Core.kwcall #270

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

nmheim
Copy link
Contributor

@nmheim nmheim commented Feb 13, 2024

This seems to fix non-differentiable keyword arguments by constructing the KwFunc defined in Diffractor.

@@ -244,6 +244,10 @@ struct KwFunc{T,S}
end
(kw::KwFunc)(args...) = kw.kwf(args...)

function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
rrule(KwFunc(f), kwargs, f, args...)
Copy link
Member

@oxinabox oxinabox Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this

Suggested change
rrule(KwFunc(f), kwargs, f, args...)
rrule(f, args...; kwargs...)

is that the same, or is it different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@nmheim nmheim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not exactly sure why the KwFunc struct is needed though.. it seems like could be done via rrule(::typeof(Core.kwcall), kwargs, f, args...) directly?

Copy link
Contributor Author

@nmheim nmheim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the KwFunc and dispatching on kwcall directly seems to work, but I was afraid to remove something which I don't exactly understand

function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
    r = Core.kwfunc(rrule)(kwargs, rrule, f, args...)
    if r === nothing
        return nothing
    end
    x, back = r
    x, Δ->begin
        (NoTangent(), NoTangent(), back(Δ)...)
    end
end

Copy link
Member

@oxinabox oxinabox Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this might be the thing that is there to avoid ADing through so much of the kwarg machinery in the nested AD case

@oxinabox
Copy link
Member

oxinabox commented Feb 13, 2024

@Keno I think this does not have the same concerns that #266 has, but I am not sure.
Because for reverse mode everything is already bad.

Should we do this, and reconsider it later as needed?
It unbreaks a lot of tests

@nmheim nmheim mentioned this pull request Feb 21, 2024
@Keno
Copy link
Collaborator

Keno commented Feb 28, 2024

Same as in #266 I think in order for this not to make inference worse, the method should be split into kw and non-kw versions.

Comment on lines +228 to +232
function (::∂⃖{N})(::typeof(Core.kwcall), kwargs, f::T, args...) where {T, N}
if N == 1
# Base case (inlined to avoid ambiguities with manually specified
# higher order rules)
z = rrule(DiffractorRuleConfig(), KwFunc(f), kwargs, f, args...)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its basically a copy of the non-kw version of the function, but that is what we want in order to avoid ADing through the kw machinery, if there are no kws if I understand correctly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants