Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dynamo-xla to the new torch-xla version and add DLPack support #25

Merged
merged 22 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9d907a6
Add dlpack support (#7025)
vanbasten23 May 22, 2024
7e039bc
Make from_dlpack handle cuda synchronization implicitly for input ten…
vanbasten23 May 30, 2024
5140bcd
Reuse DLDeviceType. (#7163)
vanbasten23 Jun 1, 2024
d45968c
Automatically move CUDA non XLA Tensors to XLA Device and back to CUD…
changm Mar 13, 2024
8cb3371
Add a `DynamoSyncInputExecuteTime` counter (#6813)
JackCaoG Apr 1, 2024
575c31c
Fix runtime error when run dynamo with a profiler scope (#6913)
JackCaoG Apr 11, 2024
ac7988c
Add function for retrieving fallback operations. (#7116)
ysiraichi May 29, 2024
3ea8994
Integrate dlpack to dynamo. (#7173)
vanbasten23 Jun 14, 2024
dfed484
Add support for dynamic shape in dynamo (#7676)
wonjoolee95 Jul 23, 2024
dbd1978
Optimize dynamo dynamic shape caching (#7726)
JackCaoG Jul 24, 2024
869d181
Fix the crash when symint is part of the output with dynamic torch.co…
JackCaoG Jul 25, 2024
5ac31f1
reenable dynamo dynamic shape test (#7775)
JackCaoG Jul 30, 2024
6390b83
Support mark_dynamic (#7812)
JackCaoG Aug 7, 2024
85aab79
DYNAMO RNG seed update optimization (#7884)
JackCaoG Aug 20, 2024
624a3d5
CPU time optimization for GraphInputMatcher (#7895)
JackCaoG Aug 21, 2024
9b4b1c3
Optimize _split_xla_args_tensor_sym_constant (#7900)
JackCaoG Aug 22, 2024
37858ad
Fix the crash with copy op in dynamo (#7902)
JackCaoG Aug 23, 2024
9708b7a
Move the dynamo bridge to the _dynamo repo (#7909)
JackCaoG Aug 26, 2024
4ab34e8
add dynamo config skip_input_data_check (#7913)
JackCaoG Aug 26, 2024
088b845
In dynamo optim_mode avoid unnecessary set_attr (#7915)
JackCaoG Aug 27, 2024
4c87fa4
fix dynamo inplace copy (#7933)
zpcore Sep 3, 2024
069c89e
fix bug for compile related to dlpack
yitongh Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functorch.compile import aot_module_simplified, make_boxed_compiler
from torch._dynamo import disable

import torch_xla.core.dynamo_bridge as bridge
import torch_xla._dynamo.dynamo_bridge as bridge
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as metrics
from torch import fx, nn
Expand Down
247 changes: 203 additions & 44 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import sys

from absl.testing import absltest, parameterized
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.debug.metrics as met
import torch_xla.core.xla_env_vars as xenv
from torch_xla import runtime as xr
import torch_xla.debug.profiler as xp
import torch.optim as optim
import torch.nn as nn
import torch._dynamo as dynamo
Expand All @@ -24,7 +27,7 @@


def _is_on_tpu():
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'
return xr.device_type() == 'TPU'


skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU')
Expand All @@ -51,17 +54,80 @@ def random_op(self, a):
return torch.randn(5, 5, device=a.device) + a

def test_random_op_different_result_each_run(self):
xm.wait_device_ops()
met.clear_all()
dynamo_random_op = torch.compile(
self.random_op, backend="openxla", fullgraph=True)
t = torch.randn(5, 5).to(xm.xla_device())
dynamo_res_1 = dynamo_random_op(t)
dynamo_res_2 = dynamo_random_op(t)
dynamo_res_3 = dynamo_random_op(t)
# retriving/updating rng seed in the breidge should not cause transferToServer
self.assertNotIn("TransferFromDeviceTime", met.metric_names())
# updating rng seed will result in transferToServer
self.assertIn("TransferToDeviceTime", met.metric_names())
self.assertFalse(torch.allclose(dynamo_res_1, dynamo_res_2))
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))


class DynamoInferenceBasicTest(unittest.TestCase):
class DynamoLTCInteractionTest(unittest.TestCase):

def index_copy_inplace(self, cache, update_indices, xk):
cache.index_copy_(0, update_indices, xk)

def test_mark_step_after_dynamo(self):
cache_len = 512
kv_heads = 8
head_dim = 128
running = 16

device = xm.xla_device()
cache = torch.rand((cache_len, kv_heads, head_dim)).to(device)
update_indices = torch.randint(
0, cache_len, (running,), dtype=torch.long).to(device)
xk = torch.rand((running, kv_heads, head_dim)).to(device)

dynamo_index_copy_inplace = torch.compile(
self.index_copy_inplace, backend="openxla", fullgraph=True)
met.clear_all()
for i in range(10):
dynamo_index_copy_inplace(cache, update_indices, xk)
xm.wait_device_ops()
current_execute_time = met.metric_data('ExecuteTime')[0]
# This mark_step should be a no-op and don't trigger additional execution.
xm.mark_step()
xm.wait_device_ops()
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])

def test_copy_op(self):

def copy_a_to_b(a):
res = a.cos()
copy = torch.ops.aten.copy.default(a, res)
return copy

device = torch_xla.device()
compiled_copy = torch.compile(copy_a_to_b, backend="openxla")
a = torch.randn(2, 9).to(device)
res = compiled_copy(a)
self.assertTrue(torch.allclose(res, a))


class DynamoProfilerTest(unittest.TestCase):

def dummy_fn(self, a):
return torch.sin(a) + a

def test_dynamo_with_trace(self):
dynamo_dummy = torch.compile(
self.dummy_fn, backend="openxla", fullgraph=True)
t = torch.randn(2, 3, 4, device=xm.xla_device())
for i in range(10):
with xp.Trace('build_graph'):
t = dynamo_dummy(t)


class DynamoInferenceBasicTest(parameterized.TestCase):

@classmethod
def setUpClass(self):
Expand All @@ -72,6 +138,20 @@ def fn_simple(self, x, y):
b = torch.sin(y)
return a + b

def _choose_proper_device(self, initialize_on_cuda):
if not initialize_on_cuda:
return xm.xla_device()

assert initialize_on_cuda
if xr.device_type() != "CUDA" or not torch.cuda.is_available():
self.skipTest(
"Skip this test because it requires xr.device_type()=='CUDA' and torch.cuda.is_available()."
)
os.environ.update({
xenv.ZERO_COPY_ENABLED: "1",
})
return "cuda:0"

def test_simple_model(self):
device = xm.xla_device()
x = torch.tensor(100.0)
Expand All @@ -83,7 +163,7 @@ def test_simple_model(self):
res_xla_dynamo = fn_simple_dynamo(xla_x, xla_y)
self.assertIn('xla::add', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
# verifiy that tracing is skipped in following runs
# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_2 = fn_simple_dynamo(xla_x, xla_y)
self.assertNotIn('xla::add', met.counter_names())
Expand All @@ -102,22 +182,76 @@ def test_simple_model(self):
torch_xla._XLAC._get_xla_tensor_debug_info(xla_xy))
self.assertNotIn('XLAData: None',
torch_xla._XLAC._get_xla_tensor_debug_info(xla_y3))
# Dynamo has to sync the input since they are intermedate IR(xla_xy and xla_y3)
self.assertEqual(met.counter_value('DynamoSyncInputExecuteTime'), 1)

# Tests that the dynamo bridge automatically moves tensors to XLA device,
# then back to the original device.
@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
@parameterized.parameters(
"0",
"1",
)
def test_simple_model_automoves_tensors(self, zero_copy_enabled):
os.environ.update({
xenv.ZERO_COPY_ENABLED: zero_copy_enabled,
})
x = torch.tensor(100.0, requires_grad=True, device="cuda:0")
y = torch.tensor(200.0, requires_grad=True, device="cuda:0")
original_device = x.device
eager_result = self.fn_simple(x, y)

# Since all tests run in the same process, have to reset the metrics report.
met.clear_all()
torch._dynamo.reset()

fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
res_xla_dynamo = fn_simple_dynamo(x, y)
self.assertIn('xla::add', met.counter_names())
self.assertTrue(res_xla_dynamo.device == original_device)
self.assertTrue(torch.allclose(eager_result, res_xla_dynamo))

# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_reused = fn_simple_dynamo(x, y)
self.assertNotIn('xla::add', met.counter_names())
self.assertTrue(res_xla_dynamo_reused.device == original_device)
self.assertTrue(torch.allclose(eager_result, res_xla_dynamo_reused))

def test_fn_without_input(self):
# verify that dynamo can handle different inputs
res_xla_dynamo_different = fn_simple_dynamo(x + y, y * 3)
res_cpu_3 = self.fn_simple(x + y, y * 3)
self.assertTrue(res_xla_dynamo_different.device == original_device)
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different))

# There should not be any fallbacks.
self.assertEqual(torch_xla._XLAC._get_executed_fallback_ops(), [])

@parameterized.parameters(
True,
False,
)
def test_fn_without_input(self, initialize_on_cuda):

def fn_without_input(device):
constant = 0.835
expanded = torch.full((4, 4), constant, device=device)
arange = torch.arange(16, device=device).reshape(4, 4)
return expanded + arange

device = xm.xla_device()
device = self._choose_proper_device(initialize_on_cuda)

compiled_fn = torch.compile(fn_without_input, backend='openxla')
res_cpu = fn_without_input('cpu')
res_xla_dynamo = compiled_fn(device)
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))

def test_simple_model_with_in_place_ops(self):
@parameterized.parameters(
True,
False,
)
def test_simple_model_with_in_place_ops(self, initialize_on_cuda):

class TestModel(nn.Module):

Expand All @@ -139,90 +273,115 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
output = input_tensor + self.self_tensor
return output

device = self._choose_proper_device(initialize_on_cuda)

torch._dynamo.reset()
met.clear_all()
device = xm.xla_device()

cpu_model = TestModel()
xla_model = TestModel(device).to(device)
compiled_model = torch.compile(xla_model, backend='openxla')
device_model = TestModel(device).to(device)
compiled_model = torch.compile(device_model, backend='openxla')

input_tensor = torch.ones(3)
copy_tensor = torch.rand(5, 3)
index = torch.tensor([0, 4, 2, 1, 3])
xla_input_tensor = input_tensor.to(device)
xla_copy_tensor = copy_tensor.to(device)
xla_index = index.to(device)
device_input_tensor = input_tensor.to(device)
device_copy_tensor = copy_tensor.to(device)
device_index = index.to(device)

in_place_ops = ['copy_', 'add_', 'abs_']
for in_place_op in in_place_ops:
res_cpu = cpu_model.forward(
index, copy_tensor, input_tensor, op_name=in_place_op)
res_xla_dynamo = compiled_model.forward(
xla_index, xla_copy_tensor, xla_input_tensor, op_name=in_place_op)
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))

def test_einsum(self):
res_device_dynamo = compiled_model.forward(
device_index,
device_copy_tensor,
device_input_tensor,
op_name=in_place_op)
self.assertTrue(torch.allclose(res_cpu, res_device_dynamo.cpu()))

@parameterized.parameters(
True,
False,
)
def test_einsum(self, initialize_on_cuda):
# einsum currently does not have meta function to compute the shape hence
# will fallback to XLA with FakeTensor as input to infer the output shape.
def einsum_mm(a, b):
return torch.einsum('ijkl,ijlm->ijkm', a, b)

device = xm.xla_device()
a = torch.randn(4, 4, 4, 4).to(xm.xla_device())
b = torch.randn(4, 4, 4, 4).to(xm.xla_device())
device = self._choose_proper_device(initialize_on_cuda)
a = torch.randn(4, 4, 4, 4).to(device)
b = torch.randn(4, 4, 4, 4).to(device)
xm.mark_step()

dynamo_einsum_mm = torch.compile(einsum_mm, backend="openxla")
res_xla_dynamo = dynamo_einsum_mm(a, b)
res_xla_non_dynamo = einsum_mm(a, b)
res_device_dynamo = dynamo_einsum_mm(a, b)
res_device_non_dynamo = einsum_mm(a, b)
self.assertTrue(
torch.allclose(res_xla_non_dynamo.cpu(), res_xla_dynamo.cpu()))
torch.allclose(res_device_non_dynamo.cpu(), res_device_dynamo.cpu()))

def test_simple_model_with_different_input_shape(self):
met.clear_counters()
device = xm.xla_device()
xla_x = torch.randn(5, 5).to(device)
xla_y = torch.randn(5, 5).to(device)
xla_z = torch.randn(10, 10).to(device)
@parameterized.parameters(
True,
False,
)
def test_simple_model_with_different_input_shape(self, initialize_on_cuda):
met.clear_all()
device = self._choose_proper_device(initialize_on_cuda)
# We need to make `dim` depend on `initialize_on_cuda` because the XLA compilation cache
# does not clean itself between the parameterized tests.
dim = 5 + int(initialize_on_cuda)
device_x = torch.randn(dim, dim).to(device)
device_y = torch.randn(dim, dim).to(device)
new_dim = 2 * dim
device_z = torch.randn(new_dim, new_dim).to(device)
fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla")
fn_simple_dynamo(xla_x, xla_x)
fn_simple_dynamo(device_x, device_x)
compile_count = met.metric_data('CompileTime')[0]
# Execute with input with same shape should not trigger additional compilation
fn_simple_dynamo(xla_y, xla_y)
fn_simple_dynamo(device_y, device_y)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)
# Give `fn_simple_dynamo` an input with different shappe, we expect
# dynamo to recognize this is a different graph and let XLA to retrace/recompile
res_xla_dynamo_3 = fn_simple_dynamo(xla_z, xla_z)
res_xla_dynamo_3 = fn_simple_dynamo(device_z, device_z)
self.assertEqual(met.metric_data('CompileTime')[0], compile_count + 1)
self.assertTrue(
torch.allclose(
res_xla_dynamo_3.cpu(),
self.fn_simple(xla_z.cpu(), xla_z.cpu()),
self.fn_simple(device_z.cpu(), device_z.cpu()),
rtol=1e-05,
atol=1e-05))

@skipOnTpu
def test_resnet18(self):
device = xm.xla_device()
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
def get_loader(self, device, sample_count, batch_size=4):
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=batch_size)
loader = xu.SampleGenerator(
data=(torch.randn(batch_size, 3, 224, 224, device=device),
torch.zeros(batch_size, dtype=torch.int64, device=device)),
sample_count=sample_count)
return loader

@skipOnTpu
@parameterized.parameters(
True,
False,
)
def test_resnet18(self, initialize_on_cuda):
device = self._choose_proper_device(initialize_on_cuda)
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
loader = self.get_loader(device, sample_count, batch_size=4)
resnet18 = torchvision.models.resnet18()
resnet18.eval()
xla_resnet18 = torchvision.models.resnet18()
xla_resnet18.load_state_dict(resnet18.state_dict())
xla_resnet18.to(device)
xla_resnet18.eval()
device_resnet18 = torchvision.models.resnet18()
device_resnet18.load_state_dict(resnet18.state_dict())
device_resnet18.to(device)
device_resnet18.eval()
# materalize the fake data for test purpose
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
for data, _ in loader:
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
output = dynamo_resnet18(data)
output_cpu = resnet18(data.cpu())
self.assertTrue(
Expand Down Expand Up @@ -526,7 +685,7 @@ def test_resnet18(self):
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)


class DynamErrorMessageTest(unittest.TestCase):
class DynamoErrorMessageTest(unittest.TestCase):

def test_mixed_cpu_tensor(self):
device = xm.xla_device()
Expand Down
3 changes: 1 addition & 2 deletions test/dynamo/test_dynamo_aliasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.experimental.dynamo_set_buffer_donor
from torch_xla.core.dynamo_bridge import alias_with_buffer_donor_config
from torch_xla._dynamo.dynamo_bridge import alias_with_buffer_donor_config


class TestBufferDonationUtil(unittest.TestCase):
Expand Down
Loading