Skip to content

Commit

Permalink
Support mark bounded dynamic and use bound values for shape comparison (
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh authored Oct 21, 2024
1 parent 5d9590e commit 51f25c8
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 5 deletions.
73 changes: 73 additions & 0 deletions test/ds/test_bounded_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import sys
import unittest
import torch, torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

sys.path.insert(1, os.path.join(sys.path[0], '..'))
import test_utils

pd = torch._C._EnablePythonDispatcher()
dev = xm.xla_device()


def mark_dynamic(t, dims, bounds):
torch_xla._XLAC._xla_mark_bounded_dynamic(t, dims, bounds)


class TestBoundedDynamicShapes(test_utils.XlaTestCase):

def test_mark_dynamic(self):
t1 = torch.randn([5, 2]).to(dev)
# t1 has size [<=10, 2]
mark_dynamic(t1, [0], [10])
self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([t1]))
if test_utils.is_disc_backend():
t1_cpu = t1.cpu()
self.assertEqual(t1_cpu.shape[0], 5)

def test_sizeGe(self):
met.clear_all()
t1 = torch.randn([5, 2]).to(dev)
# t1 has size [<=10, 2]
mark_dynamic(t1, [0], [10])
self.assertTrue(t1.shape[0] >= t1.shape[1])
self.assertGreater(met.counter_value("xla::size_ge"), 0)
self.assertIsNone(met.metric_data('CompileTime'))

def test_sizeLt(self):
met.clear_all()
t1 = torch.randn([5, 2]).to(dev)
# t1 has size [<=10, 2]
mark_dynamic(t1, [0], [10])
self.assertFalse(t1.shape[0] < t1.shape[1])
self.assertGreater(met.counter_value("xla::size_lt"), 0)
self.assertIsNone(met.metric_data('CompileTime'))

def test_sizeNe(self):
met.clear_all()
t1 = torch.randn([5, 2]).to(dev)
# t1 has size [<=10, 2]
mark_dynamic(t1, [0], [10])
self.assertTrue(t1.shape[0] != t1.shape[1])
self.assertGreater(met.counter_value("xla::size_ne"), 0)
self.assertIsNone(met.metric_data('CompileTime'))

def test_sizeEq(self):
met.clear_all()
t1 = torch.randn([5, 2]).to(dev)
# t1 has size [<=10, 2]
mark_dynamic(t1, [0], [10])
self.assertFalse(t1.shape[0] == 1)
self.assertGreater(met.counter_value("xla::size_eq"), 0)
self.assertIsNone(met.metric_data('CompileTime'))


if __name__ == '__main__':
os.environ['USE_BOUND_FOR_SHAPE_COMPARE'] = os.getenv(
'USE_BOUND_FOR_SHAPE_COMPARE', '1')
test = unittest.main()
# DISABLE PYTHON DISPATCHER FLAG
del pd
sys.exit(0 if test.result.wasSuccessful() else 1)
6 changes: 6 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu

_IS_DISC_BACKEND = 'DISC_DEVICE' in os.environ


def is_disc_backend():
return _IS_DISC_BACKEND


def _set_rng_seed(seed):
torch.manual_seed(seed)
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2329,6 +2329,15 @@ void InitXlaModuleBindings(py::module m) {
}
return result;
});
m.def("_xla_mark_bounded_dynamic",
[](const at::Tensor& input, const std::vector<uint32_t>& dims,
const std::vector<uint32_t>& bounds) {
TORCH_LAZY_COUNTER("XlaMarkBoundedDynamic", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLA_CHECK(dims.size() == bounds.size())
<< "dims.size() should be equal to bounds.size()";
xtensor->MarkBoundedDynamicDimension(dims, bounds);
});
m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) {
TORCH_LAZY_COUNTER("XlaMarkDynamic", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::Shape shape,
torch::lazy::hash_t hash_seed)
: torch::lazy::Node(op, shape, num_outputs),
xla_shape_(std::move(xla_shape)),
node_hash_(GetOpHash(op, xla_shape_, hash_seed)),
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
shape_hash_(torch::lazy::Hash(xla_shape_.ToString())),
dag_hash_(node_hash_) {}

XlaNode::XlaNode(torch::lazy::OpKind op, xla::Shape xla_shape,
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t node_hash() const { return node_hash_; }

torch::lazy::hash_t hash() const override {
if (shape_hash_ != 0) {
auto dag_hash =
sharding_hash_ != 0
? (torch::lazy::HashCombine(dag_hash_, sharding_hash_))
: dag_hash_;
return torch::lazy::HashCombine(dag_hash, shape_hash_);
}
if (sharding_hash_ != 0) {
return torch::lazy::HashCombine(dag_hash_, sharding_hash_);
}
Expand Down Expand Up @@ -143,6 +150,12 @@ class XlaNode : public torch::lazy::Node {
unbounded_dynamic_dims_.insert(dim);
}

void MarkBoundedDynamicDimension(uint32_t dim, int64_t bound) {
xla_shape_.set_dynamic_dimension(dim, true);
xla_shape_.set_dimensions(dim, bound);
shape_hash_ = torch::lazy::Hash(xla_shape_.ToString());
}

const std::unordered_set<uint32_t>& dynamic_dims() const {
return unbounded_dynamic_dims_;
}
Expand All @@ -168,6 +181,7 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t node_hash_ = 0;
torch::lazy::hash_t dag_hash_;
torch::lazy::hash_t sharding_hash_ = 0;
torch::lazy::hash_t shape_hash_ = 0;

// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,14 @@ SizeEq::SizeEq(torch::lazy::Value a, torch::lazy::Value b)
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
upper_bound_ =
dim_node_0->getStaticValue() == dim_node_1->getStaticValue() ? 1 : 0;
};

int64_t SizeEq::getDynamicValue() const {
if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) {
return upper_bound_;
}
if (operand(0) == operand(1)) {
return 1;
}
Expand All @@ -163,9 +168,14 @@ SizeNe::SizeNe(torch::lazy::Value a, torch::lazy::Value b)
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
upper_bound_ =
dim_node_0->getStaticValue() != dim_node_1->getStaticValue() ? 1 : 0;
};

int64_t SizeNe::getDynamicValue() const {
if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) {
return upper_bound_;
}
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
Expand All @@ -185,9 +195,14 @@ SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b)
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
upper_bound_ =
dim_node_0->getStaticValue() >= dim_node_1->getStaticValue() ? 1 : 0;
};

