Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-add expansion of Enzyme.@import_{r,f}rule to BijectorsEnzymeExt #346

Closed
wants to merge 3 commits into from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 6, 2024

Since Enzyme now works, but Julia hasn't been fixed, this re-enables the extension. See #339 (comment).

In the process, it aso re-enables tests that are no longer broken.

@penelopeysm penelopeysm force-pushed the py/enzyme-ext branch 3 times, most recently from f9ee0d5 to d3eafb2 Compare November 6, 2024 22:05
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO that's incomprehensible. What's the reason for not just implementing the forward and reverse-mode Enzyme rules? The function and the partial derivatives are very simple, and a direct implementation is likely faster and avoids unnecessary overhead. As also mentioned in the docstrings of @import_rrule and @import_frule, these macros most likely do not give you the most performant implementation (and they might lead to broken rules but I assume it's fine in this simple case here).

@penelopeysm
Copy link
Member Author

I agree it is very incomprehensible, but we discussed this previously (on Slack) and none of us really have the time to figure out how to rewrite the rule ourselves. (I definitely don't, I don't know if anyone else's schedule has freed up, though I doubt it.) @wsmoses suggested that we could, in the worst case, run macroexpand, so that's what I've done, and it's only really intended as a stopgap measure until 1.11.2 is released, at which point we can get rid of all of this.

@penelopeysm penelopeysm force-pushed the py/enzyme-ext branch 2 times, most recently from 5729b55 to 3f096f8 Compare November 6, 2024 22:44
@devmotion
Copy link
Member

devmotion commented Nov 6, 2024

it's only really intended as a stopgap measure until 1.11.2 is released, at which point we can get rid of all of this.

My point is that even when the Pkg issues are fixed, I think ideally you would use neither @import_frule nor @import_rrule.

@penelopeysm
Copy link
Member Author

I will pass that on :)

@penelopeysm
Copy link
Member Author

But maybe @willtebbutt or @mhauru one of you are better placed to do this (at some point in time)?

@yebai
Copy link
Member

yebai commented Nov 7, 2024

Regardless of how we want to implement these rules, Bijectors are only users of the rules defined here.

So, it feels more natural for these rules to be hosted as an EnzymeRootsExt extension, living in either the Enzyme.jl or the Roots.jl repo.

@yebai
Copy link
Member

yebai commented Nov 7, 2024

@devmotion, since you wrote the ChainRulesCore rule for Roots.jl, would you be happy to do the same for Enzyme, or coach @mhauru to do it?

@yebai
Copy link
Member

yebai commented Nov 7, 2024

Here are the rules for ReverseDiff, which can be adapted for Enzyme:

function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal}
return track(find_alpha, wt_y, wt_u_hat, b)
end
@grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal)
α = find_alpha(value(wt_y), value(wt_u_hat), value(b))
∂wt_y = inv(1 + wt_u_hat * sech+ b)^2)
∂wt_u_hat = -tanh+ b) * ∂wt_y
∂b = ∂wt_y - 1
find_alpha_pullback::Real) =* ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b)
return α, find_alpha_pullback
end

@penelopeysm
Copy link
Member Author

penelopeysm commented Nov 7, 2024

Is it sensible to merge this first and keep track of the proper rule-writing in another issue?

Right now we have a missing rule and that will make things break.

If we'd rather let things sit broken for now, that's fine with me too! Just thought to prod.

@yebai
Copy link
Member

yebai commented Nov 7, 2024

It doesn't matter much since the code is related to normalising flows, which Turing does not use. It is more for users who directly work with Bijectors.

I'd strongly encourage that these rules got correctly fixed and moved to an EnzymeRoots extension instead of being kept here.

@devmotion
Copy link
Member

@devmotion, since you wrote the ChainRulesCore rule for Roots.jl, would you be happy to do the same for Enzyme, or coach @mhauru to do it?

Well, one could do that. But that's again too generic for Bijectors IMO. The only interest here is one specific function, so only a rule for this specific function is needed. I assume generally performance of such a specific rule will be better than performance of generic Roots.jl rules.

I should have some time this evening, I'll open a draft PR and see how far I can get.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants