Skip to content

Commit ee2355e

Browse files
Fei Yufacebook-github-bot
authored andcommitted
directly pass update_util as int flag without syncing iter (#3293)
Summary: Pull Request resolved: #3293 as title, this issue continue exists in most recent ITEP experiments when we only apply ITEP on the baseline without changing batch size and/or trainer numbers. from recent MAI results in f777920760, we see about 3.5% QPS gap with ITEP enabled (393 vs 403 P90) issues visible in trace https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2Faps-aps-mai_to_flow-777920760-f777930060%2F0%2Frank-0.Aug_11_01_48_39.4443.pt.trace.json.gz&bucket=aps_traces {F1981311699} Reviewed By: doIIarplus Differential Revision: D67302872 fbshipit-source-id: dbb33a76e44f6dbeb62a80b19d4e7d97286ebec4
1 parent 3f8fdbb commit ee2355e

File tree

4 files changed

+456
-21
lines changed

4 files changed

+456
-21
lines changed

torchrec/distributed/itep_embeddingbag.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ def compute(
177177
ctx: ITEPEmbeddingBagCollectionContext,
178178
dist_input: KJTList,
179179
) -> List[torch.Tensor]:
180+
# We need to explicitly move iter to CPU since it might be moved to GPU
181+
# after __init__. This should be done once.
182+
self._iter = self._iter.cpu()
183+
180184
if not ctx.is_reindexed:
181185
dist_input = self._reindex(dist_input)
182186
ctx.is_reindexed = True
@@ -196,6 +200,10 @@ def output_dist(
196200
def compute_and_output_dist(
197201
self, ctx: ITEPEmbeddingBagCollectionContext, input: KJTList
198202
) -> LazyAwaitable[KeyedTensor]:
203+
# We need to explicitly move iter to CPU since it might be moved to GPU
204+
# after __init__. This should be done once.
205+
self._iter = self._iter.cpu()
206+
199207
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
200208
for i, (sharding, features) in enumerate(
201209
zip(
@@ -424,6 +432,10 @@ def compute(
424432
ctx: ITEPEmbeddingCollectionContext,
425433
dist_input: KJTList,
426434
) -> List[torch.Tensor]:
435+
# We need to explicitly move iter to CPU since it might be moved to GPU
436+
# after __init__. This should be done once.
437+
self._iter = self._iter.cpu()
438+
427439
for i, (sharding, features) in enumerate(
428440
zip(
429441
self._embedding_collection._sharding_type_to_sharding.keys(),
@@ -450,6 +462,10 @@ def output_dist(
450462
def compute_and_output_dist(
451463
self, ctx: ITEPEmbeddingCollectionContext, input: KJTList
452464
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
465+
# We need to explicitly move iter to CPU since it might be moved to GPU
466+
# after __init__. This should be done once.
467+
self._iter = self._iter.cpu()
468+
453469
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
454470
""" """
455471
for i, (sharding, features) in enumerate(

torchrec/modules/itep_embedding_modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def forward(
7979

8080
features = self._itep_module(features, self._iter.item())
8181
pooled_embeddings = self._embedding_bag_collection(features)
82+
8283
self._iter += 1
8384

8485
return pooled_embeddings

torchrec/modules/itep_modules.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,13 +514,13 @@ def forward(
514514
feature_offsets,
515515
) = self.get_remap_info(sparse_features)
516516

517-
update_utils: bool = (
517+
update_util: bool = (
518518
(cur_iter < 10)
519519
or (cur_iter < 100 and (cur_iter + 1) % 19 == 0)
520520
or ((cur_iter + 1) % 39 == 0)
521521
)
522522
full_values_list = None
523-
if update_utils and sparse_features.variable_stride_per_key():
523+
if update_util and sparse_features.variable_stride_per_key():
524524
if sparse_features.inverse_indices_or_none() is not None:
525525
# full util update mode require reconstructing original input indicies from VBE input
526526
full_values_list = self.get_full_values_list(sparse_features)
@@ -531,7 +531,7 @@ def forward(
531531
)
532532

533533
remapped_values = torch.ops.fbgemm.remap_indices_update_utils(
534-
cur_iter,
534+
int(cur_iter),
535535
buffer_idx,
536536
feature_lengths,
537537
feature_offsets,
@@ -540,6 +540,7 @@ def forward(
540540
self.row_util,
541541
self.buffer_offsets,
542542
full_values_list=full_values_list,
543+
update_util=update_util,
543544
)
544545

545546
sparse_features._values = remapped_values

0 commit comments

Comments
 (0)