-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Taking AD seriously #28
Comments
Good idea, incidentally this is quite high on my priority list as well (though I expect to mostly need Jacobians).
Yes, Pymanopt does the first thing. I think geomstats also has some good ideas in this area, see for example the way they solved the problem of exp/log on sphere being very hard to AD through: https://github.com/geomstats/geomstats/blob/048a99dc9dff3f86e42025264ac0811ed6888f7c/geomstats/geometry/hypersphere.py#L533-L588 . I think this may be one of the bigger challenges, many functions we use need to be special-cased around zero, and while it's not a problem for finite differences, it is for AD. Using Taylor expansions to some degree may be necessary.
Gradient projections and
I don't quite understand at the moment what the intrinsic geometric AD is. Probably working in charts so I think relying on ONBs of tangent spaces is not the greatest idea? I've seen this paper: https://arxiv.org/abs/1812.11592 but I don't understand why they don't have to care about the metric. One thing I know is that |
Why would you classify ChainRules.jl's forward mode support as experimental? ChainRules just hit v1.0, which signals that its API is very stable. By comparison, ForwarDiff has not yet hit v1.0.
This is true if we want to support ADs like ForwardDiff and ReverseDiff, which are not likely to use ChainRules in the near future, but IMO we should start with ChainRules support, since that brings compatibility with 6 AD engines, including Diffractor. The real issue is that none of the reverse-mode ADs that use ChainRules support mutating functions (I think), and we've built our entire interface around mutating functions. So that might mean we only support forward-mode AD for now. It's not clear to me how/why AD on manifolds must differ from normal AD. If we're working in a chart, shouldn't there be no difference? And how do we plan to use the AD? Note that timing-wise, it might make sense to hold off on an implementation until AbstractDifferentiation.jl is released. It aims to unify the APIs of the different AD engines, which replaces the machinery many packages like ours have to allow the user to select an AD backend. |
I would mainly be interested in gradients and (approximate) Hessians
that looks indeed quite technical. Maybe one could rather provide the differential manually for those two as I do in Manopt.jl (though therein using the framework of Jacobi fields, which might be nice to use in general anyways).
I know, I just do not yet completely know how to treat both functions to the generality we usually have here in our package.
With intrinsic I mean a method that does not rely on the specific representation at hand (i.g. for hyperbolic should work the same for all 3 point types / representations) nor does it rely on an embedding (since that might increase the dimension) but really just works on (co-)tangent vectors, bases of (co-)tangent spaces and generic tools like vector transport. This would yield methods that never require an embedding and for best of cases only seldomly chart. |
That was my take-home-message from what I learned about ChainRules at JuliaCon.
yeah, I am also not sure about this yet, either; we need the mutating functions for speed here, sure. But I think my first issue would actually be the embedding vs intrinsic approaches (see last comment)
As I wrote, there would be two – maybe then 3 – approaches
Sounds reasonable, sure. We could still sketch ideas and start some code, maybe. As I just wrote, the first would be to figure out which modes we have (I think now we are at 3 different ways to do it). |
Mostly because I don't know of any stable forward mode AD library that uses ChainRules.jl. Is there one? ForwardDiff.jl's last breaking release was in 2018, so I'd say it's stable.
I think we could just gradually introduce non-mutating variants as we need them.
The thing is, we don't want to be constantly working in a chart. one we get a gradient (for example), we want to perform retraction using a closed-form formula, and most of our formulas work in an embedding. Even if we never work in an embedding, we may need to switch a chart. There was some recent work on chart-based optimization and normalizing flows on manifolds that attempt to address these issues: https://arxiv.org/abs/1909.09501 and https://arxiv.org/abs/2006.10254 . AD seems to be mostly useful for gradients and Hessians in optimization. I'm going to experiment with Jacobians for continuous normalizing flows.
I think there are many other issues we can solve without waiting for that.
OK but how does one manually provide such differentials to AD? I've done such things for ForwardDiff.jl but the method only works for specific AD libraries (though ChainRules.jl may help?) and is quite ugly.
That's a very ambitious goal, and I don't know what generic tools are sufficient to solve this without making it too slow. |
I did not say intrinsic is easy, nor that there is much work done in this direction yet – but it would be really cool if something like that would work :) |
I'm still not clear on this. Is there a write-up somewhere of how AD on manifolds should/must differ from standard AD? Or put another way, what quantities are you hoping to get out of AD? Because standard AD will give you "tangents" (derivatives of the real values stored in your points wrt some upstream real scalar) and "cotangents" (derivatives of some downstream real scalar wrt the real values stored in your points). How does this differ from what you want from a manifold AD? |
Oh, the result types will not differ, also inputs will not differ. We will work with (co)tangents for sure. The question is more about the “computational path” to take. Let's look at the first variant (embedding) and an example – maybe an easy one – isometrically embedded – the sphere. If we have a function like the Rayleigh quotient on the sphere ( A further disadvantage is, that a manifold might not have such an embedding... Similarly for the third approach – lets take again the Rayleigh quotient. We could define it as a function in a chart, i.e. from R^n (for S^n) in stead of R^[n+1} in the embedding as F = f(g(c)), where g is a chart and c are the coordinates of x in the chart (. For the last one, i.e. we have Df or D*f or grad f for simple building blocks/functions on a manifold and define chain rules for these functions and derive an intrinsic AD (might involve the Riemannian curvature tensor and such, not sure yet), we would maybe not have those overheads (up to curvature). There is this one paper from 2018 describing this a little bit but I have not seen this in source code. So in the end all 3 hopefully do the same, and we might do some of them (just maybe not the chart one? I am not sure). They all have the same input and output, but their inner mechanisms are different. Does this help? |
Here is the quick sketch or a function ChainRulesCore.rrule(::typeof(distance), M::AbstractManifold, p, q)
d = distance(M, p, q)
function distance_pullback(Ȳ)
return NoTangent(), NoTangent(), -(Ȳ/d) * log(M, p, q), -(Ȳ/d) * log(M, q, p)
end
return d, distance_pullback
end |
And note that for the extrinsic approach, the PRs JuliaManifolds/Manifolds.jl#423 and JuliaManifolds/Manifolds.jl#427 will bring the desired functionality at least for gradients (Hessians is the next thing I want to get into) – especially they provide a more general approach than Manopt/Matlabs This is quite some step to the first approach mentioned |
I think I would like to get into a proper way of handling AD on manifolds. I know we have quite some issues open here (#17, JuliaManifolds/Manifolds.jl#42, #27, JuliaManifolds/Manifolds.jl#88, #29) and we also have some support already. But we have not made much changes recently. Maybe this would be a good topic to tackle next.
In general I see two ways to go
The first is, what I think pymanopt does, the second is to some extend done in manopt with finite differences if I understand correctly, and I tried to do something like that for the Hessian here but did not yet find the time to continue that.
The main challenge with the first point, I think, is the conversion of an Euclidean gradient in the embedding to the Riemannian gradient, sometimes that is just a projection, sometimes that requires an adaption/calculation to adapt the metric. I have not yet understood for example all implications we would require to have something like
ehess2rhess
for the Hessian transform.The main challenge for the second point is that from the building blocks (several are available in Manopt.jl, i.e. basic differentials, adjoint differentials, and gradients) we have to provide ChainRules.jl? The good thing is that we already have tangent and cotangent vectors.
For me the main questions / open points are
egrad2rgrad
/ehess2rhess
to complete approach 1?Finally, would we want to do it here (we can start here for sure) or do we want to to a ManifoldsAD.jl package?
Let's keep this topic as an overview for this topic for now. Feel free to add more points to the list of things to do.
The text was updated successfully, but these errors were encountered: