diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 251968ecb..24b0a4885 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -133,9 +133,10 @@ def get_device_from_parameter_sharding( if len(set(device_type_list)) == 1: return device_type_list[0] else: - assert ( - ps.sharding_type == "row_wise" - ), "Only row_wise sharding supports sharding across multiple device types for a table" + assert ps.sharding_type in [ + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table" return device_type_list diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 94d7574f6..af945dfe2 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -166,9 +166,10 @@ def get_device_from_parameter_sharding( if len(set(device_type_list)) == 1: return device_type_list[0] else: - assert ( - ps.sharding_type == "row_wise" - ), "Only row_wise sharding supports sharding across multiple device types for a table" + assert ps.sharding_type in [ + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table" return device_type_list