From 51f25c82feae29ed13bced9f5b171b0642a6d206 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Mon, 21 Oct 2024 17:33:59 +0800 Subject: [PATCH] Support mark bounded dynamic and use bound values for shape comparison (#6) --- test/ds/test_bounded_dynamic.py | 73 +++++++++++++++++++++ test/test_utils.py | 6 ++ torch_xla/csrc/init_python_bindings.cpp | 9 +++ torch_xla/csrc/ir.cpp | 3 +- torch_xla/csrc/ir.h | 14 ++++ torch_xla/csrc/ops/dynamic_ir.cpp | 20 ++++++ torch_xla/csrc/ops/dynamic_ir.h | 12 ++++ torch_xla/csrc/runtime/computation_client.h | 5 ++ torch_xla/csrc/tensor.cpp | 33 ++++++++-- torch_xla/csrc/tensor.h | 2 + 10 files changed, 172 insertions(+), 5 deletions(-) create mode 100644 test/ds/test_bounded_dynamic.py diff --git a/test/ds/test_bounded_dynamic.py b/test/ds/test_bounded_dynamic.py new file mode 100644 index 00000000000..77b4f5630e7 --- /dev/null +++ b/test/ds/test_bounded_dynamic.py @@ -0,0 +1,73 @@ +import os +import sys +import unittest +import torch, torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met + +sys.path.insert(1, os.path.join(sys.path[0], '..')) +import test_utils + +pd = torch._C._EnablePythonDispatcher() +dev = xm.xla_device() + + +def mark_dynamic(t, dims, bounds): + torch_xla._XLAC._xla_mark_bounded_dynamic(t, dims, bounds) + + +class TestBoundedDynamicShapes(test_utils.XlaTestCase): + + def test_mark_dynamic(self): + t1 = torch.randn([5, 2]).to(dev) + # t1 has size [<=10, 2] + mark_dynamic(t1, [0], [10]) + self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([t1])) + if test_utils.is_disc_backend(): + t1_cpu = t1.cpu() + self.assertEqual(t1_cpu.shape[0], 5) + + def test_sizeGe(self): + met.clear_all() + t1 = torch.randn([5, 2]).to(dev) + # t1 has size [<=10, 2] + mark_dynamic(t1, [0], [10]) + self.assertTrue(t1.shape[0] >= t1.shape[1]) + self.assertGreater(met.counter_value("xla::size_ge"), 0) + self.assertIsNone(met.metric_data('CompileTime')) + + def test_sizeLt(self): + met.clear_all() + t1 = torch.randn([5, 2]).to(dev) + # t1 has size [<=10, 2] + mark_dynamic(t1, [0], [10]) + self.assertFalse(t1.shape[0] < t1.shape[1]) + self.assertGreater(met.counter_value("xla::size_lt"), 0) + self.assertIsNone(met.metric_data('CompileTime')) + + def test_sizeNe(self): + met.clear_all() + t1 = torch.randn([5, 2]).to(dev) + # t1 has size [<=10, 2] + mark_dynamic(t1, [0], [10]) + self.assertTrue(t1.shape[0] != t1.shape[1]) + self.assertGreater(met.counter_value("xla::size_ne"), 0) + self.assertIsNone(met.metric_data('CompileTime')) + + def test_sizeEq(self): + met.clear_all() + t1 = torch.randn([5, 2]).to(dev) + # t1 has size [<=10, 2] + mark_dynamic(t1, [0], [10]) + self.assertFalse(t1.shape[0] == 1) + self.assertGreater(met.counter_value("xla::size_eq"), 0) + self.assertIsNone(met.metric_data('CompileTime')) + + +if __name__ == '__main__': + os.environ['USE_BOUND_FOR_SHAPE_COMPARE'] = os.getenv( + 'USE_BOUND_FOR_SHAPE_COMPARE', '1') + test = unittest.main() + # DISABLE PYTHON DISPATCHER FLAG + del pd + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_utils.py b/test/test_utils.py index 4aefdce6805..905e81d5af0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -13,6 +13,12 @@ import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu +_IS_DISC_BACKEND = 'DISC_DEVICE' in os.environ + + +def is_disc_backend(): + return _IS_DISC_BACKEND + def _set_rng_seed(seed): torch.manual_seed(seed) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f0373a485c5..d07eb856fe6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2329,6 +2329,15 @@ void InitXlaModuleBindings(py::module m) { } return result; }); + m.def("_xla_mark_bounded_dynamic", + [](const at::Tensor& input, const std::vector& dims, + const std::vector& bounds) { + TORCH_LAZY_COUNTER("XlaMarkBoundedDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLA_CHECK(dims.size() == bounds.size()) + << "dims.size() should be equal to bounds.size()"; + xtensor->MarkBoundedDynamicDimension(dims, bounds); + }); m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); XLATensorPtr xtensor = bridge::GetXlaTensor(input); diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 7a0a226c27d..beddda9faf2 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -101,7 +101,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::Shape shape, torch::lazy::hash_t hash_seed) : torch::lazy::Node(op, shape, num_outputs), xla_shape_(std::move(xla_shape)), - node_hash_(GetOpHash(op, xla_shape_, hash_seed)), + node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)), + shape_hash_(torch::lazy::Hash(xla_shape_.ToString())), dag_hash_(node_hash_) {} XlaNode::XlaNode(torch::lazy::OpKind op, xla::Shape xla_shape, diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 0a98c1122ce..24b202904e4 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -112,6 +112,13 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t node_hash() const { return node_hash_; } torch::lazy::hash_t hash() const override { + if (shape_hash_ != 0) { + auto dag_hash = + sharding_hash_ != 0 + ? (torch::lazy::HashCombine(dag_hash_, sharding_hash_)) + : dag_hash_; + return torch::lazy::HashCombine(dag_hash, shape_hash_); + } if (sharding_hash_ != 0) { return torch::lazy::HashCombine(dag_hash_, sharding_hash_); } @@ -143,6 +150,12 @@ class XlaNode : public torch::lazy::Node { unbounded_dynamic_dims_.insert(dim); } + void MarkBoundedDynamicDimension(uint32_t dim, int64_t bound) { + xla_shape_.set_dynamic_dimension(dim, true); + xla_shape_.set_dimensions(dim, bound); + shape_hash_ = torch::lazy::Hash(xla_shape_.ToString()); + } + const std::unordered_set& dynamic_dims() const { return unbounded_dynamic_dims_; } @@ -168,6 +181,7 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t node_hash_ = 0; torch::lazy::hash_t dag_hash_; torch::lazy::hash_t sharding_hash_ = 0; + torch::lazy::hash_t shape_hash_ = 0; // Experimental sharding annotations attached to the IR node. std::vector> output_shardings_; diff --git a/torch_xla/csrc/ops/dynamic_ir.cpp b/torch_xla/csrc/ops/dynamic_ir.cpp index b3236e9f144..55812cb631f 100644 --- a/torch_xla/csrc/ops/dynamic_ir.cpp +++ b/torch_xla/csrc/ops/dynamic_ir.cpp @@ -138,9 +138,14 @@ SizeEq::SizeEq(torch::lazy::Value a, torch::lazy::Value b) const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); XLA_CHECK(dim_node_1); + upper_bound_ = + dim_node_0->getStaticValue() == dim_node_1->getStaticValue() ? 1 : 0; }; int64_t SizeEq::getDynamicValue() const { + if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) { + return upper_bound_; + } if (operand(0) == operand(1)) { return 1; } @@ -163,9 +168,14 @@ SizeNe::SizeNe(torch::lazy::Value a, torch::lazy::Value b) const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); XLA_CHECK(dim_node_1); + upper_bound_ = + dim_node_0->getStaticValue() != dim_node_1->getStaticValue() ? 1 : 0; }; int64_t SizeNe::getDynamicValue() const { + if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) { + return upper_bound_; + } const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); @@ -185,9 +195,14 @@ SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b) const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); XLA_CHECK(dim_node_1); + upper_bound_ = + dim_node_0->getStaticValue() >= dim_node_1->getStaticValue() ? 1 : 0; }; int64_t SizeGe::getDynamicValue() const { + if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) { + return upper_bound_; + } const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); @@ -207,9 +222,14 @@ SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b) const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); XLA_CHECK(dim_node_1); + upper_bound_ = + dim_node_0->getStaticValue() < dim_node_1->getStaticValue() ? 1 : 0; }; int64_t SizeLt::getDynamicValue() const { + if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) { + return upper_bound_; + } const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); XLA_CHECK(dim_node_0); diff --git a/torch_xla/csrc/ops/dynamic_ir.h b/torch_xla/csrc/ops/dynamic_ir.h index 05d59538908..e1097ec0960 100644 --- a/torch_xla/csrc/ops/dynamic_ir.h +++ b/torch_xla/csrc/ops/dynamic_ir.h @@ -70,6 +70,9 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode { // TODO: not sure we will ever need it? TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); } + + private: + int64_t upper_bound_; }; class SizeNe : public XlaNode, public torch::lazy::DimensionNode { @@ -85,6 +88,9 @@ class SizeNe : public XlaNode, public torch::lazy::DimensionNode { // TODO: not sure we will ever need it? TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); } + + private: + int64_t upper_bound_; }; class SizeGe : public XlaNode, public torch::lazy::DimensionNode { @@ -100,6 +106,9 @@ class SizeGe : public XlaNode, public torch::lazy::DimensionNode { // TODO: not sure we will ever need it? TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); } + + private: + int64_t upper_bound_; }; class SizeLt : public XlaNode, public torch::lazy::DimensionNode { @@ -115,6 +124,9 @@ class SizeLt : public XlaNode, public torch::lazy::DimensionNode { // TODO: not sure we will ever need it? TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); } + + private: + int64_t upper_bound_; }; class SizeAdd : public XlaNode, public torch::lazy::DimensionNode { diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 33b48255baf..11aac832318 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -71,6 +71,11 @@ class ComputationClient { should_donate_buffer_ = should_donate_buffer; } + void MarkBoundedDynamicDimension(uint32_t dim, int64_t bound) { + xla_shape_.set_dynamic_dimension(dim, true); + xla_shape_.set_dimensions(dim, bound); + } + virtual std::string ToString() const = 0; virtual bool HasSharding() const = 0; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 26fb7fe784b..831e4aee37e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -723,6 +723,7 @@ c10::SymNode XLASymNodeImpl::mod(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::eq(const c10::SymNode& other) { + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); auto p_other = dynamic_cast(other.get()); XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI"; XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI"; @@ -740,8 +741,13 @@ c10::SymNode XLASymNodeImpl::ne(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_"); + auto p_other = dynamic_cast(other.get()); + XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI"; + XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI"; + // use SizeLt to implement SizeGt + auto n_lt = torch::lazy::MakeNode(p_other->node(), node()); + return c10::make_intrusive(n_lt, PyType::BOOL); } c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) { @@ -876,8 +882,7 @@ c10::SymNode XLASymNodeImpl::wrap_bool(bool num) { } int64_t XLASymNodeImpl::guard_int(const char* file, int64_t line) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + return int_(); } double XLASymNodeImpl::guard_float(const char* file, int64_t line) { @@ -934,6 +939,26 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) { xla_node->MarkDynamicDimension(dim); } +void XLATensor::MarkBoundedDynamicDimension( + const std::vector& dims, const std::vector& bounds) { + auto* ir_node = GetIrValue().node.get(); + auto* xla_node = dynamic_cast(ir_node); + torch::lazy::BackendDataPtr backend_data = CurrentDataHandle(); + if (backend_data == nullptr) { + backend_data = + torch::lazy::getBackend()->GetComputationDataFromNode(ir_node); + } + for (int i = 0; i < dims.size(); i++) { + xla_node->MarkBoundedDynamicDimension(dims[i], bounds[i]); + if (backend_data) { + std::dynamic_pointer_cast(backend_data) + ->MarkBoundedDynamicDimension(dims[i], bounds[i]); + } + } + // Update generation to XLATensorImpl::SetupSymSizeProperties + data()->generation += 1; +} + bool XLATensor::SetNodeUserMetadata( std::shared_ptr metadata) { auto* node = dynamic_cast(CurrentIrValue().node.get()); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index af0ef0b2260..09cbd4f3258 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -203,6 +203,8 @@ class XLATensor : public torch::lazy::LazyTensor { void SetScalarType(c10::optional logical_element_type); void MarkDynamicDimension(uint32_t dim); + void MarkBoundedDynamicDimension(const std::vector& dims, + const std::vector& bounds); // We don't use the upstream shape to provide xla::shape. runtime::util::MaybeRef shape() const;