-
Notifications
You must be signed in to change notification settings - Fork 25
Autograd -> Jax conversion #433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Wavebot is working now when converted to jax. Working on updating the code so that only the necessary functions use jax.numpy while the rest use numpy. |
* bug bix : DC and Nyquist frequency should not be devided by two before ifft * Changed td_to_fd to scale single sided frequency components rather than TD signal * minor bug fix from issue332 sandialabs#332
* added initial file changes based on sphinx_multiversion docs and WEC-Sim implementation * removed sphinx-multiversion since it is no longer supported and made manual multiversion * now uses absolute paths, commented out linkcheck for debugging * fixed docstring errors in utilities module * updating files again that somehow got reverted * fixing path in conf.py * don't run tutorials (will revert later) * handle file moves correctly, fixed if statement to make other versions appear * fixed two bugs in versions template * reverted temp changes, changes latest to main * switched latest to main * main branch now in root directory of pages * fixed URLs with change from last commit * make other branches visible before building * switched main branch tag for more testing * fixed typo * switched dev branch to an existing branch * renamed main to latest, changed version.html file name to avoid confusion * added prints about moving files so Sphinx output isn't misleading * fixed typo with quotations * changed versions.html name back because that broke things I guess * modified contributing documentation to reflect changes * add logic to remove duplicate 'latest' branch * Fixed pathing when already on latest * remove typo * Troubleshooting complete, switching back to correct branches for deployment * Removed extra word in docstring * removed redundant function * fixed pathing so returns to same file (and fixes tutorial/API docs) * changed latest branch for demonstration * switched back latest branch for deployment
* removed conda environment from workflows since newer capytaine/wavespectra work with Windows * fixed unnecessary capitalization * still create CI conda environment to fix Mac environment failures * added conda env fully back in, push workflow deploys docs, split PR workflow * conda environment activates again * mambaforge instead of miniforge * manual cache reset * reset to older version of setup-miniconda to troubleshoot
…supported versions (sandialabs#390) Co-authored-by: jtgrasb <[email protected]>
* Try specifying subversion * Test new cache * revert to 3.12 * Revert comment back to normal
|
To do:
|
|
I added just-in-time compilation to the optimization (objective function, constraints, and relative gradients) using jax.jit which should speed up the code. I also had to change the call to Here is the computation time with and without jit for the AquaHarmonics parameter sweep cell:
Based on this issue thread, scipy does computation on numpy arrays which means the data type keeps getting converted back and forth between jax and numpy arrays. This is why the jax without jit has such a large computation time and the jax with jit computation time is not so much faster than autograd. jax.vmap:
To do:
|
|
@cmichelenstrofer - I'm trying to debug the CI on this PR for macOS. The key part of the failing CI log seems to be the following, which is either due to some mismatch of When I install locally, I do the following trying to replicate the install commands in pr.yml. This seems to work on the GitHub runner (see, e.g., https://github.com/sandialabs/WecOptTool/actions/runs/17650331026/job/50178598578), but my machine tells me there aren't capytaine and wavespectra binaries for arm64 on conda-forge (see #324 for our previous discussion on this). So I instead do the following, which works fine. mamba create -n tmp_wot
mamba activate tmp_wot
mamba create -n tmp_wot
mamba install pip
cd wecopttool
pip install .
pip install gmsh pygmsh coveralls pytest # we could probably make a special set of optional dependencies in pyproject.toml to avoid this line
coverage run -m pytestBased on all of this, I pushed an update to pr.yml with f1e6e51 to basically do the whole installation via pip. Note that this bypasses the installation of capytaine and wavespectra via mamba which I believe you added to allow for caching (see #35)1 It seems to fix the problem on macOS, Ubuntu still works fine, and now Windows fails due a Segmentation fault 🤪 -- maybe this is due to these warnings which we have actually been receiving for a while about
Footnotes
|
|
For some reason, fully removing mamba and pip installing in editable mode using To do:
|
Pull Request Test Coverage Report for Build 18686165478Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
|
After encountering additional GitHub CI errors with uploading to coveralls (errors also present on dev branch), I had to edit the workflow further. Here are the changes I made to the pr.yml and push.yml:
This PR should now be ready for merge. |
* Update CONTRIBUTING.md to indicate PRs should be to the new `dev` branch * Update RELEASING.md to reflect new workflow with the `dev` branch * update docstrings (sandialabs#326) * damping naming and consistently change radiation damping (sandialabs#328) * issue 321 fd_to_td() bug (sandialabs#329) * bug bix : DC and Nyquist frequency should not be devided by two before ifft * Changed td_to_fd to scale single sided frequency components rather than TD signal * minor bug fix from issue332 sandialabs#332 * nodf -> ndof (sandialabs#334) * add DOI for Daniel's paper (sandialabs#336) * Lower tolerance for new test to fix CI failing occasionally * hyperlinks no longer have formatting, plus other small adjustments (sandialabs#348) * Merge to dev, not main (sandialabs#349) * Dev version of documentation site (sandialabs#347) * added initial file changes based on sphinx_multiversion docs and WEC-Sim implementation * removed sphinx-multiversion since it is no longer supported and made manual multiversion * now uses absolute paths, commented out linkcheck for debugging * fixed docstring errors in utilities module * updating files again that somehow got reverted * fixing path in conf.py * don't run tutorials (will revert later) * handle file moves correctly, fixed if statement to make other versions appear * fixed two bugs in versions template * reverted temp changes, changes latest to main * switched latest to main * main branch now in root directory of pages * fixed URLs with change from last commit * make other branches visible before building * switched main branch tag for more testing * fixed typo * switched dev branch to an existing branch * renamed main to latest, changed version.html file name to avoid confusion * added prints about moving files so Sphinx output isn't misleading * fixed typo with quotations * changed versions.html name back because that broke things I guess * modified contributing documentation to reflect changes * add logic to remove duplicate 'latest' branch * Fixed pathing when already on latest * remove typo * Troubleshooting complete, switching back to correct branches for deployment * Removed extra word in docstring * removed redundant function * fixed pathing so returns to same file (and fixes tutorial/API docs) * changed latest branch for demonstration * switched back latest branch for deployment * updated with new Capytaine docs URL * Add warnings when adding inertia and hydrostatic stiffness automatically (sandialabs#346) * CI workflow cleanup (sandialabs#352) * removed conda environment from workflows since newer capytaine/wavespectra work with Windows * fixed unnecessary capitalization * still create CI conda environment to fix Mac environment failures * added conda env fully back in, push workflow deploys docs, split PR workflow * conda environment activates again * mambaforge instead of miniforge * manual cache reset * reset to older version of setup-miniconda to troubleshoot * Updated workflows to newest Python version and changed references to supported versions (sandialabs#390) Co-authored-by: jtgrasb <[email protected]> * Revert to Python 3.12 (sandialabs#394) * Try specifying subversion * Test new cache * revert to 3.12 * Revert comment back to normal * use dev for docs and restrict sphinx (sandialabs#396) * Remove Sphinx version requirement (sandialabs#409) * v3.0.3 * v3.1 * Trying to convert tutorial 1 * Convert to jax progress * post-processing * clear outputs * wavebot tutorial running * wavebot tutorial running * Update to jax and numpy * Revert wavebot execution count * update pyproject.toml * Specify jax version for mac * try jaxlib * no jaxlib * add jax to environment manually * add jaxlib to env * ad jaxlib to pyproject * install jax manually for macos * conda init * install jax and jaxlib on macos * try arm64 * remove arm * make core optimization jittable * try pinning jax version * revert previous * install entirely w. pip * Add verbose outputs to testing * editable mode * no mamba * editable mode * update random inputs * remove cache environment * use pytest-cov * remove cache clear * Fix push.yml --------- Co-authored-by: Carlos A. Michelén Ströfer <[email protected]> Co-authored-by: Ryan Coe <[email protected]> Co-authored-by: Daniel Gaebele <[email protected]> Co-authored-by: Michael Devin <[email protected]> Co-authored-by: mcdevin <[email protected]>

Description
Convert autograd to jax.
Wavebot tutorial is working for the first optimization.