Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Softmax Ops #846

Merged
merged 51 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
27e2526
Add pytorch support for some basic Ops
HarshvirSandhu May 13, 2024
629d00b
update variable names, docstrings
HarshvirSandhu May 13, 2024
3eceb56
Avoid numpy conversion of torch Tensors
HarshvirSandhu May 17, 2024
3cde964
Fix typify and CheckAndRaise
HarshvirSandhu May 17, 2024
c003aa5
Fix Elemwise Ops
HarshvirSandhu May 17, 2024
8dc406e
Fix Scalar Ops
HarshvirSandhu May 17, 2024
a8f6ddb
Fix ruff-format
HarshvirSandhu May 17, 2024
9d535f5
Initial setup for pytorch tests
HarshvirSandhu May 23, 2024
c5600da
Fix mode parameters for pytorch
HarshvirSandhu May 23, 2024
54b6248
Prevent conversion of scalars to numpy
HarshvirSandhu May 23, 2024
19454b3
Update TensorConstantSignature and map dtypes to Tensor types
HarshvirSandhu May 23, 2024
92d7114
Add tests for basic ops
HarshvirSandhu May 23, 2024
5aae0e5
Remove torch from user facing API
HarshvirSandhu May 29, 2024
8c174dd
Add function to convert numpy arrays to pytorch tensors
HarshvirSandhu May 29, 2024
0977c3a
Avoid copy when converting to tensor
HarshvirSandhu May 29, 2024
1c23825
Fix tests
HarshvirSandhu May 29, 2024
c9195a8
Remove dispatches that are not tested
HarshvirSandhu May 31, 2024
b07805c
set path for pytorch tests
HarshvirSandhu May 31, 2024
9e8d3fc
Remove tensorflow probability from yml
HarshvirSandhu Jun 4, 2024
a2d3afa
Add checks for runtime broadcasting
HarshvirSandhu Jun 4, 2024
a577a80
Remove IfElse
HarshvirSandhu Jun 4, 2024
499a174
Remove dev notebook
HarshvirSandhu Jun 12, 2024
2826613
Fix check and raise
HarshvirSandhu Jun 12, 2024
62ffcec
Fix compare_pytorch_and_py
HarshvirSandhu Jun 12, 2024
acdbba1
Fix DimShuffle
HarshvirSandhu Jun 12, 2024
2519c65
Add tests for Elemwise operations
HarshvirSandhu Jun 12, 2024
eb6d5c2
Fix test for CheckAndRaise
HarshvirSandhu Jun 14, 2024
9f02a4f
Remove duplicate function
HarshvirSandhu Jun 14, 2024
caf2965
Remove device from pytorch_typify
HarshvirSandhu Jun 15, 2024
bf87eb9
Merge branch 'main' of https://github.com/HarshvirSandhu/pytensor int…
HarshvirSandhu Jun 15, 2024
2c27683
Solve merge conflict
HarshvirSandhu Jun 15, 2024
c603c6b
Use micromamba for pytorch install
HarshvirSandhu Jun 15, 2024
3f17107
Fix pytorch linker
HarshvirSandhu Jun 16, 2024
e850d8d
Fix typify and deepcopy
HarshvirSandhu Jun 16, 2024
e682fc4
Parametrize device in all tests
HarshvirSandhu Jun 16, 2024
bf4cf92
Install torch with cuda
HarshvirSandhu Jun 16, 2024
899e7f9
Fix test_pytorch_FunctionGraph_once
HarshvirSandhu Jun 16, 2024
04d2935
Remove device argument from test
HarshvirSandhu Jun 16, 2024
8ec7661
remove device from elemwise tests and add assertions
HarshvirSandhu Jun 17, 2024
bb7df41
skip tests if cuda is not available
HarshvirSandhu Jun 17, 2024
0441cf2
Fix tests
HarshvirSandhu Jun 18, 2024
85f2742
Merge branch 'main' of https://github.com/pymc-devs/pytensor into pyt…
HAKSOAT Jun 20, 2024
4ca5aca
Implemented softmax ops for PyTorch
HAKSOAT Jun 23, 2024
b9aca57
Merge remote-tracking branch 'upstream/main' into pytensor-pytorch-so…
HAKSOAT Jun 23, 2024
287d9c2
Switched to run softmax on all items if axis is None
HAKSOAT Jun 24, 2024
f42e2a0
Implemented log softmax
HAKSOAT Jun 24, 2024
35b17e0
Implemented softmaxgrad
HAKSOAT Jun 25, 2024
5efc3c8
Added checks and error raises for nonfloat inputs
HAKSOAT Jun 27, 2024
16e415a
Added checks and error raises for nonfloat inputs
HAKSOAT Jun 27, 2024
ffbc594
Merge branch 'pytensor-pytorch-softmax' of https://github.com/HAKSOAT…
HAKSOAT Jun 28, 2024
b4cdce0
Merge branch 'main' of https://github.com/pymc-devs/pytensor into pyt…
HAKSOAT Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def dimshuffle(x):
@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].outputs[0].dtype
Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to check the input dtype, because we would fail if we pass an integer. PyTensor could start saying Softmax takes as input integers and outputs floats once we fix it? Sorry if I said the output before

Suggested change
dtype = kwargs["node"].outputs[0].dtype
dtype = kwargs["node"].inputs[0].dtype


def softmax(x):
if not torch.is_floating_point(x):
x = x.to(torch.float32)
if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch Softmax is not currently implemented for non-float types."
)

def softmax(x):
if axis is not None:
return torch.softmax(x, dim=axis)
else:
Expand All @@ -56,11 +59,14 @@ def softmax(x):
@pytorch_funcify.register(LogSoftmax)
def pytorch_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].outputs[0].dtype

def log_softmax(x):
if not torch.is_floating_point(x):
x = x.to(torch.float32)
if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch LogSoftmax is not currently implemented for non-float types."
)

def log_softmax(x):
if axis is not None:
return torch.log_softmax(x, dim=axis)
else:
Expand Down
28 changes: 22 additions & 6 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,40 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
def test_softmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)

compare_pytorch_and_py(fgraph, [test_input])
if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch Softmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
def test_logsoftmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)

compare_pytorch_and_py(fgraph, [test_input])
if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch LogSoftmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])


@pytest.mark.parametrize("axis", [None, 0, 1])
Expand Down