Skip to content

Commit 3f8fdbb

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Update sharding type assertions to support COLUMN_WISE sharding (#3344)
Summary: Pull Request resolved: #3344 astitled Reviewed By: Raahul46 Differential Revision: D81543306 fbshipit-source-id: 5160a94ab5842f134a459bb2d76920d940c13e40
1 parent 960581a commit 3f8fdbb

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

torchrec/distributed/embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ def get_device_from_parameter_sharding(
133133
if len(set(device_type_list)) == 1:
134134
return device_type_list[0]
135135
else:
136-
assert (
137-
ps.sharding_type == "row_wise"
138-
), "Only row_wise sharding supports sharding across multiple device types for a table"
136+
assert ps.sharding_type in [
137+
ShardingType.ROW_WISE.value,
138+
ShardingType.COLUMN_WISE.value,
139+
], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
139140
return device_type_list
140141

141142

torchrec/distributed/quant_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ def get_device_from_parameter_sharding(
166166
if len(set(device_type_list)) == 1:
167167
return device_type_list[0]
168168
else:
169-
assert (
170-
ps.sharding_type == "row_wise"
171-
), "Only row_wise sharding supports sharding across multiple device types for a table"
169+
assert ps.sharding_type in [
170+
ShardingType.ROW_WISE.value,
171+
ShardingType.COLUMN_WISE.value,
172+
], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
172173
return device_type_list
173174

174175

0 commit comments

Comments
 (0)