int64_t SizeGe::getDynamicValue() const {
if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) {
return upper_bound_;
}
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
Expand All @@ -207,9 +222,14 @@ SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b)
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
upper_bound_ =
dim_node_0->getStaticValue() < dim_node_1->getStaticValue() ? 1 : 0;
};

int64_t SizeLt::getDynamicValue() const {
if (runtime::sys_util::GetEnvBool("USE_BOUND_FOR_SHAPE_COMPARE", false)) {
return upper_bound_;
}
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/ops/dynamic_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}

private:
int64_t upper_bound_;
};

class SizeNe : public XlaNode, public torch::lazy::DimensionNode {
Expand All @@ -85,6 +88,9 @@ class SizeNe : public XlaNode, public torch::lazy::DimensionNode {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}

private:
int64_t upper_bound_;
};

class SizeGe : public XlaNode, public torch::lazy::DimensionNode {
Expand All @@ -100,6 +106,9 @@ class SizeGe : public XlaNode, public torch::lazy::DimensionNode {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}

private:
int64_t upper_bound_;
};

class SizeLt : public XlaNode, public torch::lazy::DimensionNode {
Expand All @@ -115,6 +124,9 @@ class SizeLt : public XlaNode, public torch::lazy::DimensionNode {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}

private:
int64_t upper_bound_;
};

class SizeAdd : public XlaNode, public torch::lazy::DimensionNode {
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class ComputationClient {
should_donate_buffer_ = should_donate_buffer;
}

void MarkBoundedDynamicDimension(uint32_t dim, int64_t bound) {
xla_shape_.set_dynamic_dimension(dim, true);
xla_shape_.set_dimensions(dim, bound);
}

virtual std::string ToString() const = 0;

virtual bool HasSharding() const = 0;
Expand Down
33 changes: 29 additions & 4 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ c10::SymNode XLASymNodeImpl::mod(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::eq(const c10::SymNode& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_");
auto p_other = dynamic_cast<XLASymNodeImpl*>(other.get());
XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI";
XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI";
Expand All @@ -740,8 +741,13 @@ c10::SymNode XLASymNodeImpl::ne(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::size_");
auto p_other = dynamic_cast<XLASymNodeImpl*>(other.get());
XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI";
XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI";
// use SizeLt to implement SizeGt
auto n_lt = torch::lazy::MakeNode<SizeLt>(p_other->node(), node());
return c10::make_intrusive<XLASymNodeImpl>(n_lt, PyType::BOOL);
}

c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) {
Expand Down Expand Up @@ -876,8 +882,7 @@ c10::SymNode XLASymNodeImpl::wrap_bool(bool num) {
}

int64_t XLASymNodeImpl::guard_int(const char* file, int64_t line) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
return int_();
}

double XLASymNodeImpl::guard_float(const char* file, int64_t line) {
Expand Down Expand Up @@ -934,6 +939,26 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) {
xla_node->MarkDynamicDimension(dim);
}

void XLATensor::MarkBoundedDynamicDimension(
const std::vector<uint32_t>& dims, const std::vector<uint32_t>& bounds) {
auto* ir_node = GetIrValue().node.get();
auto* xla_node = dynamic_cast<XlaNode*>(ir_node);
torch::lazy::BackendDataPtr backend_data = CurrentDataHandle();
if (backend_data == nullptr) {
backend_data =
torch::lazy::getBackend()->GetComputationDataFromNode(ir_node);
}
for (int i = 0; i < dims.size(); i++) {
xla_node->MarkBoundedDynamicDimension(dims[i], bounds[i]);
if (backend_data) {
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(backend_data)
->MarkBoundedDynamicDimension(dims[i], bounds[i]);
}
}
// Update generation to XLATensorImpl::SetupSymSizeProperties
data()->generation += 1;
}

bool XLATensor::SetNodeUserMetadata(
std::shared_ptr<torch::lazy::UserMetaData> metadata) {
auto* node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class XLATensor : public torch::lazy::LazyTensor {
void SetScalarType(c10::optional<at::ScalarType> logical_element_type);

void MarkDynamicDimension(uint32_t dim);
void MarkBoundedDynamicDimension(const std::vector<uint32_t>& dims,
const std::vector<uint32_t>& bounds);
// We don't use the upstream shape to provide xla::shape.
runtime::util::MaybeRef<xla::Shape> shape() const;

Expand Down

0 comments on commit 51f25c8

Please sign in to comment.