File tree Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -133,9 +133,10 @@ def get_device_from_parameter_sharding(
133133 if len (set (device_type_list )) == 1 :
134134 return device_type_list [0 ]
135135 else :
136- assert (
137- ps .sharding_type == "row_wise"
138- ), "Only row_wise sharding supports sharding across multiple device types for a table"
136+ assert ps .sharding_type in [
137+ ShardingType .ROW_WISE .value ,
138+ ShardingType .COLUMN_WISE .value ,
139+ ], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
139140 return device_type_list
140141
141142
Original file line number Diff line number Diff line change @@ -166,9 +166,10 @@ def get_device_from_parameter_sharding(
166166 if len (set (device_type_list )) == 1 :
167167 return device_type_list [0 ]
168168 else :
169- assert (
170- ps .sharding_type == "row_wise"
171- ), "Only row_wise sharding supports sharding across multiple device types for a table"
169+ assert ps .sharding_type in [
170+ ShardingType .ROW_WISE .value ,
171+ ShardingType .COLUMN_WISE .value ,
172+ ], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
172173 return device_type_list
173174
174175
You can’t perform that action at this time.
0 commit comments