Skip to content

Commit 8c7491a

Browse files
authored
Merge pull request #76 from facebookresearch/stop_grad
Add stop gradient to tangent vector field calcuation
2 parents d461c4b + f75a975 commit 8c7491a

File tree

4 files changed

+4
-3
lines changed

4 files changed

+4
-3
lines changed

.bumpversion.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[tool.bumpversion]
2-
current_version = "v0.5.2"
2+
current_version = "v0.5.3"
33
commit = true
44
commit_args = "--no-verify"
55
tag = true

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22

33
name = "fmmax"
4-
version = "v0.5.2"
4+
version = "v0.5.3"
55
description = "Fourier modal method with Jax"
66
readme = "README.md"
77
requires-python = ">=3.7"

src/fmmax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22

3-
__version__ = "v0.5.2"
3+
__version__ = "v0.5.3"
44

55
from . import (
66
basis,

src/fmmax/vector.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def compute_tangent_field(
144144
Returns:
145145
The normal field, `(tx, ty)`.
146146
"""
147+
arr = jax.lax.stop_gradient(arr)
147148
batch_shape = arr.shape[:-2]
148149
arr = utils.atleast_nd(arr, n=3)
149150
arr = arr.reshape((-1,) + arr.shape[-2:])

0 commit comments

Comments
 (0)