Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more general TriangleMesh differentiation support #1208

Open
tylerflex opened this issue Oct 17, 2023 · 7 comments · May be fixed by #2095
Open

more general TriangleMesh differentiation support #1208

tylerflex opened this issue Oct 17, 2023 · 7 comments · May be fixed by #2095
Assignees

Comments

@tylerflex
Copy link
Collaborator

No description provided.

@tylerflex tylerflex self-assigned this Oct 17, 2023
@tylerflex tylerflex mentioned this issue Mar 15, 2024
100 tasks
@so-rose
Copy link

so-rose commented May 14, 2024

Just 2c; massive +1 for this.

My integration relies on geometry generated using Blender 3D's "Geometry Nodes", a visual scripting system for intuitive and interactive geometry generation.

With the flexibility of arbitrary triangle meshes comes some caveats; namely, that the scripted parameterized inputs to geometry-generating GeoNodes trees must be realized at the time of generating a td.TriangleMesh object. The mesh generation is on some level just a function, Blender-the-software must be invoked somewhere in order to actually run this function.

I imagine this is far from the most exotic geometry generation floating around. Thus, the optimization loop of any "generate arbitrary geometry w/optimization of symbolic inputs" would probably need to be local+cloud loop, ex. with td.TriangleMesh wrapping a non-jax callback while itself being registered with jax.

So, not easy. Still, if this were to be undertaken, then almost infinite flexibility would be available to us users with very little effort, when it comes to structures that can be optimized! Which is very fun.

@tylerflex
Copy link
Collaborator Author

I think this should be possible.. hopefully in the next few months we can implement it.

@tylerflex
Copy link
Collaborator Author

FYI: We are in the process of deprecating the adjoint (jax) plugin and making tidy3d natively compatible with autograd. So you should be able to just take your existing functions using regular tidy3d components and call autograd.grad() on them without any API modifications.

Once we get this working, implementing things like triangle mesh will be a lot simpler.

@so-rose
Copy link

so-rose commented May 14, 2024

@tylerflex That sounds amazing! I'm glad it's in the plan. Just to make sure I'm understanding correctly; the autograd is in reference to this? https://github.com/HIPS/autograd

If so, two thoughts/worries:

  • HIPS/autograd seems to not be actively developed anymore, per the message on the main page of the repository. Incidentally, the four main devs seem to now be working on jax?
  • On a personal node, I'm relying quite heavily on jax, not just for gradients, but also for JIT optimization + GPU-based processing of output data, sympy support (see sympy's codegen, which is extra nice together with the optics module), and possibly sharding in the future.
    • Perhaps I'd just like to give feedback that jax support in tidy3d truly is a "killer feature" for me (and maybe others), for far more reasons than just gradients.
    • For example, CPU-based jax alone is allowing me to manipulate FieldTime monitor output as real-time, interactively adjustable videos of field monitors, which should scale even to large / volumetric fields due to jax's GPU support.

@tylerflex
Copy link
Collaborator Author

tylerflex commented May 14, 2024

@tylerflex That sounds amazing! I'm glad it's in the plan. Just to make sure I'm understanding correctly; the autograd is in reference to this? https://github.com/HIPS/autograd

Yea that's the one.

If so, two thoughts/worries:

  • HIPS/autograd seems to not be actively developed anymore, per the message on the main page of the repository. Incidentally, the four main devs seem to now be working on jax?

That's true. We are considering forking autograd ourselves and maintaining a version of it, since we're mainly just using it for auto-diff. jax is proving quite challenging to work with for auto-diff alone.

  1. jax and jaxlib are ~50 MB, whereas autograd is just ~50kB. So there is much less of an issue making autograd a core dependency. it also has very few dependencies of its own, mainly just numpy
  2. Many users have installation issues with jax.
  • On a personal node, I'm relying quite heavily on jax, not just for gradients, but also for JIT optimization + GPU-based processing of output data, sympy support (see sympy's codegen, which is extra nice together with the optics module), and possibly sharding in the future.
    • Perhaps I'd just like to give feedback that jax support in tidy3d truly is a "killer feature" for me (and maybe others), for far more reasons than just gradients.
    • For example, CPU-based jax alone is allowing me to manipulate FieldTime monitor output as real-time, interactively adjustable videos of field monitors, which should scale even to large / volumetric fields due to jax's GPU support.

It would be interesting to learn more about how you use jax + tidy3d for JIT and GPU processing on the front end. These features seem to not work for me with tidy3d.

We are planning to support autograd 'natively'. but also write converters from jax/pytorch/tensorflow to tidy3d autograd. So you should still be able to use jax auto-diff features. And we'll keep the adjoint plugin around, although we probably won't develop much for it.

@so-rose
Copy link

so-rose commented May 14, 2024

jax and jaxlib are hundreds of MB, whereas autograd is just 50kB. So there is much less of an issue making autograd a core dependency.

