Skip to content

Commit 24fa190

Browse files
ChromeHeartsGoogle-ML-Automation
authored andcommitted
Provide an api to look up the equvalent stack row ids for a given stack_table_spec and table_spec
PiperOrigin-RevId: 747435196
1 parent 5542696 commit 24fa190

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed

jax_tpu_embedding/sparsecore/lib/nn/table_stacking.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,36 @@ def stack_and_shard_feature_tables(
733733
)
734734

735735
return ret
736+
737+
738+
def get_row_ids_in_stacked_table(
739+
stack_table_spec: embedding_spec_pb2.StackedTableSpecProto,
740+
table_spec: embedding_spec_pb2.TableSpecProto,
741+
row_ids: Sequence[int],
742+
) -> Sequence[int]:
743+
"""Returns the stacked table's row ids for the given unsharded table's row ids.
744+
745+
Args:
746+
stack_table_spec: StackedTableSpecProto describing the stacked table
747+
table_spec: TableSpecProto of the unsharded table
748+
row_ids: Squence of row ids of the unsharded table
749+
750+
Returns:
751+
Row ids of the stacked table
752+
"""
753+
ret = []
754+
755+
num_sparse_cores = stack_table_spec.num_sparsecores
756+
stack_shard_size = stack_table_spec.stack_vocab_size // num_sparse_cores
757+
758+
for row_id in row_ids:
759+
assert (
760+
row_id < table_spec.vocab_size
761+
), f"{row_id} execeeds available vocabulary size [{table_spec.vocab_size}]."
762+
shard_id = (
763+
row_id % num_sparse_cores + table_spec.shard_rotation
764+
) % num_sparse_cores
765+
sharded_row_id = row_id // num_sparse_cores + table_spec.row_offset_in_shard
766+
ret.append(shard_id * stack_shard_size + sharded_row_id)
767+
768+
return ret

jax_tpu_embedding/sparsecore/lib/nn/tests/table_stacking_test.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,208 @@ def test_stacking_and_sharding_table(self, delete_input):
11321132
for tbl in feature_tables.values():
11331133
self.assertEqual(tbl.is_deleted(), delete_input)
11341134

