From d81b6994109064c3299c5cd20b78de1829313437 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Mon, 11 Nov 2024 14:28:09 +0800 Subject: [PATCH] Support more symint ops (#17) --- codegen/xla_native_functions.yaml | 3 + test/ds/test_bounded_dynamic_ops.py | 327 +++++++++++++++++++++++++++ third_party/openxla | 2 +- torch_xla/csrc/aten_xla_type.cpp | 166 ++++++++++---- torch_xla/csrc/data_ops.cpp | 166 +++++++++++++- torch_xla/csrc/data_ops.h | 12 + torch_xla/csrc/elementwise.cpp | 5 + torch_xla/csrc/helpers.cpp | 176 +++++++++++++- torch_xla/csrc/helpers.h | 44 ++++ torch_xla/csrc/nll_loss.cpp | 106 ++++++++- torch_xla/csrc/ops/expand_symint.cpp | 9 + torch_xla/csrc/ops/ops.cpp | 30 ++- torch_xla/csrc/ops/ops.h | 6 + torch_xla/csrc/ops/select.cpp | 15 +- torch_xla/csrc/ops/select_symint.cpp | 56 +++++ torch_xla/csrc/ops/select_symint.h | 31 +++ torch_xla/csrc/ops/view_symint.cpp | 62 +++++ torch_xla/csrc/ops/view_symint.h | 32 +++ torch_xla/csrc/tensor_methods.cpp | 205 +++++++++++++++-- torch_xla/csrc/tensor_methods.h | 11 + torch_xla/csrc/torch_util.cpp | 26 +++ torch_xla/csrc/torch_util.h | 5 + torch_xla/csrc/xla_lower_util.cpp | 20 ++ torch_xla/csrc/xla_lower_util.h | 5 + 24 files changed, 1433 insertions(+), 87 deletions(-) create mode 100644 test/ds/test_bounded_dynamic_ops.py create mode 100644 torch_xla/csrc/ops/select_symint.cpp create mode 100644 torch_xla/csrc/ops/select_symint.h create mode 100644 torch_xla/csrc/ops/view_symint.cpp create mode 100644 torch_xla/csrc/ops/view_symint.h diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index 199025dc7e1..565f5f7fea0 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -406,12 +406,15 @@ symint: - empty.memory_format - empty_strided - expand_copy + - full - new_empty_strided - view_copy - diagonal_backward - narrow_copy - select_backward - select.int + - slice_backward + - slice_copy.Tensor # See Note: [Disabling functionalization] - expand - view diff --git a/test/ds/test_bounded_dynamic_ops.py b/test/ds/test_bounded_dynamic_ops.py new file mode 100644 index 00000000000..2f52aca4041 --- /dev/null +++ b/test/ds/test_bounded_dynamic_ops.py @@ -0,0 +1,327 @@ +import os +import sys +import unittest +import torch +import torch.nn.functional as F +import torch_xla +import torch_xla.core.xla_model as xm + +sys.path.insert(1, os.path.join(sys.path[0], '..')) +import test_utils + +PD = torch._C._EnablePythonDispatcher() +XLA_DEVICE = xm.xla_device() + + +def _mark_dynamic(t, dims, bounds): + torch_xla._XLAC._xla_mark_bounded_dynamic(t, dims, bounds) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device), mask + ], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len=None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + torch.finfo(dtype).min) + + +def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, + past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + +class TestBoundedDynamicOps(test_utils.XlaTestCase): + + def _diff_output(self, + torch_out, + xla_out, + atol=1e-3, + rtol=1e-5, + equal_nan=True): + if isinstance(torch_out, torch.Tensor): + self.assertIsInstance(xla_out, torch.Tensor) + torch_out = torch_out.detach().cpu() + xla_out = xla_out.detach().cpu() + self.assertEqual(xla_out.dtype, torch_out.dtype) + self.assertEqual(torch_out.shape, xla_out.shape) + self.assertTrue( + torch.allclose( + torch_out, xla_out, atol=atol, rtol=rtol, equal_nan=equal_nan)) + elif isinstance(torch_out, (tuple, list)): + self.assertIsInstance(xla_out, (tuple, list)) + self.assertEqual(len(torch_out), len(xla_out)) + for o1, o2 in zip(torch_out, xla_out): + self._diff_output(o1, o2, rtol, atol) + else: + self.assertEqual(torch_out, xla_out) + + def test_add(self): + t1 = torch.randn([5, 2]) + t2 = torch.randn([5, 2]) + torch_out = t1 + t2 + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0], [10]) + _mark_dynamic(t2, [0], [10]) + xla_out = t1 + t2 + self._diff_output(torch_out, xla_out) + + def test_add_broadcast(self): + t1 = torch.randn([5, 2]) + t2 = torch.randn([2]) + torch_out = t1 + t2 + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0], [10]) + xla_out = t1 + t2 + self._diff_output(torch_out, xla_out) + + def test_add_scalar(self): + t1 = torch.randn([5, 2]) + t2 = 1.0 + torch_out = t1 + t2 + + t1 = t1.to(XLA_DEVICE) + _mark_dynamic(t1, [0], [10]) + xla_out = t1 + t2 + self._diff_output(torch_out, xla_out) + + def test_reshape(self): + x = torch.randn(4, 101, 100) + y = torch.randn(4 * 101 * 100) + torch_out = y.reshape(x.shape[0], x.shape[1], -1) + + x = x.to(XLA_DEVICE) + y = y.to(XLA_DEVICE) + _mark_dynamic(x, [0, 1], [10, 200]) + _mark_dynamic(y, [0], [10 * 200 * 100]) + xla_out = y.reshape(x.shape[0], x.shape[1], -1) + self._diff_output(torch_out, xla_out) + + def test_flatten(self): + x = torch.randn(4, 101, 100) + torch_out = x.flatten(0, 1) + + x = x.to(XLA_DEVICE) + _mark_dynamic(x, [0], [10]) + xla_out = x.flatten(0, 1) + self._diff_output(torch_out, xla_out) + + def test_arange(self): + x = torch.randn(4, 101, 100) + torch_out = torch.arange( + 0, (x.shape[0] + 1) * x.shape[1], + step=x.shape[1], + dtype=torch.int32, + device=x.device) + + x = x.to(XLA_DEVICE) + _mark_dynamic(x, [1], [200]) + xla_out = torch.arange( + 0, (x.shape[0] + 1) * x.shape[1], + step=x.shape[1], + dtype=torch.int32, + device=x.device) + self._diff_output(torch_out, xla_out) + + def test_slice_with_backward(self): + x = torch.randn(4, 101, 100) + y = torch.randn(4, 201, 100) + x.requires_grad = True + y.requires_grad = True + torch_out = y[0:10, 10:x.shape[1], ...] + torch.autograd.backward(torch_out, torch.zeros_like(torch_out)) + torch_grad = y.grad + + x = x.detach().to(XLA_DEVICE) + y = y.detach().to(XLA_DEVICE) + x.requires_grad = True + y.requires_grad = True + _mark_dynamic(x, [1], [200]) + xla_out = y[0:10, 10:x.shape[1], ...] + torch.autograd.backward(xla_out, torch.zeros_like(xla_out)) + xla_grad = y.grad + + self._diff_output(torch_out, xla_out) + self._diff_output(torch_grad, xla_grad) + + def test_attn_mask(self): + inputs_embeds = torch.randn(4, 101) + attention_mask = torch.ones((4, 101), + dtype=torch.bool).to(inputs_embeds.device) + torch_out = _prepare_decoder_attention_mask( + attention_mask, (inputs_embeds.shape[0], inputs_embeds.shape[1]), + inputs_embeds, 0) + + inputs_embeds = inputs_embeds.to(XLA_DEVICE) + attention_mask = attention_mask.to(XLA_DEVICE) + _mark_dynamic(inputs_embeds, [1], [200]) + _mark_dynamic(attention_mask, [1], [200]) + xla_out = _prepare_decoder_attention_mask( + attention_mask, (inputs_embeds.shape[0], inputs_embeds.shape[1]), + inputs_embeds, 0) + + self._diff_output(torch_out, xla_out) + + def test_matmul_0(self): + t1 = torch.randn([5, 2]).to(torch.bfloat16) + t2 = torch.randn([2, 3]).to(torch.bfloat16) + torch_out = t1.to("cuda") @ t2.to("cuda") + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0], [10]) + xla_out = t1 @ t2 + + self.assertIn('<=10,3', torch_xla._XLAC._get_xla_tensors_text([xla_out])) + self._diff_output(torch_out, xla_out) + + def test_matmul_1(self): + t1 = torch.randn([5, 2]).to(torch.bfloat16) + t2 = torch.randn([2]).to(torch.bfloat16) + torch_out = t1.to("cuda") @ t2.to("cuda") + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0], [10]) + xla_out = t1 @ t2 + + self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([xla_out])) + self._diff_output(torch_out, xla_out) + + def test_matmul_2(self): + t1 = torch.randn([10, 5, 2]).to(torch.bfloat16) + t2 = torch.randn([2]).to(torch.bfloat16) + torch_out = t1.to("cuda") @ t2.to("cuda") + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0, 1], [20, 10]) + xla_out = t1 @ t2 + self.assertIn('<=20,<=10', torch_xla._XLAC._get_xla_tensors_text([xla_out])) + self._diff_output(torch_out.cpu(), xla_out) + + def test_matmul_3(self): + t1 = torch.randn([10, 3, 4]).to(torch.bfloat16) + t2 = torch.randn([10, 4, 5]).to(torch.bfloat16) + torch_out = t1.to("cuda") @ t2.to("cuda") + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0, 1], [20, 10]) + _mark_dynamic(t2, [0], [20]) + xla_out = t1 @ t2 + self.assertIn('<=20,<=10,5', + torch_xla._XLAC._get_xla_tensors_text([xla_out])) + self._diff_output(torch_out, xla_out) + + def test_matmul_4(self): + t1 = torch.randn([10, 3, 4]).to(torch.bfloat16) + t2 = torch.randn([4, 5]).to(torch.bfloat16) + torch_out = t1.to("cuda") @ t2.to("cuda") + + t1 = t1.to(XLA_DEVICE) + t2 = t2.to(XLA_DEVICE) + _mark_dynamic(t1, [0, 1], [20, 10]) + xla_out = t1 @ t2 + self.assertIn('<=20,<=10,5', + torch_xla._XLAC._get_xla_tensors_text([xla_out])) + self._diff_output(torch_out, xla_out) + + def test_triu(self): + t = torch.randn(4, 4) + torch_out = torch.triu(t, diagonal=1) + + t = t.to(XLA_DEVICE) + _mark_dynamic(t, [0, 1], [10, 10]) + xla_out = torch.triu(t, diagonal=1) + + self.assertIn('<=10,<=10', torch_xla._XLAC._get_xla_tensors_text([xla_out])) + self._diff_output(torch_out, xla_out) + + def test_nll_loss_with_backward(self): + logits = torch.randn(20, 30) + target = torch.randint(0, 30, (20,), dtype=torch.long) + logits.requires_grad = True + torch_out = F.nll_loss(logits, target) + torch_out.backward() + torch_grad = logits.grad + + logits = logits.detach().to(XLA_DEVICE) + logits.requires_grad = True + target = target.to(XLA_DEVICE) + _mark_dynamic(logits, [0], [50]) + _mark_dynamic(target, [0], [50]) + xla_out = F.nll_loss(logits, target) + xla_out.backward() + xla_grad = logits.grad + self.assertIn('<=50,30', torch_xla._XLAC._get_xla_tensors_text([xla_grad])) + + self._diff_output(torch_out, xla_out) + self._diff_output(torch_grad, xla_grad) + + +if __name__ == '__main__': + assert test_utils.is_disc_backend() + os.environ['USE_BOUND_FOR_SHAPE_COMPARE'] = os.getenv( + 'USE_BOUND_FOR_SHAPE_COMPARE', '1') + test = unittest.main() + # DISABLE PYTHON DISPATCHER FLAG + del PD + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/third_party/openxla b/third_party/openxla index 490cf85ccdc..09a5a713612 160000 --- a/third_party/openxla +++ b/third_party/openxla @@ -1 +1 @@ -Subproject commit 490cf85ccdc8d48e5b2143454930666d6a767964 +Subproject commit 09a5a713612f731e33a632e6cc3ccf44e7326980 diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 60296439ede..a19eb515e5e 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -477,7 +477,13 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); dst.resize_as_(typed_tensor).copy_(typed_tensor); } else { - tensor_methods::copy_(dst_tensor, self_tensor); + if (dst_tensor->shape().get().is_dynamic()) { + XLA_CHECK(self_tensor->shape().get().is_dynamic()) + << "self tensor is not dynamic!"; + tensor_methods::copy_symint_(dst_tensor, dst.sym_sizes(), self_tensor); + } else { + tensor_methods::copy_(dst_tensor, self_tensor); + } bridge::ReplaceXlaTensor(dst, dst_tensor); } return dst; @@ -580,6 +586,15 @@ XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) { at::Tensor XLANativeFunctions::_log_softmax(const at::Tensor& self, int64_t dim, bool half_to_float) { + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + if (self_tensor->shape().get().is_dynamic()) { + XLA_CHECK( + !runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) + << "FUNCTIONALIZATION should be enabled when using _log_softmax for " + "symint tensor"; + return at::functionalization::functionalize_aten_op::call(self, dim, half_to_float); + } TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto self_meta = to_meta(self); auto out_meta = at::meta::_log_softmax(self_meta, dim, half_to_float); @@ -631,13 +646,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - // Currently, we disallow the case when both operands contain dynamic - // dimensions. This is consistent with PyTorch's behavior. - XLA_CHECK(!(tensor_has_dym_dim(self) && tensor_has_dym_dim(other))) - << "Both operands of torch.add cannot have dynamic dimensions at the " - "same time. This is not " - "supported in PyTorch/XLA."; - at::native::alpha_check(at::result_type(self, other), alpha); return DoBinaryOp(self, other, [&](const XLATensorPtr& xself, const XLATensorPtr& xother, @@ -1323,22 +1331,23 @@ at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, }); } -at::Tensor XLANativeFunctions::full(at::IntArrayRef size, - const at::Scalar& fill_value, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory) { +at::Tensor XLANativeFunctions::full_symint(at::SymIntArrayRef size, + const at::Scalar& fill_value, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { TORCH_LAZY_FN_COUNTER("xla::"); // Fall back to CPU if layout or pin_memory are not default if (layout.value_or(at::Layout::Strided) != at::Layout::Strided || pin_memory.value_or(false)) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( - size, fill_value, dtype, layout, device, pin_memory); + C10_AS_INTARRAYREF_SLOW(size), fill_value, dtype, layout, device, + pin_memory); } - return bridge::AtenFromXlaTensor(tensor_methods::full( - absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); + return bridge::AtenFromXlaTensor(tensor_methods::full_symint( + size, fill_value, GetXlaDeviceOrCurrent(device), + at::dtype_or_default(dtype))); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, @@ -2842,16 +2851,34 @@ at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output))); } -at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, - c10::optional start, - c10::optional end, - int64_t step) { +at::Tensor XLANativeFunctions::slice_copy_symint( + const at::Tensor& self, int64_t dim, c10::optional start, + c10::optional end, c10::SymInt step) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - int64_t start_val = start.has_value() ? start.value() : 0; - int64_t end_val = end.has_value() ? end.value() : INT64_MAX; + c10::SymInt start_val = start.has_value() ? start.value() : 0; + c10::SymInt end_val = end.has_value() ? end.value() : self.sym_sizes()[dim]; + if (!start_val.is_symbolic() && start_val < 0) { + start_val = self.sym_sizes()[dim] + start_val; + } + if (!end_val.is_symbolic()) { + if (end_val > self.sym_sizes()[dim]) { + end_val = self.sym_sizes()[dim]; + } else if (end_val < 0) { + end_val = self.sym_sizes()[dim] + end_val; + } + } + if (XlaHelpers::IsDISCBackend() && + (start_val.is_symbolic() || end_val.is_symbolic() || + step.is_symbolic())) { + return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( + tensor_methods::slice_symint(bridge::GetXlaTensor(self), dim, start_val, + end_val, step), + self)); + } return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( - tensor_methods::slice(bridge::GetXlaTensor(self), dim, start_val, end_val, - step), + tensor_methods::slice(bridge::GetXlaTensor(self), dim, + start_val.expect_int(), end_val.expect_int(), + step.expect_int()), self)); } @@ -3037,13 +3064,6 @@ at::Tensor XLANativeFunctions::sub(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - // Currently, we disallow the case when both operands contain dynamic - // dimensions. This is consistent with PyTorch's behavior. - XLA_CHECK(!(tensor_has_dym_dim(self) && tensor_has_dym_dim(other))) - << "Both operands of torch.sub cannot have dynamic dimensions at the " - "same time. This is not " - "supported in PyTorch/XLA."; - CheckSubOperandTypes(self.scalar_type(), other.scalar_type()); at::native::alpha_check(at::result_type(self, other), alpha); return DoBinaryOp(self, other, @@ -3186,6 +3206,11 @@ at::Tensor& XLANativeFunctions::uniform_( at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + if (self_tensor->shape().get().is_dynamic()) { + return bridge::AtenFromXlaTensor(tensor_methods::unsqueeze_symint( + bridge::GetXlaTensor(self), self.sym_sizes(), dim)); + } return bridge::AtenFromXlaTensor( tensor_methods::unsqueeze(bridge::GetXlaTensor(self), dim)); } @@ -3353,12 +3378,35 @@ at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, XLATensorPtr xla_input = bridge::GetXlaTensor(self); bool input_has_dyn_shape = xla_input->shape().get().is_dynamic(); - XLA_CHECK(!(input_has_dyn_shape && input_shape_static)) - << "This view op has dynamic input tensor but static input shape. This " - "behavior is currently unsupported; if the user believes this must be " - "supported, please file a feature request against PyTorch/XLA."; + if (input_has_dyn_shape && input_shape_static) { + std::vector sym_shape; + sym_shape.resize(shape.size()); + int dyn_dim = -1; + int64_t complete_element_count = 1; + for (int i = 0; i < int_shape.value().size(); i++) { + if (int_shape.value()[i] == -1) { + dyn_dim = i; + } else { + complete_element_count *= int_shape.value()[i]; + sym_shape[i] = c10::SymInt(int_shape.value()[i]); + } + } + XLA_CHECK(dyn_dim != -1) + << "This view op has dynamic input tensor but static input shape. This " + "behavior is currently unsupported; if the user believes this must " + "be " + "supported, please file a feature request against PyTorch/XLA."; + sym_shape[dyn_dim] = self.sym_numel() / complete_element_count; + return bridge::AtenFromXlaTensor(tensor_methods::view_symint( + xla_input, c10::SymIntArrayRef(sym_shape.data(), sym_shape.size()))); + } + + if (input_has_dyn_shape) { + return bridge::AtenFromXlaTensor( + tensor_methods::view_symint(xla_input, shape)); + } return bridge::AtenFromXlaTensor( - tensor_methods::view_symint(xla_input, shape)); + tensor_methods::view(xla_input, int_shape.value())); } at::Tensor XLANativeFunctions::where(const at::Tensor& condition, @@ -3596,8 +3644,21 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, } TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::embedding( - bridge::GetXlaTensor(weight), bridge::GetXlaTensor(indices))); + auto weight_tensor = bridge::GetXlaTensor(weight); + auto indices_tensor = bridge::GetXlaTensor(indices); + if (weight_tensor->shape().get().is_dynamic() || + indices_tensor->shape().get().is_dynamic()) { + std::vector final_size; + auto indices_sizes = indices.sym_sizes(); + final_size.insert(final_size.begin(), indices_sizes.begin(), + indices_sizes.end()); + final_size.push_back(weight.sym_sizes()[1]); + return bridge::AtenFromXlaTensor(tensor_methods::embedding_symint( + bridge::GetXlaTensor(weight), bridge::GetXlaTensor(indices), final_size, + indices.sym_numel())); + } + return bridge::AtenFromXlaTensor( + tensor_methods::embedding(weight_tensor, indices_tensor)); } at::Tensor XLANativeFunctions::_euclidean_dist(const at::Tensor& x1, @@ -3739,17 +3800,24 @@ at::Tensor XLANativeFunctions::diagonal_backward_symint( diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2); } -at::Tensor XLANativeFunctions::slice_backward(const at::Tensor& grad_output, - at::IntArrayRef input_sizes, - int64_t dim, int64_t start, - int64_t end, int64_t step) { +at::Tensor XLANativeFunctions::slice_backward_symint( + const at::Tensor& grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, + c10::SymInt start, c10::SymInt end, c10::SymInt step) { // See Note: [Disabling functionalization] if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { - return at::native::slice_backward(grad_output, input_sizes, dim, start, end, - step); - } - return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, start, end, step); + return at::native::slice_backward( + grad_output, C10_AS_INTARRAYREF_SLOW(input_sizes), dim, + start.guard_int(__FILE__, __LINE__), end.guard_int(__FILE__, __LINE__), + step.guard_int(__FILE__, __LINE__)); + } + at::Scalar fill_value = grad_output.dtype() == at::kBool ? false : 0; + at::Tensor grad_input = full_symint( + input_sizes, fill_value, at::typeMetaToScalarType(grad_output.dtype()), + grad_output.layout(), grad_output.device(), false); + // TODO: we may need to support slice_scatter_symint + return slice_scatter( + grad_input, grad_output, dim, start.guard_int(__FILE__, __LINE__), + end.guard_int(__FILE__, __LINE__), step.guard_int(__FILE__, __LINE__)); } at::Tensor XLANativeFunctions::permute(const at::Tensor& self, diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 6508dd71cc3..06e10b92bf4 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -97,6 +97,27 @@ xla::XlaOp BuildView(xla::XlaOp input, absl::Span output_sizes) { return XlaHelpers::DynamicReshape(input, complete_output_sizes); } +xla::XlaOp BuildViewSymInt(xla::XlaOp input, + absl::Span size_ops, + const std::vector& upper_bounds, + const std::vector& dynamic_dims) { + xla::Shape output_shape = + xla::ShapeUtil::MakeShape(ShapeHelper::ShapeOfXlaOp(input).element_type(), + {upper_bounds}, {dynamic_dims}); + std::vector complete_size_ops; + size_t curr_size_ops_index = 0; + for (size_t i = 0; i < dynamic_dims.size(); i++) { + if (dynamic_dims[i]) { + complete_size_ops.push_back(size_ops[curr_size_ops_index++]); + } else { + complete_size_ops.push_back( + XlaHelpers::ScalarValue(upper_bounds[i], input.builder())); + } + } + return XlaHelpers::DynamicBoundedReshape(input, complete_size_ops, + output_shape); +} + xla::XlaOp SetDimensionSizes(xla::XlaOp input, absl::Span symbolic_output_sizes, std::vector dynamic_dims) { @@ -138,13 +159,77 @@ xla::XlaOp BuildExpand(xla::XlaOp input, auto input_sizes = XlaHelpers::SizesOfXlaOp(input); // Adjust the rank of the input to match the rank of the output. XLA_CHECK_LE(input_sizes.size(), output_sizes.size()); - input_sizes.insert(input_sizes.begin(), - output_sizes.size() - input_sizes.size(), 1); - xla::XlaOp implicit_reshape = XlaHelpers::DynamicReshape(input, input_sizes); + xla::XlaOp implicit_reshape = input; + + if (input_sizes.size() < output_sizes.size()) { + input_sizes.insert(input_sizes.begin(), + output_sizes.size() - input_sizes.size(), 1); + implicit_reshape = XlaHelpers::DynamicReshape(input, input_sizes); + } + if (output_sizes.size() == 0) { + return implicit_reshape; + } + + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(implicit_reshape); + if (XlaHelpers::IsDISCBackend() && input_shape.is_dynamic()) { + std::vector get_dim_ops; + for (size_t dim = 0; dim < input_shape.rank(); dim++) { + if (input_shape.is_dynamic_dimension(dim)) { + get_dim_ops.push_back(xla::GetDimensionSize(implicit_reshape, dim)); + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + output_sizes[dim], implicit_reshape.builder())); + } + } + xla::XlaOp sym_op = XlaHelpers::CreateShapeTensor(get_dim_ops); + xla::Shape final_shape = xla::ShapeUtil::MakeShape( + XlaHelpers::TypeOfXlaOp(implicit_reshape), output_sizes, + runtime::util::ToVector(input_shape.dynamic_dimensions())); + return xla::DynamicBroadcastInDim( + implicit_reshape, sym_op, + torch::lazy::Iota(output_sizes.size()), final_shape); + } + return xla::BroadcastInDim(implicit_reshape, output_sizes, torch::lazy::Iota(output_sizes.size())); } +xla::XlaOp BuildExpandSymInt(xla::XlaOp input, + absl::Span output_sizes, + const std::vector& size_ops, + const std::vector& dynamic_dims) { + auto input_sizes = XlaHelpers::SizesOfXlaOp(input); + // Adjust the rank of the input to match the rank of the output. + XLA_CHECK_LE(input_sizes.size(), output_sizes.size()); + xla::XlaOp implicit_reshape = input; + + if (input_sizes.size() < output_sizes.size()) { + input_sizes.insert(input_sizes.begin(), + output_sizes.size() - input_sizes.size(), 1); + implicit_reshape = XlaHelpers::DynamicReshape(input, input_sizes); + } + if (output_sizes.size() == 0) { + return implicit_reshape; + } + + size_t current_size_index = 0; + std::vector get_dim_ops; + for (size_t i = 0; i < dynamic_dims.size(); i++) { + if (dynamic_dims[i]) { + get_dim_ops.push_back(size_ops[current_size_index++]); + } else { + get_dim_ops.push_back( + XlaHelpers::ScalarValue(output_sizes[i], input.builder())); + } + } + xla::XlaOp sym_op = XlaHelpers::CreateShapeTensor(get_dim_ops); + xla::Shape final_shape = xla::ShapeUtil::MakeShape( + XlaHelpers::TypeOfXlaOp(input), output_sizes, dynamic_dims); + return xla::DynamicBroadcastInDim( + implicit_reshape, sym_op, torch::lazy::Iota(output_sizes.size()), + final_shape); +} + xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp scalar) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); @@ -152,8 +237,30 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) { xla::Shape shape = XlaHelpers::GetPromotedShape(input_shape, mask_shape); - input = BuildExpand(input, shape.dimensions()); - mask = BuildExpand(mask, shape.dimensions()); + if (shape.is_dynamic()) { + std::vector size_ops; + int min_rank = std::min(input_shape.rank(), mask_shape.rank()); + for (int i = 0; i < shape.rank(); i++) { + if (shape.is_dynamic_dimension(i)) { + int input_dim = input_shape.rank() - min_rank + i; + int mask_dim = mask_shape.rank() - min_rank + i; + if (input_shape.is_dynamic_dimension(input_dim)) { + size_ops.push_back(xla::GetDimensionSize(input, input_dim)); + } else { + size_ops.push_back(xla::GetDimensionSize(mask, mask_dim)); + } + } + } + input = BuildExpandSymInt( + input, shape.dimensions(), size_ops, + runtime::util::ToVector(shape.dynamic_dimensions())); + mask = BuildExpandSymInt( + mask, shape.dimensions(), size_ops, + runtime::util::ToVector(shape.dynamic_dimensions())); + } else { + input = BuildExpand(input, shape.dimensions()); + mask = BuildExpand(mask, shape.dimensions()); + } } xla::XlaOp zero = xla::Zero(mask.builder(), XlaHelpers::TypeOfXlaOp(mask)); @@ -370,8 +477,13 @@ xla::XlaOp BuildUnselect(xla::XlaOp target, xla::XlaOp source, int64_t dim, xla::PrimitiveType pred_type = GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::PRED); - xla::XlaOp source_true = XlaHelpers::ScalarBroadcast( - 1, pred_type, source_shape.dimensions(), source.builder()); + xla::XlaOp source_true; + if (source_shape.is_dynamic()) { + source_true = XlaHelpers::DynamicScalarBroadcast(1, pred_type, source); + } else { + source_true = XlaHelpers::ScalarBroadcast( + 1, pred_type, source_shape.dimensions(), source.builder()); + } xla::XlaOp pred_zero = xla::Zero(target.builder(), pred_type); xla::XlaOp zero = xla::Zero(target.builder(), target_shape.element_type()); xla::PaddingConfig padding_config; @@ -554,4 +666,44 @@ xla::XlaOp PadInDim(xla::XlaOp input, int64_t dim, int64_t pad_lo, return xla::Pad(input, *pad_value, padding_config); } +xla::XlaOp BuildSelectSymInt(xla::XlaOp input, int dim, xla::XlaOp start, + xla::XlaOp end, xla::XlaOp stride, + const xla::Shape& output_shape) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + + std::vector start_ops, limit_ops, stride_ops; + + xla::XlaOp salar_zero = XlaHelpers::ScalarValue(0, input.builder()); + xla::XlaOp salar_one = XlaHelpers::ScalarValue(1, input.builder()); + + for (int i = 0; i < input_shape.rank(); ++i) { + if (i == dim) { + start_ops.push_back(start); + limit_ops.push_back(end); + stride_ops.push_back(stride); + } else { + start_ops.push_back(salar_zero); + if (input_shape.is_dynamic_dimension(i)) { + limit_ops.push_back(xla::GetDimensionSize(input, i)); + } else { + limit_ops.push_back(XlaHelpers::ScalarValue( + input_shape.dimensions(i), input.builder())); + } + stride_ops.push_back(salar_one); + } + } + + xla::Shape tensor_shape = xla::ShapeUtil::MakeShape( + XlaHelpers::TypeOfXlaOp(salar_zero), {input_shape.rank()}, {false}); + + xla::XlaOp start_tensor = XlaHelpers::CreateShapeTensor(start_ops); + xla::XlaOp limit_tensor = XlaHelpers::CreateShapeTensor(limit_ops); + xla::XlaOp stride_tensor = XlaHelpers::CreateShapeTensor(stride_ops); + xla::XlaOp output = xla::CustomCall( + input.builder(), "mhlo.real_dynamic_slice", + /*operands=*/{input, start_tensor, limit_tensor, stride_tensor}, + /*shape*/ output_shape); + return output; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index 067a77abc23..1192339e947 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -25,6 +25,10 @@ std::vector GetCompleteShape(absl::Span output_sizes, // Creates a new tensor with the same data as the input tensor and the specified // output size. xla::XlaOp BuildView(xla::XlaOp input, absl::Span output_sizes); +xla::XlaOp BuildViewSymInt(xla::XlaOp input, + absl::Span size_ops, + const std::vector& upper_bounds, + const std::vector& dynamic_dims); // Return a new XlaOp that reflects dynamic dimensions xla::XlaOp SetDimensionSizes(xla::XlaOp input, @@ -42,6 +46,10 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); // output sizes. xla::XlaOp BuildExpand(xla::XlaOp input, absl::Span output_sizes); +xla::XlaOp BuildExpandSymInt(xla::XlaOp input, + absl::Span output_sizes, + const std::vector& size_ops, + const std::vector& dynamic_dims); xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp scalar); @@ -111,6 +119,10 @@ xla::XlaOp BuildReplicationPadBackward(xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp PadInDim(xla::XlaOp input, int64_t dim, int64_t pad_lo, int64_t pad_hi, const xla::XlaOp* pad_value = nullptr); +xla::XlaOp BuildSelectSymInt(xla::XlaOp input, int dim, xla::XlaOp start, + xla::XlaOp end, xla::XlaOp stride, + const xla::Shape& output_shape); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_DATA_OPS_H_ diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 4facf43f6c8..3f819511f94 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -59,6 +59,11 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, const xla::Shape& output_shape = ShapeHelper::ShapeOfXlaOp(output); xla::XlaOp xla_threshold = XlaHelpers::ScalarValue( threshold, input_shape.element_type(), builder); + if (input_shape.is_dynamic()) { + xla::XlaOp xla_value = XlaHelpers::DynamicScalarBroadcast( + value, output_shape.element_type(), input); + return xla::Select(xla::Gt(input, xla_threshold), output, xla_value); + } xla::XlaOp xla_value = XlaHelpers::ScalarValue( value, output_shape.element_type(), builder); return xla::Select(xla::Gt(input, xla_threshold), output, diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 673dfbf9884..89d1585a60d 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -311,8 +311,26 @@ xla::XlaOp XlaHelpers::DynamicReshape(xla::XlaOp input, } auto info = GetDynamicReshapeInfo(input_shape, output_sizes); if (info) { - return xla::ReshapeWithInferredDimension(input, output_sizes, - info->dynamic_dimension); + int input_dyn_dim = -1; + for (int i = 0; i < input_shape.rank(); i++) { + if (input_shape.is_bounded_dynamic_dimension(i)) { + XLA_CHECK(input_dyn_dim == -1) + << "Reshaping with multiple bounded shapes is not supported."; + input_dyn_dim = i; + } + } + + auto output_shape = info->output_shape; + std::vector dim_sizes; + for (int i = 0; i < output_shape.rank(); i++) { + if (output_shape.is_bounded_dynamic_dimension(i)) { + dim_sizes.push_back(xla::GetDimensionSize(input, input_dyn_dim)); + } else { + dim_sizes.push_back(XlaHelpers::ScalarValue( + output_shape.dimensions(i), input.builder())); + } + } + return DynamicBoundedReshape(input, dim_sizes, output_shape); } return xla::Reshape(input, output_sizes); } @@ -340,6 +358,20 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { }); } +xla::XlaOp XlaHelpers::DynamicBoundedReshape( + xla::XlaOp input, const std::vector& size_ops, + const xla::Shape& shape) { + if (!XlaHelpers::IsDISCBackend()) { + return xla::DynamicReshape( + input, size_ops, shape.dimensions(), + runtime::util::ToVector(shape.dynamic_dimensions())); + } + + auto shape_tensor = XlaHelpers::CreateShapeTensor(size_ops); + return xla::CustomCall(input.builder(), "mhlo.dynamic_reshape", + {input, shape_tensor}, shape); +} + xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { @@ -731,7 +763,10 @@ std::pair XlaHelpers::PromoteShapes(xla::XlaOp op1, const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp(op2); xla::Shape shape = GetPromotedShape(shape1, shape2); - if (shape1.is_unbounded_dynamic() || shape2.is_unbounded_dynamic()) { + if (XlaHelpers::IsDISCBackend() && + (shape1.is_dynamic() || shape2.is_dynamic())) { + return ImplicitBroadcastWithBoundedDynamicShapes(op1, op2, shape); + } else if (shape1.is_unbounded_dynamic() || shape2.is_unbounded_dynamic()) { return ImplicitBroadcastWithUnboundedDynamicShapes(op1, op2, shape); } @@ -812,6 +847,25 @@ std::vector ExtractDimensionSizesWithPadding(const xla::XlaOp op, return op_dims; } +xla::XlaOp CreateShapeTensorWithPadding(xla::XlaOp operand, int pad_count) { + const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(operand); + std::vector op_dims; + for (size_t i = 0; i < pad_count; i++) { + op_dims.push_back(XlaHelpers::ScalarValue(1, operand.builder())); + } + + for (int dim = 0; dim < shape.rank(); dim++) { + if (shape.is_dynamic_dimension(dim)) { + op_dims.push_back(xla::GetDimensionSize(operand, dim)); + } else { + op_dims.push_back(XlaHelpers::ScalarValue(shape.dimensions(dim), + operand.builder())); + } + } + + return XlaHelpers::CreateShapeTensor(op_dims); +} + // Stringify the broadcast dimensions to provide for the 'backend_config' // attribute of the generated custom_call. std::string StringifyBroadcastDimensions(std::vector broadcast_dims) { @@ -843,6 +897,77 @@ xla::XlaOp DynamicBroadcastInDim(xla::XlaOp op, const xla::Shape& final_shape, } // namespace +xla::XlaOp XlaHelpers::CreateOutputDimsTensor(xla::XlaOp operand) { + const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(operand); + std::vector get_dim_ops; + for (int dim = 0; dim < shape.rank(); dim++) { + if (shape.is_dynamic_dimension(dim)) { + get_dim_ops.push_back(xla::GetDimensionSize(operand, dim)); + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + shape.dimensions(dim), operand.builder())); + } + } + return XlaHelpers::CreateShapeTensor(get_dim_ops); +} + +xla::XlaOp XlaHelpers::CreateShapeTensor( + const std::vector& size_ops) { + XLA_CHECK(!size_ops.empty()) << "size_ops should not be empty"; + auto front = size_ops.front(); + xla::Shape shape = xla::ShapeUtil::MakeShape(XlaHelpers::TypeOfXlaOp(front), + {size_ops.size()}); + return xla::CustomCall(front.builder(), "tensor.from_elements", + /*operands=*/size_ops, + /*shape*/ shape); +} + +xla::XlaOp XlaHelpers::DynamicBoundedBroadcast( + xla::XlaOp input, xla::XlaOp aux_input, + const std::vector& aux_input_dimensions) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); + std::vector output_dimensions; + std::vector output_dynamic; + // Collect the dimension sizes and dynamic dimensions corresponding to the + // final broadcasted shape. + for (auto dim : aux_input_dimensions) { + output_dimensions.push_back(aux_input_shape.dimensions(dim)); + output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim)); + } + + for (int dim = 0; dim < input_shape.rank(); dim++) { + output_dimensions.push_back(input_shape.dimensions(dim)); + output_dynamic.push_back(input_shape.is_dynamic_dimension(dim)); + } + + std::vector get_dim_ops; + for (auto dim : aux_input_dimensions) { + if (aux_input_shape.is_dynamic_dimension(dim)) { + get_dim_ops.push_back(xla::GetDimensionSize(aux_input, dim)); + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + aux_input_shape.dimensions(dim), aux_input.builder())); + } + } + + for (int dim = 0; dim < input_shape.rank(); dim++) { + if (input_shape.is_dynamic_dimension(dim)) { + get_dim_ops.push_back(xla::GetDimensionSize(input, dim)); + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + input_shape.dimensions(dim), input.builder())); + } + } + + // Create shape tensor + auto shape_tensor = XlaHelpers::CreateShapeTensor(get_dim_ops); + + xla::Shape final_shape = xla::ShapeUtil::MakeShape( + input_shape.element_type(), output_dimensions, output_dynamic); + return DynamicBroadcastInDim(input, final_shape, shape_tensor); +} + xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( xla::XlaOp input, xla::XlaOp aux_input, const std::vector& aux_input_dimensions) { @@ -901,6 +1026,51 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( return DynamicBroadcastInDim(input, final_shape, concat_op); } +std::pair +XlaHelpers::ImplicitBroadcastWithBoundedDynamicShapes(xla::XlaOp op1, + xla::XlaOp op2, + const xla::Shape& shape) { + const xla::Shape& shape1 = ShapeHelper::ShapeOfXlaOp(op1); + const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp(op2); + + XLA_CHECK(shape.dimensions().size() == + std::max(shape1.dimensions().size(), shape2.dimensions().size())); + + // Collect the dimension sizes of the 'op1' and 'op2' in 'op1_dims' and + // 'op2_dims' resp. with potential padding. + // Example: + // shape1 = [9, ?, 6, ?, ?] + // shape2 = [6, 1, 2] + // shape = [9, ?, 6, ?, 2] where ?: represents bounded dynamic size. + // + // rank(shape1) = rank(shape): No padding needed and + // the pre-broadcast result, 'op1_dims', just collects the dimension sizes + // of shape1. + // op1_dims = [9, ?, 6, ?, ?] + // rank(shape2) < rank(shape): Make the rank of shape2 match + // that of shape by padding it with 1's. The pre-broadcast result is + // 'op2_dims'. + // op2_dims = [1, 1, 6, 1, 2] + xla::XlaOp op1_dims = CreateShapeTensorWithPadding( + op1, shape.dimensions().size() - shape1.dimensions().size()); + xla::XlaOp op2_dims = CreateShapeTensorWithPadding( + op2, shape.dimensions().size() - shape2.dimensions().size()); + + if (shape1 == shape) { + return std::pair( + op1, DynamicBroadcastInDim(op2, shape, op1_dims)); + } else if (shape2 == shape) { + return std::pair( + DynamicBroadcastInDim(op1, shape, op2_dims), op2); + } + // The broadcasted shape is the max of the individual pre-broadcasted + // shapes. final_broadcast_dimensions = max(op1_dims, op2_dims). + auto final_broadcast_dimensions = xla::Max(op1_dims, op2_dims); + return std::pair( + DynamicBroadcastInDim(op1, shape, final_broadcast_dimensions), + DynamicBroadcastInDim(op2, shape, final_broadcast_dimensions)); +} + std::pair XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( xla::XlaOp op1, xla::XlaOp op2, const xla::Shape& shape) { diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 1afecb5b0bf..bd8622d5f10 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -15,6 +15,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/shape_helper.h" #include "tsl/platform/bfloat16.h" #include "xla/client/xla_builder.h" #include "xla/literal_util.h" @@ -134,6 +135,34 @@ class XlaHelpers { static xla::XlaOp CreateReturnValue(xla::XlaBuilder* builder, const std::vector& outputs); + static xla::XlaOp CreateOutputDimsTensor(xla::XlaOp operand); + + static xla::XlaOp CreateShapeTensor(const std::vector& size_ops); + + template + static xla::XlaOp DynamicScalarBroadcast(T scalar_value, + xla::PrimitiveType type, + xla::XlaOp aux_input) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); + xla::Shape output_shape = xla::ShapeUtil::MakeShape( + type, input_shape.dimensions(), + runtime::util::ToVector(input_shape.dynamic_dimensions())); + xla::XlaOp scalar_op = + ScalarValue(scalar_value, type, aux_input.builder()); + return xla::DynamicBroadcastInDim( + scalar_op, CreateOutputDimsTensor(aux_input), {}, output_shape); + } + + template + static xla::XlaOp DynamicScalarBroadcast(T scalar_value, + xla::XlaOp aux_input) { + xla::XlaOp scalar_op = ScalarValue(scalar_value, TypeOfXlaOp(aux_input), + aux_input.builder()); + return xla::DynamicBroadcastInDim(scalar_op, + CreateOutputDimsTensor(aux_input), {}, + ShapeHelper::ShapeOfXlaOp(aux_input)); + } + // Creates a scalar broadcasted to a given shape. template static xla::XlaOp ScalarBroadcast(T scalar_value, xla::PrimitiveType type, @@ -166,12 +195,23 @@ class XlaHelpers { false); } + static bool IsDISCBackend() { + return runtime::sys_util::GetEnvString("DISC_DEVICE", "") != ""; + } + // Creates custom_call to express dynamic reshape op using the dimension // sizes of 'aux_input'. static xla::XlaOp DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes); + static xla::XlaOp DynamicBoundedReshape( + xla::XlaOp input, const std::vector& shape_ops, + const xla::Shape& shape); + + static xla::XlaOp DynamicBoundedBroadcast( + xla::XlaOp input, xla::XlaOp aux_input, + const std::vector& aux_input_dimensions); // Broadcasts 'input' shape to // shape(aux_input)[aux_input_dimensions] x shape(input). // This method is used as a replacement for xla::Broadcast when unbounded @@ -322,6 +362,10 @@ class XlaHelpers { ImplicitBroadcastWithUnboundedDynamicShapes(xla::XlaOp op1, xla::XlaOp op2, const xla::Shape& shape); + static std::pair + ImplicitBroadcastWithBoundedDynamicShapes(xla::XlaOp op1, xla::XlaOp op2, + const xla::Shape& shape); + // Retuns the explicit broadcasting specifications on operations between // arrays of different ranks. static std::vector getBroadcastDimensions(xla::XlaOp op1, diff --git a/torch_xla/csrc/nll_loss.cpp b/torch_xla/csrc/nll_loss.cpp index 566b5b9971a..b4ed26d5c43 100644 --- a/torch_xla/csrc/nll_loss.cpp +++ b/torch_xla/csrc/nll_loss.cpp @@ -67,11 +67,65 @@ xla::XlaOp LabelsToOneHot(xla::XlaBuilder* builder, int64_t depth, int axis, std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); xla::XlaOp one_hot_bool = xla::Eq(indices, iota, broadcast_dims); + const xla::Shape& one_hot_shape = ShapeHelper::ShapeOfXlaOp(one_hot_bool); + if (XlaHelpers::IsDISCBackend() && one_hot_shape.is_dynamic()) { + xla::Shape final_shape = xla::ShapeUtil::MakeShape( + XlaHelpers::TypeOfXlaOp(on_value), one_hot_shape.dimensions(), + runtime::util::ToVector(one_hot_shape.dynamic_dimensions())); + xla::XlaOp sym_op = XlaHelpers::CreateOutputDimsTensor(one_hot_bool); + on_value = xla::DynamicBroadcastInDim(on_value, sym_op, {}, final_shape); + off_value = xla::DynamicBroadcastInDim(off_value, sym_op, {}, final_shape); + return xla::Select(one_hot_bool, on_value, off_value); + } // Selects the user-provided off_value and on_value values. return xla::Select(one_hot_bool, xla::Broadcast(on_value, output_dimensions), xla::Broadcast(off_value, output_dimensions)); } +WeightScale DynamicGetMaskedWeight(xla::XlaOp weight, xla::XlaOp logits, + const xla::Shape& logits_shape, + xla::XlaOp labels, xla::XlaOp one_hot_labels, + int axis, int ignore_index, + bool non_zero_scale) { + const xla::Shape& labels_shape = ShapeHelper::ShapeOfXlaOp(labels); + xla::XlaOp valid_bitmap = xla::Ne( + labels, XlaHelpers::ScalarValue( + ignore_index, labels_shape.element_type(), labels.builder())); + xla::XlaOp xweight; + xla::Shape f32_shape = xla::ShapeUtil::MakeShape( + xla::PrimitiveType::F32, logits_shape.dimensions(), + runtime::util::ToVector(logits_shape.dynamic_dimensions())); + xla::XlaOp sym_op = XlaHelpers::CreateOutputDimsTensor(logits); + if (!weight.IsUninitialized()) { + xla::Shape weight_shape = f32_shape; + weight_shape.set_element_type(XlaHelpers::TypeOfXlaOp(weight)); + xweight = xla::DynamicBroadcastInDim(weight, sym_op, {1}, logits_shape); + } else { + xweight = XlaHelpers::DynamicScalarBroadcast(1.0, logits); + } + xla::XlaOp zeros = XlaHelpers::DynamicScalarBroadcast(0.0, logits); + std::vector broadcast_dims(labels_shape.rank()); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + + xla::Shape valid_shape = f32_shape; + valid_shape.set_element_type(XlaHelpers::TypeOfXlaOp(valid_bitmap)); + xla::XlaOp xvalid_bitmap = + DynamicBroadcastInDim(valid_bitmap, sym_op, broadcast_dims, valid_shape); + xla::XlaOp result_weight = + xla::Select(xvalid_bitmap, xweight, zeros) * one_hot_labels; + + xla::XlaComputation add_func = + XlaHelpers::CreateAddComputation(logits_shape.element_type()); + xla::XlaOp zero = xla::Zero(labels.builder(), logits_shape.element_type()); + xla::XlaOp scale = xla::ReduceAll(result_weight, zero, add_func); + if (non_zero_scale) { + xla::XlaOp one = xla::One(labels.builder(), logits_shape.element_type()); + scale = xla::Select(xla::Ne(scale, zero), scale, one); + } + return {result_weight, scale}; +} + WeightScale GetMaskedWeight(xla::XlaOp weight, const xla::Shape& logits_shape, xla::XlaOp labels, xla::XlaOp one_hot_labels, int axis, int ignore_index, bool non_zero_scale) { @@ -129,14 +183,26 @@ xla::XlaOp BuildNllLoss(xla::XlaOp logits, xla::XlaOp labels, xla::XlaOp weight, // 0 or logits has NaN in that index). Without replacing them to 0, reduction // will return NaN regardless of labelded logit values. xla::XlaOp non_labeled_mask = xla::Ne(one_hot_labels, one); - labeled_logits = xla::Select(non_labeled_mask, - xla::Broadcast(zero, logits_shape.dimensions()), - labeled_logits); + xla::XlaOp zero_logits; + if (XlaHelpers::IsDISCBackend() && logits_shape.is_dynamic()) { + zero_logits = xla::DynamicBroadcastInDim( + zero, XlaHelpers::CreateOutputDimsTensor(logits), {}, logits_shape); + } else { + zero_logits = xla::Broadcast(zero, logits_shape.dimensions()); + } + labeled_logits = xla::Select(non_labeled_mask, zero_logits, labeled_logits); // When the whole target is equal to the ignore_index in the nll_loss forward, // pytorch will return nan hence scale should be 0. - WeightScale weight_scale = GetMaskedWeight( - weight, logits_shape, labels, one_hot_labels, classes_axis, ignore_index, - /*non_zero_scale=*/false); + WeightScale weight_scale; + if (XlaHelpers::IsDISCBackend() && logits_shape.is_dynamic()) { + weight_scale = DynamicGetMaskedWeight( + weight, logits, logits_shape, labels, one_hot_labels, classes_axis, + ignore_index, /*non_zero_scale=*/false); + } else { + weight_scale = + GetMaskedWeight(weight, logits_shape, labels, one_hot_labels, + classes_axis, ignore_index, /*non_zero_scale=*/false); + } labeled_logits = labeled_logits * weight_scale.weight; xla::XlaComputation add_func = XlaHelpers::CreateAddComputation(logits_shape.element_type()); @@ -172,17 +238,35 @@ xla::XlaOp BuildNllLossBackward(xla::XlaOp grad_output, xla::XlaOp logits, const xla::Shape& grad_output_shape = ShapeHelper::ShapeOfXlaOp(grad_output); xla::XlaOp grad = grad_output; if (grad_output_shape.rank() == 1) { - grad = xla::BroadcastInDim(grad, logits_shape.dimensions(), {0}); + if (XlaHelpers::IsDISCBackend() && logits_shape.is_dynamic()) { + grad = xla::DynamicBroadcastInDim( + grad, XlaHelpers::CreateOutputDimsTensor(logits), {0}, logits_shape); + } else { + grad = xla::BroadcastInDim(grad, logits_shape.dimensions(), {0}); + } } else if (grad_output_shape.rank() == 3) { // nll_loss_2d case - grad = xla::BroadcastInDim(grad, logits_shape.dimensions(), {0, 2, 3}); + if (XlaHelpers::IsDISCBackend() && logits_shape.is_dynamic()) { + grad = xla::DynamicBroadcastInDim( + grad, XlaHelpers::CreateOutputDimsTensor(logits), {0, 2, 3}, + logits_shape); + } else { + grad = xla::BroadcastInDim(grad, logits_shape.dimensions(), {0, 2, 3}); + } } xla::XlaOp result = xla::Neg(one_hot_labels) * grad; // When the whole target is equal to the ignore_index in the nll_loss // backward, pytorch will return 0 hence scale should not be 0. - WeightScale weight_scale = - GetMaskedWeight(weight, logits_shape, labels, one_hot_labels, - classes_axis, ignore_index, /*non_zero_scale=*/true); + WeightScale weight_scale; + if (XlaHelpers::IsDISCBackend() && logits_shape.is_dynamic()) { + weight_scale = DynamicGetMaskedWeight( + weight, logits, logits_shape, labels, one_hot_labels, classes_axis, + ignore_index, /*non_zero_scale=*/true); + } else { + weight_scale = + GetMaskedWeight(weight, logits_shape, labels, one_hot_labels, + classes_axis, ignore_index, /*non_zero_scale=*/true); + } result = result * weight_scale.weight; if (reduction_mode != ReductionMode::kMean) { return result; diff --git a/torch_xla/csrc/ops/expand_symint.cpp b/torch_xla/csrc/ops/expand_symint.cpp index bed61c74ad0..9623df9f934 100644 --- a/torch_xla/csrc/ops/expand_symint.cpp +++ b/torch_xla/csrc/ops/expand_symint.cpp @@ -5,6 +5,7 @@ #include "absl/strings/str_join.h" #include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/runtime/debug_macros.h" @@ -58,6 +59,14 @@ XlaOpVector ExpandSymInt::Lower(LoweringContext* loctx) const { for (int i = 1; i < operands().size(); i++) { size_ops.push_back(loctx->GetOutputOp(operand(i))); } + if (size_ops.empty()) { + return ReturnOp(BuildExpand(input, upper_bounds_), loctx); + } + if (XlaHelpers::IsDISCBackend()) { + xla::XlaOp output = + BuildExpandSymInt(input, upper_bounds_, size_ops, dynamic_dims_); + return ReturnOp(output, loctx); + } xla::XlaOp output = SetDimensionSizes(BuildExpand(input, upper_bounds_), size_ops, dynamic_dims_); return ReturnOp(output, loctx); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 7391f8ff714..61c602c4470 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -16,6 +16,7 @@ #include "torch_xla/csrc/ops/arithmetic_ir_ops.h" #include "torch_xla/csrc/ops/constant.h" #include "torch_xla/csrc/ops/expand.h" +#include "torch_xla/csrc/ops/expand_symint.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/log_softmax_backward.h" #include "torch_xla/csrc/ops/permute.h" @@ -222,7 +223,16 @@ torch::lazy::NodePtr Sigmoid(const torch::lazy::Value& input) { torch::lazy::NodePtr SigmoidBackward(const torch::lazy::Value& grad_output, const torch::lazy::Value& output) { - torch::lazy::Value scalar_1 = ScalarOp(1, GetXlaShape(output)); + torch::lazy::Value scalar_1; + auto output_shape = GetXlaShape(output); + if (output_shape.is_dynamic()) { + SymIntElements size_elements(output); + scalar_1 = ScalarOp(1, output_shape.element_type()); + scalar_1 = torch::lazy::MakeNode(scalar_1, size_elements); + } else { + scalar_1 = ScalarOp(1, GetXlaShape(output)); + } + auto lower_fn = [](const XlaNode& node, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp grad_output = loctx->GetOutputOp(node.operand(0)); @@ -1082,4 +1092,22 @@ torch::lazy::NodePtr Mul(const torch::lazy::Value& input, std::move(lower_fn)); } +torch::lazy::NodePtr DynamicArange(const torch::lazy::Value& size, + const torch::lazy::Value& start, + const torch::lazy::Value& step, + xla::PrimitiveType scalar_type, + const xla::Shape& shape) { + auto lower_fn = [scalar_type, shape](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_size = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_start = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_end = loctx->GetOutputOp(node.operand(2)); + return node.ReturnOp( + BuildDynamicArange(xla_size, xla_start, xla_end, scalar_type, shape), + loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::arange), {size, start, step}, + shape, std::move(lower_fn)); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 76f0e165973..a41038ecfa9 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -263,6 +263,12 @@ torch::lazy::NodePtr Add(const torch::lazy::Value& input, torch::lazy::NodePtr Mul(const torch::lazy::Value& input, const torch::lazy::Value& other); +torch::lazy::NodePtr DynamicArange(const torch::lazy::Value& size, + const torch::lazy::Value& start, + const torch::lazy::Value& step, + xla::PrimitiveType scalar_type, + const xla::Shape& shape); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_OPS_OPS_H_ diff --git a/torch_xla/csrc/ops/select.cpp b/torch_xla/csrc/ops/select.cpp index a109bd49e44..eef46aa83e5 100644 --- a/torch_xla/csrc/ops/select.cpp +++ b/torch_xla/csrc/ops/select.cpp @@ -1,5 +1,7 @@ #include "torch_xla/csrc/ops/select.h" +#include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/debug_macros.h" @@ -28,7 +30,18 @@ torch::lazy::NodePtr Select::Clone(torch::lazy::OpList operands) const { XlaOpVector Select::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - if (!input_shape.is_unbounded_dynamic()) { + if (XlaHelpers::IsDISCBackend() && input_shape.is_bounded_dynamic()) { + xla::XlaOp start = + XlaHelpers::ScalarValue(start_, input.builder()); + xla::XlaOp end = XlaHelpers::ScalarValue(end_, input.builder()); + xla::XlaOp stride = + XlaHelpers::ScalarValue(stride_, input.builder()); + xla::Shape final_shape = + MakeSelectShape(input_shape, dim_, start_, end_, stride_); + xla::XlaOp output = + BuildSelectSymInt(input, dim_, start, end, stride, final_shape); + return ReturnOp(output, loctx); + } else if (!input_shape.is_unbounded_dynamic()) { xla::XlaOp output = xla::SliceInDim(input, start_, end_, GetStride(start_, end_, stride_), dim_); return ReturnOp(output, loctx); diff --git a/torch_xla/csrc/ops/select_symint.cpp b/torch_xla/csrc/ops/select_symint.cpp new file mode 100644 index 00000000000..90619e68982 --- /dev/null +++ b/torch_xla/csrc/ops/select_symint.cpp @@ -0,0 +1,56 @@ +#include "torch_xla/csrc/ops/select_symint.h" + +#include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/shape_helper.h" + +namespace torch_xla { + +SelectSymInt::SelectSymInt(const torch::lazy::Value& input, int64_t dim, + const torch::lazy::Value& start, + const torch::lazy::Value& end, + const torch::lazy::Value& stride, + xla::Shape output_shape) + : XlaNode(xla_select, {input, start, end, stride}, output_shape, + /*num_outputs=*/1, torch::lazy::MHash(dim)), + dim_(dim) {} + +XlaOpVector SelectSymInt::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp start = loctx->GetOutputOp(operand(1)); + xla::XlaOp end = loctx->GetOutputOp(operand(2)); + xla::XlaOp stride = loctx->GetOutputOp(operand(3)); + xla::XlaOp output = + BuildSelectSymInt(input, dim_, start, end, stride, xla_shape()); + return ReturnOp(output, loctx); +} + +std::string SelectSymInt::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", dim=" << dim_; + return ss.str(); +} + +xla::Shape SelectSymInt::MakeSelectShape(const xla::Shape& shape, int64_t dim, + int64_t start, int64_t end, + int64_t stride) { + int64_t effective_stride = GetStride(start, end, stride); + xla::Shape select_shape(shape); + select_shape.set_dimensions( + dim, (end - start + effective_stride - 1) / effective_stride); + select_shape.set_dynamic_dimension(dim, true); + return select_shape; +} + +int64_t SelectSymInt::GetStride(int64_t start, int64_t end, int64_t stride) { + if (stride == 0) { + XLA_CHECK_EQ(start, end); + stride = 1; + } + return stride; +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/select_symint.h b/torch_xla/csrc/ops/select_symint.h new file mode 100644 index 00000000000..77345248923 --- /dev/null +++ b/torch_xla/csrc/ops/select_symint.h @@ -0,0 +1,31 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_SELECT_SYMINT_H_ +#define XLA_TORCH_XLA_CSRC_OPS_SELECT_SYMINT_H_ + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class SelectSymInt : public XlaNode { + public: + SelectSymInt(const torch::lazy::Value& input, int64_t dim, + const torch::lazy::Value& start, const torch::lazy::Value& end, + const torch::lazy::Value& stride, xla::Shape output_shape); + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + int64_t dim() const { return dim_; } + + static xla::Shape MakeSelectShape(const xla::Shape& shape, int64_t dim, + int64_t start, int64_t end, int64_t stride); + + static int64_t GetStride(int64_t start, int64_t end, int64_t stride); + + private: + int64_t dim_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_SELECT_SYMINT_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/view_symint.cpp b/torch_xla/csrc/ops/view_symint.cpp new file mode 100644 index 00000000000..f77ada4a89f --- /dev/null +++ b/torch_xla/csrc/ops/view_symint.cpp @@ -0,0 +1,62 @@ +#include "torch_xla/csrc/ops/view_symint.h" + +#include "absl/strings/str_join.h" +#include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/lowering_context.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace { + +std::vector GetValues( + const torch::lazy::Value& input, + const std::vector& dimensions) { + std::vector values; + values.reserve(dimensions.size() + 1); + values.push_back(input); + for (torch::lazy::NodePtr dim : dimensions) { + if (dim) { + // Dimension Node only exist for symbolic dimension. + values.push_back(torch::lazy::Value(dim, 0)); + } + } + return values; +} + +} // namespace + +ViewSymIntOp::ViewSymIntOp(const torch::lazy::Value& input, + const SymIntElements& size_elements, + xla::Shape output_shape) + : XlaNode( + torch::lazy::OpKind(at::aten::view), + GetValues(input, size_elements.GetSizeNodes()), output_shape, + /*num_outputs=*/1, + torch::lazy::MHash( + torch::lazy::ToVector(output_shape.dimensions()), + torch::lazy::ToVector(output_shape.dynamic_dimensions()))), + upper_bounds_(torch::lazy::ToVector(output_shape.dimensions())), + dynamic_dims_( + torch::lazy::ToVector(output_shape.dynamic_dimensions())) {} + +XlaOpVector ViewSymIntOp::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + std::vector size_ops; + for (int i = 1; i < operands().size(); i++) { + size_ops.push_back(loctx->GetOutputOp(operand(i))); + } + xla::XlaOp output = + BuildViewSymInt(input, size_ops, upper_bounds_, dynamic_dims_); + return ReturnOp(output, loctx); +} + +std::string ViewSymIntOp::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", size=(" << absl::StrJoin(upper_bounds_, ", ") + << ")" + << ", dynamic_dims=(" << absl::StrJoin(dynamic_dims_, ", ") << ")"; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/view_symint.h b/torch_xla/csrc/ops/view_symint.h new file mode 100644 index 00000000000..2eaea2752cd --- /dev/null +++ b/torch_xla/csrc/ops/view_symint.h @@ -0,0 +1,32 @@ +#ifndef XLA_TORCH_XLA_CSRC_OPS_VIEW_SYMINT_H_ +#define XLA_TORCH_XLA_CSRC_OPS_VIEW_SYMINT_H_ + +#include + +#include "absl/types/span.h" +#include "torch_xla/csrc/ir.h" +#include "torch_xla/csrc/torch_util.h" + +namespace torch_xla { + +class ViewSymIntOp : public XlaNode { + public: + ViewSymIntOp(const torch::lazy::Value& input, + const SymIntElements& size_elements, xla::Shape output_shape); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + const std::vector& size() const { return upper_bounds_; }; + + const bool IsDynamic(int index) const { return dynamic_dims_[index]; }; + + private: + std::vector upper_bounds_; + std::vector dynamic_dims_; +}; + +} // namespace torch_xla + +#endif // XLA_TORCH_XLA_CSRC_OPS_VIEW_SYMINT_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 639245dbd55..97037b4195f 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -111,6 +111,7 @@ #include "torch_xla/csrc/ops/scatter_add.h" #include "torch_xla/csrc/ops/scatter_reduce.h" #include "torch_xla/csrc/ops/select.h" +#include "torch_xla/csrc/ops/select_symint.h" #include "torch_xla/csrc/ops/send.h" #include "torch_xla/csrc/ops/sgd_optimizer_step.h" #include "torch_xla/csrc/ops/softmax.h" @@ -136,6 +137,7 @@ #include "torch_xla/csrc/ops/var.h" #include "torch_xla/csrc/ops/var_mean.h" #include "torch_xla/csrc/ops/view.h" +#include "torch_xla/csrc/ops/view_symint.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/metrics.h" @@ -167,6 +169,16 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, input, torch::lazy::ToVector(target_shape.dimensions())); } +torch::lazy::Value MaybeExpandSymInt(const torch::lazy::Value& input, + const xla::Shape& target_shape, + c10::SymIntArrayRef target_size) { + if (GetXlaShape(input).dimensions() == target_shape.dimensions()) { + return input; + } + SymIntElements size_elements = SymIntElements(target_size); + return torch::lazy::MakeNode(input, size_elements); +} + MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor, const c10::optional& min, const c10::optional& max) { @@ -849,8 +861,39 @@ XLATensorPtr addmm(const XLATensorPtr& input, const XLATensorPtr& weight, void arange_out(XLATensorPtr& out, const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, at::ScalarType scalar_type) { - out->SetIrValue(ARange(start, end, step, scalar_type)); - out->SetScalarType(scalar_type); + if (XlaHelpers::IsDISCBackend() && + (start.isSymbolic() || end.isSymbolic() || step.isSymbolic())) { + XLA_CHECK(start.isIntegral(/*includeBool=*/false) && + end.isIntegral(/*includeBool=*/false) && + step.isIntegral(/*includeBool=*/false)) + << "Only int start && end && step is supported, got " << start << ", " + << end << ", " << step; + // corner case + if (start.toSymInt() == end.toSymInt()) { + out->SetIrValue(ARange(0, 0, 1, scalar_type)); + out->SetScalarType(scalar_type); + return; + } + xla::PrimitiveType prim_type = + MakeXlaPrimitiveType(scalar_type, &out->GetDevice()); + auto range_size = + 1 + (end.toSymInt() - start.toSymInt() - 1) / step.toSymInt(); // ceil + + xla::Shape result_shape = xla::ShapeUtil::MakeShape( + prim_type, {GetSymIntUpperBound(range_size)}, {true}); + + const torch::lazy::BackendDevice& device = out->GetDevice(); + torch::lazy::Value size_value = GetSymIntValue(range_size, device); + torch::lazy::Value start_value = GetSymIntValue(start.toSymInt(), device); + torch::lazy::Value step_value = GetSymIntValue(step.toSymInt(), device); + + out->SetIrValue(DynamicArange(size_value, start_value, step_value, + prim_type, result_shape)); + out->SetScalarType(scalar_type); + } else { + out->SetIrValue(ARange(start, end, step, scalar_type)); + out->SetScalarType(scalar_type); + } } XLATensorPtr as_strided(const XLATensorPtr& input, std::vector size, @@ -1307,6 +1350,22 @@ XLATensorPtr embedding(const XLATensorPtr& weight, return tensor_ops::Embedding(weight, indices); } +XLATensorPtr embedding_symint(const XLATensorPtr& weight, + const XLATensorPtr& indices, + at::SymIntArrayRef final_size, + at::SymInt indices_numel) { + XLA_CHECK_EQ(weight->shape().get().rank(), 2); + XLA_CHECK(indices->dtype() == at::kLong || indices->dtype() == at::kInt); + + if (indices->shape().get().rank() == 1) { + return index_select(weight, 0, indices); + } + + XLATensorPtr embeddings = + index_select(weight, 0, view_symint(indices, {indices_numel})); + return view_symint(embeddings, final_size); +} + XLATensorPtr exp(const XLATensorPtr& input) { return input->CreateFrom(Exp(input->GetIrValue())); } @@ -2493,8 +2552,16 @@ XLATensorPtr rrelu_with_noise_backward(const XLATensorPtr& grad_output, XLATensorPtr rsub(const XLATensorPtr& input, const XLATensorPtr& other, const at::Scalar& alpha, c10::optional logical_element_type) { - torch::lazy::Value alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, other->shape(), logical_element_type, other->GetDevice()); + torch::lazy::Value alpha_xla; + if (other->shape().get().is_static()) { + alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( + alpha, other->shape(), logical_element_type, other->GetDevice()); + } else { + SymIntElements sym_int_elements(other->GetIrValue()); + alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( + alpha, other->shape(), sym_int_elements, logical_element_type, + other->GetDevice()); + } return input->CreateFrom( Rsub(input->GetIrValue(), other->GetIrValue(), alpha_xla), @@ -2504,15 +2571,55 @@ XLATensorPtr rsub(const XLATensorPtr& input, const XLATensorPtr& other, XLATensorPtr rsub(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { - torch::lazy::Value alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, input->shape(), logical_element_type, input->GetDevice()); - torch::lazy::Value other_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, input->shape(), logical_element_type, input->GetDevice()); + torch::lazy::Value alpha_xla; + torch::lazy::Value other_xla; + if (input->shape().get().is_static()) { + alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( + alpha, input->shape(), logical_element_type, input->GetDevice()); + other_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( + other, input->shape(), logical_element_type, input->GetDevice()); + } else { + SymIntElements sym_int_elements(input->GetIrValue()); + alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( + alpha, input->shape(), sym_int_elements, logical_element_type, + input->GetDevice()); + other_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( + other, input->shape(), sym_int_elements, logical_element_type, + input->GetDevice()); + } return input->CreateFrom(Rsub(input->GetIrValue(), other_xla, alpha_xla), logical_element_type); } +void copy_symint_(XLATensorPtr& input, c10::SymIntArrayRef input_size, + XLATensorPtr& src) { + if (input->GetDevice() == src->GetDevice()) { + torch::lazy::Value copy_value; + if (input->dtype() == src->dtype()) { + copy_value = src->GetIrValue(); + } else { + copy_value = torch::lazy::MakeNode(src->GetIrValue(), + input->dtype(), src->dtype()); + } + input->SetIrValue( + MaybeExpandSymInt(copy_value, input->shape(), input_size)); + } else { + auto input_shape = input->shape(); + at::Tensor src_tensor = src->ToTensor(/*detached=*/true); + if (!torch_xla::runtime::util::Equal(src_tensor.sizes(), + input_shape.get().dimensions())) { + src_tensor = src_tensor.expand_symint(input_size); + } + input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false); + } + + // Preserves sharding when copying. + if (src->sharding_spec() != nullptr) { + input->SetShardingSpec(*src->sharding_spec()); + } +} + void copy_(XLATensorPtr& input, XLATensorPtr& src) { if (input->GetDevice() == src->GetDevice()) { torch::lazy::Value copy_value; @@ -2634,6 +2741,47 @@ XLATensorPtr slice(const XLATensorPtr& input, int64_t dim, int64_t start, input->GetIrValue(), dim, start, end, step)); } +XLATensorPtr slice_symint(const XLATensorPtr& input, int64_t dim, + c10::SymInt start, c10::SymInt end, + c10::SymInt step) { + auto input_shape = input->shape(); + dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.get().rank()); + std::vector input_dims = torch_xla::runtime::util::ToVector( + input_shape.get().dimensions()); + if (input_dims[dim] == 0) { + // `GetCanonicalDimensionIndex` doesn't support case where dim size = 0. + // So we add a special handling in torch_xla. + return input->CreateFrom( + torch::lazy::MakeNode