Well, that makes sense! jaxlib[cpu] is indeed quite big; I have no GPU kernels installed right now, but I can see it eats 250MB.

  • Looking at my .venv, xla_extension.so is the culprit - which must be a brute-force "include everything" set of architecture-specific binary procedures? I wonder if anything can be done. It's probably hard, though. Unrolled loops and all that.
  • For GPU support, only Linux is supported with a special installation procedure, which is likely very confusing (CUDA on Windows is experimental, and there's also an experimental Metal backend, but yeah). Doesn't bother me much, but I'm not making a widely distributed commercial product where customers need support for all this chaos 😃 .
  • Oh, and it needs a good BLAS too. Semi easy on Linux, no idea elsewhere...

As mentioned, I'm currently sticking to jaxlib[cpu], to validate my methodology. However, I'm keeping a rather strict "jit/vmap/etc. everything" usage pattern that plays nice with the way GPUs like their memory, and I'd be surprised if jaxlib[cuda12] were to have trouble with the XLA code I'd like it to run. After all, the whole point of jax is to make GPU-based machine learning go fast.

It would be interesting to learn more about how you use jax + tidy3d for JIT and GPU processing on the front end.

Sure, if you're curious. So, all of this is in the context of a visual node-based system I'm building in Blender for my BSc thesis (https://github.com/so-rose/blender_maxwell/; I wouldn't try to install it right now, it's quite unstable as of writing).

Tl;dr: My use case is not "normal" by a long stretch, but I just deal with the data directly by leaving xarray land as fast as possible to take advantage of jax's happy path of `jit everything and run only that".

Each "math node" composes a function lazily, with data originating from an "extract data" node that pulls out a jax-compatible array sourced directly from xarray's data attribute (don't think a copy happens, but if so, only one).

  • Each node might do jnp operations (only jnp), run through a lambdify'ed sympy expression, or declare parameters which must be all eventually inserted in the top-level function. This might sound restrictive, but I promise it isn't!
  • At the end, it runs the final/top-level function through a @jit, which compiles the super-inefficient function of function of ... into optimized XLA bytecode, which by its nature should run on anything jax supports.
  • This compiled function stays cached in a separate "flow lane", so that changes to previously-declared parameters only re-runs the function.
  • Mind you, that first "extract data" doesn't need to eagerly produce an array; it could just as easily (with some whispering about shape to jax) lazily load data from a disk, for final evaluation only when the result is needed.
  • The final node, ex. a "visualize node", does the jit and (implicitly) caches it, then runs it to produce pixels that are blitted directly to a Blender image buffer (which I've maximally seen take 3ms for a high-res image), so the user can see the result almost instantly.

My completely average laptop can process a 100x1x100x200 complex frequency-domain FieldMonitor with a squeeze, frequency index selection, any real/imag/abs, and an BW->RGB interpolated colormap on the order of microseconds. I've yet to throw any operations that meaningfully makes it slow slow down (even things like computing the determinant of matrix-folded F-X-Y data, SVDs, etc.). Which is why I'm quite convinced that far larger datasets will scale beautifully when I flick the "GPU" button; 41ms (24fps) really is a long, long time for most data sizes that one cares to deal with to bog down modern GPUs.

Of course, generally one just needs to deduce things like transmission losses, alone for which all of this is vastly over-engineered. But my hope is that this approach is also flexible enough to do some really exotic stuff. Especially loss functions. Or trying to deduce which presumptions differ from simulations to experiments.

These features seem to not work for me with tidy3d.

I mean, there's relatively little "just works" if we're talking about directly @jiting a lot of what Tidy3D comes with. But honestly, it's not a show stopper - data is data, after all.

  • Tidy3D Methods: Haven't needed direct compatibility so much so far, but the productive path has been to copy/paste (my project is AGPL, so license should be compatible, but I include a note in the docstring) and %s/np/jnp/g any functions (like amp_time for source time dependence), to get around the primary differences between np and jnp. It's not ideal, but it's manageable in the few cases I need it.
    • Mainly, the sharp edges seem to be type promotion to float64/complex128 (when numpy does it on its own, jax loses the ability to trace), which jax really hates, as well as anything that changes the array shape (jax simply doesn't jit with dynamically sized arrays; it's a design choice). Simply using jnp fixes the first, not the second, but Tidy3D seems generally designed in a way that makes the second not such a big issue.
  • Output Data: xarray really is lovely, but the happy path in my case has been to extract the raw data (for composing high performance operations on) and manually tracking index names, coordinates, etc. on the side (for deducing which operations exactly to compose). So, reinventing tensors again, I suppose!

We are planning to support autograd 'natively'. but also write converters from jax/pytorch/tensorflow to tidy3d autograd.

Fantastic. I'd be happy to give feedback on the jax side of things once it's ready for user consumption. Though I'll be sticking with the adjoint plugin for now, of course.

I hope it's at least interesting why I ask about jax support! It can be a bit sharp. But sometimes also magical!

@tylerflex tylerflex changed the title TriangleMesh support in adjoint plugin more general TriangleMesh differentiation support Aug 2, 2024
@tylerflex tylerflex linked a pull request Dec 6, 2024 that will close this issue
@tylerflex
Copy link
Collaborator Author

just fyi for anyone reading this: this is basically done in #2095 , but we need to make trimesh and autograd get along and add a bunch more tests, but all of the hard bits are done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants