@@ -1132,6 +1132,208 @@ def test_stacking_and_sharding_table(self, delete_input):
11321132 for tbl in feature_tables .values ():
11331133 self .assertEqual (tbl .is_deleted (), delete_input )
11341134
1135+ def test_get_stacked_row_ids (self ):
1136+ vocab_size_a = 64
1137+ embedding_dim_a = 4
1138+
1139+ vocab_size_b = 192
1140+ embedding_dim_b = 5
1141+
1142+ vocab_size_c = 224
1143+ embedding_dim_c = 6
1144+
1145+ batch_size = 16
1146+
1147+ table_a_spec = embedding_spec .TableSpec (
1148+ vocabulary_size = vocab_size_a ,
1149+ embedding_dim = embedding_dim_a ,
1150+ initializer = jax .nn .initializers .constant (0.0 ),
1151+ optimizer = embedding_spec .SGDOptimizerSpec (),
1152+ combiner = 'sum' ,
1153+ name = 'table_a' ,
1154+ )
1155+
1156+ table_b_spec = embedding_spec .TableSpec (
1157+ vocabulary_size = vocab_size_b ,
1158+ embedding_dim = embedding_dim_b ,
1159+ initializer = jax .nn .initializers .constant (1.0 ),
1160+ optimizer = embedding_spec .SGDOptimizerSpec (),
1161+ combiner = 'sum' ,
1162+ name = 'table_b' ,
1163+ )
1164+
1165+ table_c_spec = embedding_spec .TableSpec (
1166+ vocabulary_size = vocab_size_c ,
1167+ embedding_dim = embedding_dim_c ,
1168+ initializer = jax .nn .initializers .constant (2.0 ),
1169+ optimizer = embedding_spec .SGDOptimizerSpec (),
1170+ combiner = 'sum' ,
1171+ name = 'table_c' ,
1172+ )
1173+
1174+ feature_a_spec = embedding_spec .FeatureSpec (
1175+ table_spec = table_a_spec ,
1176+ input_shape = [batch_size , 1 ],
1177+ output_shape = [batch_size , embedding_dim_a ],
1178+ name = 'feature_a' ,
1179+ )
1180+
1181+ feature_b_spec = embedding_spec .FeatureSpec (
1182+ table_spec = table_b_spec ,
1183+ input_shape = [batch_size , 1 ],
1184+ output_shape = [batch_size , embedding_dim_b ],
1185+ name = 'feature_b' ,
1186+ )
1187+
1188+ feature_c_spec = embedding_spec .FeatureSpec (
1189+ table_spec = table_c_spec ,
1190+ input_shape = [batch_size , 1 ],
1191+ output_shape = [batch_size , embedding_dim_c ],
1192+ name = 'feature_c' ,
1193+ )
1194+
1195+ # Prepare feature specs with stacking
1196+ feature_specs = [feature_a_spec , feature_b_spec , feature_c_spec ]
1197+ table_stacking .auto_stack_tables (
1198+ feature_specs ,
1199+ num_sc_per_device = self .num_sc_per_device ,
1200+ global_device_count = jax .device_count (),
1201+ )
1202+ logging .vlog (1 , 'feature_specs_a_b: %s' , feature_specs )
1203+
1204+ updated_table_spec_a = feature_specs [0 ].table_spec
1205+ updated_table_spec_b = feature_specs [1 ].table_spec
1206+ updated_table_spec_c = feature_specs [2 ].table_spec
1207+
1208+ # prepare arrays in unpadded and sharded forms
1209+ mesh = jax .sharding .Mesh (jax .devices (), 'data' )
1210+ sharding = jax .sharding .NamedSharding (
1211+ mesh , jax .sharding .PartitionSpec ('data' )
1212+ )
1213+ table_a_sharded = jax .device_put (
1214+ test_utils .row_id_initializer (
1215+ (
1216+ vocab_size_a ,
1217+ embedding_dim_a ,
1218+ ),
1219+ offset = 10 ,
1220+ ),
1221+ device = sharding ,
1222+ )
1223+ table_b_sharded = jax .device_put (
1224+ test_utils .row_id_initializer (
1225+ (
1226+ vocab_size_b ,
1227+ embedding_dim_b ,
1228+ ),
1229+ offset = 100 ,
1230+ ),
1231+ device = sharding ,
1232+ )
1233+ table_c_sharded = jax .device_put (
1234+ test_utils .row_id_initializer (
1235+ (
1236+ vocab_size_c ,
1237+ embedding_dim_c ,
1238+ ),
1239+ offset = 1000 ,
1240+ ),
1241+ device = sharding ,
1242+ )
1243+
1244+ # pad tables
1245+ table_a_padded = jnp .pad (
1246+ table_a_sharded ,
1247+ (
1248+ (
1249+ 0 ,
1250+ updated_table_spec_a .setting_in_stack .padded_vocab_size
1251+ - vocab_size_a ,
1252+ ),
1253+ (
1254+ 0 ,
1255+ updated_table_spec_a .setting_in_stack .padded_embedding_dim
1256+ - embedding_dim_a ,
1257+ ),
1258+ ),
1259+ )
1260+ table_b_padded = jnp .pad (
1261+ table_b_sharded ,
1262+ (
1263+ (
1264+ 0 ,
1265+ updated_table_spec_b .setting_in_stack .padded_vocab_size
1266+ - vocab_size_b ,
1267+ ),
1268+ (
1269+ 0 ,
1270+ updated_table_spec_b .setting_in_stack .padded_embedding_dim
1271+ - embedding_dim_b ,
1272+ ),
1273+ ),
1274+ )
1275+ table_c_padded = jnp .pad (
1276+ table_c_sharded ,
1277+ (
1278+ (
1279+ 0 ,
1280+ updated_table_spec_c .setting_in_stack .padded_vocab_size
1281+ - vocab_size_c ,
1282+ ),
1283+ (
1284+ 0 ,
1285+ updated_table_spec_c .setting_in_stack .padded_embedding_dim
1286+ - embedding_dim_c ,
1287+ ),
1288+ ),
1289+ )
1290+
1291+ logging .vlog (1 , 'table_a_padded: \n %s' , table_a_padded )
1292+ logging .vlog (1 , 'table_b_padded: \n %s' , table_b_padded )
1293+ logging .vlog (1 , 'table_c_padded: \n %s' , table_c_padded )
1294+
1295+ feature_tables = {
1296+ table_a_spec .name : table_a_sharded ,
1297+ table_b_spec .name : table_b_sharded ,
1298+ table_c_spec .name : table_c_sharded ,
1299+ }
1300+ table_spec_proto = embedding .create_proto_from_feature_specs (
1301+ feature_specs ,
1302+ global_device_count = jax .device_count (),
1303+ num_sparsecore_per_device = self .num_sc_per_device ,
1304+ )
1305+ stacked_tables = table_stacking .stack_and_shard_feature_tables (
1306+ feature_tables ,
1307+ table_spec_proto ,
1308+ delete_input = False ,
1309+ )
1310+
1311+ for stacked_table_spec in table_spec_proto .stacked_table_specs :
1312+ for table_spec in stacked_table_spec .table_specs :
1313+ # look up all the row ids
1314+ stack_row_ids = table_stacking .get_row_ids_in_stacked_table (
1315+ stacked_table_spec ,
1316+ table_spec ,
1317+ list (range (table_spec .vocab_size )),
1318+ )
1319+
1320+ # validate each row is equavalent
1321+ for i in range (table_spec .vocab_size ):
1322+ stacked_row = stacked_tables [stacked_table_spec .stack_name ][
1323+ stack_row_ids [i ]
1324+ ][: table_spec .embedding_dim ]
1325+
1326+ self .assertTrue (
1327+ jnp .array_equal (
1328+ stacked_row ,
1329+ feature_tables [table_spec .table_name ][i ],
1330+ ),
1331+ 'MISMATCH at'
1332+ f'row={ i } :\n '
1333+ f'{ stacked_row } \n '
1334+ f'{ feature_tables [table_spec .table_name ][i ]} ' ,
1335+ )
1336+
11351337
11361338if __name__ == '__main__' :
11371339 absltest .main ()
0 commit comments