Skip to content

Commit

Permalink
fix nabla for function with different in/out types (#4)
Browse files Browse the repository at this point in the history
* fix nabla for function with different in/out types

* bump version
  • Loading branch information
mavenlin authored Dec 1, 2023
1 parent 9d36e38 commit e8c9bed
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion autofd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@
"operators",
]

__version__ = "0.0.4" # noqa
__version__ = "0.0.5" # noqa
10 changes: 8 additions & 2 deletions autofd/operators/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,9 @@ def nabla_f(*args):
Returns:
(nabla f): the gradient function.
"""
assert method in (
jax.jacfwd, jax.jacrev
), "method need to be jax.jacfwd or jax.jacrev"
return nabla_p.bind(f, argnums=argnums, method=method)


Expand All @@ -714,8 +717,11 @@ def nabla_spec(f, *, argnums, method):
)
ospec = f.ret.spec
jspec = tree_map(
lambda osp:
tree_map(lambda isp: Spec(osp.shape + isp.shape, isp.dtype), ispec), ospec
lambda osp: tree_map(
lambda isp: Spec(
osp.shape + isp.shape, osp.dtype if method == jax.jacfwd else isp.dtype
), ispec
), ospec
)
return (Ret(jspec), *f.arg), None

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ignore = E731

[metadata]
name = autofd
version = 0.0.4
version = 0.0.5
author = "Min Lin"
author_email = "[email protected]"
description = "Automatic Functional Derivative in JAX"
Expand Down

0 comments on commit e8c9bed

Please sign in to comment.