Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/ntops/kernels/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, p, seed, output):
output = ntl.where(ntl.rand(seed, input.offsets()) > p, input / (1 - p), 0) # noqa: F841


@functools.cache
def make(ndim):
tensors = (Tensor(ndim), Tensor(0), Tensor(0), Tensor(ndim))

return ninetoothed.make(arrangement, application, tensors)
24 changes: 24 additions & 0 deletions src/ntops/torch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random

import torch

import ntops.kernels.abs
Expand All @@ -10,6 +12,7 @@
import ntops.kernels.clamp
import ntops.kernels.cos
import ntops.kernels.div
import ntops.kernels.dropout
import ntops.kernels.eq
import ntops.kernels.exp
import ntops.kernels.ge
Expand Down Expand Up @@ -147,6 +150,27 @@ def div(input, other, *, rounding_mode=None, out=None):
return out


def dropout(input, p=0.5, training=True, inplace=False):
if not training or p == 0:
if inplace:
return input
else:
return input.clone()

seed = random.randrange(0, 2**31)

if inplace:
output = input
else:
output = torch.empty_like(input)

kernel = ntops.kernels.dropout.make(input.ndim)

kernel(input, p, seed, output)

return output


def exp(input, *, out=None):
if out is None:
out = torch.empty_like(input)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import random

import pytest
import torch
import torch.nn.functional as F

import ntops.torch
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

input = torch.randn(shape, dtype=dtype, device=device)
p = random.uniform(0, 1)

# TODO: Add `training` and `inplace` tests later.
ninetoothed_output = ntops.torch.dropout(input, p=p)
reference_output = F.dropout(input, p=p)

assert ninetoothed_output.shape == reference_output.shape

ninetoothed_non_zero_ratio = (
ninetoothed_output.nonzero().numel() / ninetoothed_output.ndim / input.numel()
)
reference_non_zero_ratio = (
reference_output.nonzero().numel() / reference_output.ndim / input.numel()
)

print(abs(ninetoothed_non_zero_ratio - reference_non_zero_ratio))

assert abs(ninetoothed_non_zero_ratio - reference_non_zero_ratio) < 0.1

assert torch.allclose(
ninetoothed_output[ninetoothed_output != 0],
input[ninetoothed_output != 0] / (1 - p),
)