Skip to content

Commit

Permalink
PyTorch Softmax Ops (#846)
Browse files Browse the repository at this point in the history
Co-authored-by: HarshvirSandhu <[email protected]>
  • Loading branch information
HAKSOAT and HarshvirSandhu authored Jun 28, 2024
1 parent f3d2ede commit 17fa8b1
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ channels:
dependencies:
- python>=3.10
- compilers
- numpy>=1.17.0
- numpy>=1.17.0,<2
- scipy>=0.14,<1.14.0
- filelock
- etuples
Expand Down
50 changes: 50 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


@pytorch_funcify.register(Elemwise)
Expand Down Expand Up @@ -34,3 +35,52 @@ def dimshuffle(x):
return res

return dimshuffle


@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].inputs[0].dtype

if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch Softmax is not currently implemented for non-float types."
)

def softmax(x):
if axis is not None:
return torch.softmax(x, dim=axis)
else:
return torch.softmax(x.ravel(), dim=0).reshape(x.shape)

return softmax


@pytorch_funcify.register(LogSoftmax)
def pytorch_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].inputs[0].dtype

if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch LogSoftmax is not currently implemented for non-float types."
)

def log_softmax(x):
if axis is not None:
return torch.log_softmax(x, dim=axis)
else:
return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape)

return log_softmax


@pytorch_funcify.register(SoftmaxGrad)
def jax_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis

def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm

return softmax_grad
49 changes: 49 additions & 0 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import pytest

import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py

Expand Down Expand Up @@ -53,3 +55,50 @@ def test_pytorch_elemwise():

fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)

if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch Softmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)

if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch LogSoftmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_pytorch_and_py(fgraph, [dy_value, sm_value])

0 comments on commit 17fa8b1

Please sign in to comment.