Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more symint ops #17

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,15 @@ symint:
- empty.memory_format
- empty_strided
- expand_copy
- full
- new_empty_strided
- view_copy
- diagonal_backward
- narrow_copy
- select_backward
- select.int
- slice_backward
- slice_copy.Tensor
# See Note: [Disabling functionalization]
- expand
- view
Expand Down
327 changes: 327 additions & 0 deletions test/ds/test_bounded_dynamic_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
import os
import sys
import unittest
import torch
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm

sys.path.insert(1, os.path.join(sys.path[0], '..'))
import test_utils

PD = torch._C._EnablePythonDispatcher()
XLA_DEVICE = xm.xla_device()


def _mark_dynamic(t, dims, bounds):
torch_xla._XLAC._xla_mark_bounded_dynamic(t, dims, bounds)


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)

if past_key_values_length > 0:
mask = torch.cat([
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device), mask
],
dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(
inverted_mask.to(torch.bool),
torch.finfo(dtype).min)


def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else
expanded_attn_mask + combined_attention_mask)

return combined_attention_mask


class TestBoundedDynamicOps(test_utils.XlaTestCase):

def _diff_output(self,
torch_out,
xla_out,
atol=1e-3,
rtol=1e-5,
equal_nan=True):
if isinstance(torch_out, torch.Tensor):
self.assertIsInstance(xla_out, torch.Tensor)
torch_out = torch_out.detach().cpu()
xla_out = xla_out.detach().cpu()
self.assertEqual(xla_out.dtype, torch_out.dtype)
self.assertEqual(torch_out.shape, xla_out.shape)
self.assertTrue(
torch.allclose(
torch_out, xla_out, atol=atol, rtol=rtol, equal_nan=equal_nan))
elif isinstance(torch_out, (tuple, list)):
self.assertIsInstance(xla_out, (tuple, list))
self.assertEqual(len(torch_out), len(xla_out))
for o1, o2 in zip(torch_out, xla_out):
self._diff_output(o1, o2, rtol, atol)
else:
self.assertEqual(torch_out, xla_out)

def test_add(self):
t1 = torch.randn([5, 2])
t2 = torch.randn([5, 2])
torch_out = t1 + t2

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0], [10])
_mark_dynamic(t2, [0], [10])
xla_out = t1 + t2
self._diff_output(torch_out, xla_out)

def test_add_broadcast(self):
t1 = torch.randn([5, 2])
t2 = torch.randn([2])
torch_out = t1 + t2

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0], [10])
xla_out = t1 + t2
self._diff_output(torch_out, xla_out)

def test_add_scalar(self):
t1 = torch.randn([5, 2])
t2 = 1.0
torch_out = t1 + t2

t1 = t1.to(XLA_DEVICE)
_mark_dynamic(t1, [0], [10])
xla_out = t1 + t2
self._diff_output(torch_out, xla_out)

def test_reshape(self):
x = torch.randn(4, 101, 100)
y = torch.randn(4 * 101 * 100)
torch_out = y.reshape(x.shape[0], x.shape[1], -1)

x = x.to(XLA_DEVICE)
y = y.to(XLA_DEVICE)
_mark_dynamic(x, [0, 1], [10, 200])
_mark_dynamic(y, [0], [10 * 200 * 100])
xla_out = y.reshape(x.shape[0], x.shape[1], -1)
self._diff_output(torch_out, xla_out)

def test_flatten(self):
x = torch.randn(4, 101, 100)
torch_out = x.flatten(0, 1)

x = x.to(XLA_DEVICE)
_mark_dynamic(x, [0], [10])
xla_out = x.flatten(0, 1)
self._diff_output(torch_out, xla_out)

def test_arange(self):
x = torch.randn(4, 101, 100)
torch_out = torch.arange(
0, (x.shape[0] + 1) * x.shape[1],
step=x.shape[1],
dtype=torch.int32,
device=x.device)

x = x.to(XLA_DEVICE)
_mark_dynamic(x, [1], [200])
xla_out = torch.arange(
0, (x.shape[0] + 1) * x.shape[1],
step=x.shape[1],
dtype=torch.int32,
device=x.device)
self._diff_output(torch_out, xla_out)

def test_slice_with_backward(self):
x = torch.randn(4, 101, 100)
y = torch.randn(4, 201, 100)
x.requires_grad = True
y.requires_grad = True
torch_out = y[0:10, 10:x.shape[1], ...]
torch.autograd.backward(torch_out, torch.zeros_like(torch_out))
torch_grad = y.grad

x = x.detach().to(XLA_DEVICE)
y = y.detach().to(XLA_DEVICE)
x.requires_grad = True
y.requires_grad = True
_mark_dynamic(x, [1], [200])
xla_out = y[0:10, 10:x.shape[1], ...]
torch.autograd.backward(xla_out, torch.zeros_like(xla_out))
xla_grad = y.grad

self._diff_output(torch_out, xla_out)
self._diff_output(torch_grad, xla_grad)