1135+
def test_get_stacked_row_ids(self):
1136+
vocab_size_a = 64
1137+
embedding_dim_a = 4
1138+
1139+
vocab_size_b = 192
1140+
embedding_dim_b = 5
1141+
1142+
vocab_size_c = 224
1143+
embedding_dim_c = 6
1144+
1145+
batch_size = 16
1146+
1147+
table_a_spec = embedding_spec.TableSpec(
1148+
vocabulary_size=vocab_size_a,
1149+
embedding_dim=embedding_dim_a,
1150+
initializer=jax.nn.initializers.constant(0.0),
1151+
optimizer=embedding_spec.SGDOptimizerSpec(),
1152+
combiner='sum',
1153+
name='table_a',
1154+
)
1155+
1156+
table_b_spec = embedding_spec.TableSpec(
1157+
vocabulary_size=vocab_size_b,
1158+
embedding_dim=embedding_dim_b,
1159+
initializer=jax.nn.initializers.constant(1.0),
1160+
optimizer=embedding_spec.SGDOptimizerSpec(),
1161+
combiner='sum',
1162+
name='table_b',
1163+
)
1164+
1165+
table_c_spec = embedding_spec.TableSpec(
1166+
vocabulary_size=vocab_size_c,
1167+
embedding_dim=embedding_dim_c,
1168+
initializer=jax.nn.initializers.constant(2.0),
1169+
optimizer=embedding_spec.SGDOptimizerSpec(),
1170+
combiner='sum',
1171+
name='table_c',
1172+
)
1173+
1174+
feature_a_spec = embedding_spec.FeatureSpec(
1175+
table_spec=table_a_spec,
1176+
input_shape=[batch_size, 1],
1177+
output_shape=[batch_size, embedding_dim_a],
1178+
name='feature_a',
1179+
)
1180+
1181+
feature_b_spec = embedding_spec.FeatureSpec(
1182+
table_spec=table_b_spec,
1183+
input_shape=[batch_size, 1],
1184+
output_shape=[batch_size, embedding_dim_b],
1185+
name='feature_b',
1186+
)
1187+
1188+
feature_c_spec = embedding_spec.FeatureSpec(
1189+
table_spec=table_c_spec,
1190+
input_shape=[batch_size, 1],
1191+
output_shape=[batch_size, embedding_dim_c],
1192+
name='feature_c',
1193+
)
1194+
1195+
# Prepare feature specs with stacking
1196+
feature_specs = [feature_a_spec, feature_b_spec, feature_c_spec]
1197+
table_stacking.auto_stack_tables(
1198+
feature_specs,
1199+
num_sc_per_device=self.num_sc_per_device,
1200+
global_device_count=jax.device_count(),
1201+
)
1202+
logging.vlog(1, 'feature_specs_a_b: %s', feature_specs)
1203+
1204+
updated_table_spec_a = feature_specs[0].table_spec
1205+
updated_table_spec_b = feature_specs[1].table_spec
1206+
updated_table_spec_c = feature_specs[2].table_spec
1207+
1208+
# prepare arrays in unpadded and sharded forms
1209+
mesh = jax.sharding.Mesh(jax.devices(), 'data')
1210+
sharding = jax.sharding.NamedSharding(
1211+
mesh, jax.sharding.PartitionSpec('data')
1212+
)
1213+
table_a_sharded = jax.device_put(
1214+
test_utils.row_id_initializer(
1215+
(
1216+
vocab_size_a,
1217+
embedding_dim_a,
1218+
),
1219+
offset=10,
1220+
),
1221+
device=sharding,
1222+
)
1223+
table_b_sharded = jax.device_put(
1224+
test_utils.row_id_initializer(
1225+
(
1226+
vocab_size_b,
1227+
embedding_dim_b,
1228+
),
1229+
offset=100,
1230+
),
1231+
device=sharding,
1232+
)
1233+
table_c_sharded = jax.device_put(
1234+
test_utils.row_id_initializer(
1235+
(
1236+
vocab_size_c,
1237+
embedding_dim_c,
1238+
),
1239+
offset=1000,
1240+
),
1241+
device=sharding,
1242+
)
1243+
1244+
# pad tables
1245+
table_a_padded = jnp.pad(
1246+
table_a_sharded,
1247+
(
1248+
(
1249+
0,
1250+
updated_table_spec_a.setting_in_stack.padded_vocab_size
1251+
- vocab_size_a,
1252+
),
1253+
(
1254+
0,
1255+
updated_table_spec_a.setting_in_stack.padded_embedding_dim
1256+
- embedding_dim_a,
1257+
),
1258+
),
1259+
)
1260+
table_b_padded = jnp.pad(
1261+
table_b_sharded,
1262+
(
1263+
(
1264+
0,
1265+
updated_table_spec_b.setting_in_stack.padded_vocab_size
1266+
- vocab_size_b,
1267+
),
1268+
(
1269+
0,
1270+
updated_table_spec_b.setting_in_stack.padded_embedding_dim
1271+
- embedding_dim_b,
1272+
),
1273+
),
1274+
)
1275+
table_c_padded = jnp.pad(
1276+
table_c_sharded,
1277+
(
1278+
(
1279+
0,
1280+
updated_table_spec_c.setting_in_stack.padded_vocab_size
1281+
- vocab_size_c,
1282+
),
1283+
(
1284+
0,
1285+
updated_table_spec_c.setting_in_stack.padded_embedding_dim
1286+
- embedding_dim_c,
1287+
),
1288+
),
1289+
)
1290+
1291+
logging.vlog(1, 'table_a_padded: \n%s', table_a_padded)
1292+
logging.vlog(1, 'table_b_padded: \n%s', table_b_padded)
1293+
logging.vlog(1, 'table_c_padded: \n%s', table_c_padded)
1294+
1295+
feature_tables = {
1296+
table_a_spec.name: table_a_sharded,
1297+
table_b_spec.name: table_b_sharded,
1298+
table_c_spec.name: table_c_sharded,
1299+
}
1300+
table_spec_proto = embedding.create_proto_from_feature_specs(
1301+
feature_specs,
1302+
global_device_count=jax.device_count(),
1303+
num_sparsecore_per_device=self.num_sc_per_device,
1304+
)
1305+
stacked_tables = table_stacking.stack_and_shard_feature_tables(
1306+
feature_tables,
1307+
table_spec_proto,
1308+
delete_input=False,
1309+
)
1310+
1311+
for stacked_table_spec in table_spec_proto.stacked_table_specs:
1312+
for table_spec in stacked_table_spec.table_specs:
1313+
# look up all the row ids
1314+
stack_row_ids = table_stacking.get_row_ids_in_stacked_table(
1315+
stacked_table_spec,
1316+
table_spec,
1317+
list(range(table_spec.vocab_size)),
1318+
)
1319+
1320+
# validate each row is equavalent
1321+
for i in range(table_spec.vocab_size):
1322+
stacked_row = stacked_tables[stacked_table_spec.stack_name][
1323+
stack_row_ids[i]
1324+
][: table_spec.embedding_dim]
1325+
1326+
self.assertTrue(
1327+
jnp.array_equal(
1328+
stacked_row,
1329+
feature_tables[table_spec.table_name][i],
1330+
),
1331+
'MISMATCH at'
1332+
f'row={i}:\n'
1333+
f'{stacked_row}\n'
1334+
f'{feature_tables[table_spec.table_name][i]}',
1335+
)
1336+
11351337

11361338
if __name__ == '__main__':
11371339
absltest.main()

0 commit comments

Comments
 (0)