File tree Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -133,9 +133,10 @@ def get_device_from_parameter_sharding(
133
133
if len (set (device_type_list )) == 1 :
134
134
return device_type_list [0 ]
135
135
else :
136
- assert (
137
- ps .sharding_type == "row_wise"
138
- ), "Only row_wise sharding supports sharding across multiple device types for a table"
136
+ assert ps .sharding_type in [
137
+ ShardingType .ROW_WISE .value ,
138
+ ShardingType .COLUMN_WISE .value ,
139
+ ], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
139
140
return device_type_list
140
141
141
142
Original file line number Diff line number Diff line change @@ -166,9 +166,10 @@ def get_device_from_parameter_sharding(
166
166
if len (set (device_type_list )) == 1 :
167
167
return device_type_list [0 ]
168
168
else :
169
- assert (
170
- ps .sharding_type == "row_wise"
171
- ), "Only row_wise sharding supports sharding across multiple device types for a table"
169
+ assert ps .sharding_type in [
170
+ ShardingType .ROW_WISE .value ,
171
+ ShardingType .COLUMN_WISE .value ,
172
+ ], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
172
173
return device_type_list
173
174
174
175
You can’t perform that action at this time.
0 commit comments