def test_attn_mask(self):
inputs_embeds = torch.randn(4, 101)
attention_mask = torch.ones((4, 101),
dtype=torch.bool).to(inputs_embeds.device)
torch_out = _prepare_decoder_attention_mask(
attention_mask, (inputs_embeds.shape[0], inputs_embeds.shape[1]),
inputs_embeds, 0)

inputs_embeds = inputs_embeds.to(XLA_DEVICE)
attention_mask = attention_mask.to(XLA_DEVICE)
_mark_dynamic(inputs_embeds, [1], [200])
_mark_dynamic(attention_mask, [1], [200])
xla_out = _prepare_decoder_attention_mask(
attention_mask, (inputs_embeds.shape[0], inputs_embeds.shape[1]),
inputs_embeds, 0)

self._diff_output(torch_out, xla_out)

def test_matmul_0(self):
t1 = torch.randn([5, 2]).to(torch.bfloat16)
t2 = torch.randn([2, 3]).to(torch.bfloat16)
torch_out = t1.to("cuda") @ t2.to("cuda")

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0], [10])
xla_out = t1 @ t2

self.assertIn('<=10,3', torch_xla._XLAC._get_xla_tensors_text([xla_out]))
self._diff_output(torch_out, xla_out)

def test_matmul_1(self):
t1 = torch.randn([5, 2]).to(torch.bfloat16)
t2 = torch.randn([2]).to(torch.bfloat16)
torch_out = t1.to("cuda") @ t2.to("cuda")

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0], [10])
xla_out = t1 @ t2

self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([xla_out]))
self._diff_output(torch_out, xla_out)

def test_matmul_2(self):
t1 = torch.randn([10, 5, 2]).to(torch.bfloat16)
t2 = torch.randn([2]).to(torch.bfloat16)
torch_out = t1.to("cuda") @ t2.to("cuda")

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0, 1], [20, 10])
xla_out = t1 @ t2
self.assertIn('<=20,<=10', torch_xla._XLAC._get_xla_tensors_text([xla_out]))
self._diff_output(torch_out.cpu(), xla_out)

def test_matmul_3(self):
t1 = torch.randn([10, 3, 4]).to(torch.bfloat16)
t2 = torch.randn([10, 4, 5]).to(torch.bfloat16)
torch_out = t1.to("cuda") @ t2.to("cuda")

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0, 1], [20, 10])
_mark_dynamic(t2, [0], [20])
xla_out = t1 @ t2
self.assertIn('<=20,<=10,5',
torch_xla._XLAC._get_xla_tensors_text([xla_out]))
self._diff_output(torch_out, xla_out)

def test_matmul_4(self):
t1 = torch.randn([10, 3, 4]).to(torch.bfloat16)
t2 = torch.randn([4, 5]).to(torch.bfloat16)
torch_out = t1.to("cuda") @ t2.to("cuda")

t1 = t1.to(XLA_DEVICE)
t2 = t2.to(XLA_DEVICE)
_mark_dynamic(t1, [0, 1], [20, 10])
xla_out = t1 @ t2
self.assertIn('<=20,<=10,5',
torch_xla._XLAC._get_xla_tensors_text([xla_out]))
self._diff_output(torch_out, xla_out)

def test_triu(self):
t = torch.randn(4, 4)
torch_out = torch.triu(t, diagonal=1)

t = t.to(XLA_DEVICE)
_mark_dynamic(t, [0, 1], [10, 10])
xla_out = torch.triu(t, diagonal=1)

self.assertIn('<=10,<=10', torch_xla._XLAC._get_xla_tensors_text([xla_out]))
self._diff_output(torch_out, xla_out)

def test_nll_loss_with_backward(self):
logits = torch.randn(20, 30)
target = torch.randint(0, 30, (20,), dtype=torch.long)
logits.requires_grad = True
torch_out = F.nll_loss(logits, target)
torch_out.backward()
torch_grad = logits.grad

logits = logits.detach().to(XLA_DEVICE)
logits.requires_grad = True
target = target.to(XLA_DEVICE)
_mark_dynamic(logits, [0], [50])
_mark_dynamic(target, [0], [50])
xla_out = F.nll_loss(logits, target)
xla_out.backward()
xla_grad = logits.grad
self.assertIn('<=50,30', torch_xla._XLAC._get_xla_tensors_text([xla_grad]))

anw90 marked this conversation as resolved.
Show resolved Hide resolved
self._diff_output(torch_out, xla_out)
self._diff_output(torch_grad, xla_grad)


if __name__ == '__main__':
assert test_utils.is_disc_backend()
os.environ['USE_BOUND_FOR_SHAPE_COMPARE'] = os.getenv(
'USE_BOUND_FOR_SHAPE_COMPARE', '1')
test = unittest.main()
# DISABLE PYTHON DISPATCHER FLAG
del PD
sys.exit(0 if test.result.wasSuccessful() else 1)
Loading
Loading