Skip to content

Commit ba4fb07

Browse files
kausvfacebook-github-bot
authored andcommitted
Add support for Write Dist (#3347)
Summary: Pull Request resolved: #3347 Differential Revision: D78749760
1 parent d8f7e28 commit ba4fb07

11 files changed

+701
-15
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,10 @@ def __init__(
19441944
assert (
19451945
config.is_using_virtual_table
19461946
), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables"
1947+
assert embedding_cache_mode == config.enable_embedding_update, (
1948+
f"Embedding_cache kernel is {embedding_cache_mode} "
1949+
f"but embedding config has enable_embedding_update {config.enable_embedding_update}"
1950+
)
19471951
for table in config.embedding_tables:
19481952
assert table.local_cols % 4 == 0, (
19491953
f"table {table.name} has local_cols={table.local_cols} "

torchrec/distributed/dist_data.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,12 @@ def __init__(
368368
# https://github.com/pytorch/pytorch/issues/122788
369369
with record_function(f"## all2all_data:kjt {label} ##"):
370370
if self._pg._get_backend_name() == "custom":
371+
if input_tensor.dim() == 2:
372+
output_size = [sum(output_split), input_tensor.size(1)]
373+
else:
374+
output_size = [sum(output_split)]
371375
output_tensor = torch.empty(
372-
sum(output_split),
376+
output_size,
373377
device=self._device,
374378
dtype=input_tensor.dtype,
375379
)
@@ -391,8 +395,12 @@ def __init__(
391395
)
392396
self._output_tensors.append(output_tensor)
393397
else:
398+
if input_tensor.dim() == 2:
399+
output_size = [sum(output_split), input_tensor.size(1)]
400+
else:
401+
output_size = [sum(output_split)]
394402
output_tensor = torch.empty(
395-
sum(output_split), device=self._device, dtype=input_tensor.dtype
403+
output_size, device=self._device, dtype=input_tensor.dtype
396404
)
397405
with record_function(f"## all2all_data:kjt {label} ##"):
398406
awaitable = dist.all_to_all_single(
@@ -542,6 +550,111 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable:
542550
)
543551

544552

553+
class KJEAllToAll(nn.Module):
554+
"""
555+
Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits.
556+
557+
Implementation utilizes AlltoAll collective as part of torch.distributed.
558+
559+
The input provides the necessary tensors and input splits to distribute.
560+
The first collective call in `KJTAllToAllSplitsAwaitable` will transmit output
561+
splits (to allocate correct space for tensors) and batch size per rank. The
562+
following collective calls in `KJTAllToAllTensorsAwaitable` will transmit the actual
563+
tensors asynchronously.
564+
565+
Args:
566+
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
567+
splits (List[int]): List of len(pg.size()) which indicates how many features to
568+
send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by
569+
destination rank. Same for all ranks.
570+
stagger (int): stagger value to apply to recat tensor, see `_get_recat` function
571+
for more detail.
572+
573+
Example::
574+
575+
keys=['A','B','C']
576+
splits=[2,1]
577+
kjtA2A = KJTAllToAll(pg, splits)
578+
awaitable = kjtA2A(rank0_input)
579+
580+
# where:
581+
# rank0_input is KeyedJaggedTensor holding
582+
583+
# 0 1 2
584+
# 'A' [A.V0] None [A.V1, A.V2]
585+
# 'B' None [B.V0] [B.V1]
586+
# 'C' [C.V0] [C.V1] None
587+
588+
# rank1_input is KeyedJaggedTensor holding
589+
590+
# 0 1 2
591+
# 'A' [A.V3] [A.V4] None
592+
# 'B' None [B.V2] [B.V3, B.V4]
593+
# 'C' [C.V2] [C.V3] None
594+
595+
rank0_output = awaitable.wait()
596+
597+
# where:
598+
# rank0_output is KeyedJaggedTensor holding
599+
600+
# 0 1 2 3 4 5
601+
# 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None
602+
# 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4]
603+
604+
# rank1_output is KeyedJaggedTensor holding
605+
# 0 1 2 3 4 5
606+
# 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None
607+
"""
608+
609+
def __init__(
610+
self,
611+
pg: dist.ProcessGroup,
612+
splits: List[int],
613+
stagger: int = 1,
614+
) -> None:
615+
super().__init__()
616+
torch._check(len(splits) == pg.size())
617+
self._pg: dist.ProcessGroup = pg
618+
self._splits = splits
619+
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))
620+
self._stagger = stagger
621+
622+
def forward(
623+
self, input: KeyedJaggedTensor
624+
) -> Awaitable[KJTAllToAllTensorsAwaitable]:
625+
"""
626+
Sends input to relevant `ProcessGroup` ranks.
627+
628+
The first wait will get the output splits for the provided tensors and issue
629+
tensors AlltoAll. The second wait will get the tensors.
630+
631+
Args:
632+
input (KeyedJaggedTensor): `KeyedJaggedTensor` of values to distribute.
633+
634+
Returns:
635+
Awaitable[KJTAllToAllTensorsAwaitable]: awaitable of a `KJTAllToAllTensorsAwaitable`.
636+
"""
637+
638+
with torch.no_grad():
639+
assert len(input.keys()) == sum(self._splits)
640+
rank = dist.get_rank(self._pg)
641+
local_keys = input.keys()[
642+
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
643+
]
644+
645+
return KJTAllToAllSplitsAwaitable(
646+
pg=self._pg,
647+
input=input,
648+
splits=self._splits,
649+
labels=input.dist_labels(),
650+
tensor_splits=input.dist_splits(self._splits),
651+
input_tensors=input.dist_tensors(),
652+
keys=local_keys,
653+
device=input.device(),
654+
stagger=self._stagger,
655+
)
656+
657+
545658
class KJTAllToAll(nn.Module):
546659
"""
547660
Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits.

torchrec/distributed/embedding.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,17 @@ def __init__(
468468
for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items()
469469
}
470470

471+
self.create_write_dist: bool = True # TODO[Kaus]: Create config to control this
471472
self._device = device
472473
self._input_dists: List[nn.Module] = []
474+
self._write_dists: List[nn.Module] = []
473475
self._lookups: List[nn.Module] = []
476+
self._updates: List[nn.Module] = []
474477
self._create_lookups()
475478
self._output_dists: List[nn.Module] = []
476479
self._create_output_dist()
477480

481+
self._write_splits: List[int] = []
478482
self._feature_splits: List[int] = []
479483
self._features_order: List[int] = []
480484

@@ -631,6 +635,7 @@ def create_grouped_sharding_infos(
631635
total_num_buckets=config.total_num_buckets,
632636
use_virtual_table=config.use_virtual_table,
633637
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
638+
enable_embedding_update=config.enable_embedding_update,
634639
),
635640
param_sharding=parameter_sharding,
636641
param=param,
@@ -1308,7 +1313,9 @@ def _create_input_dist(
13081313

13091314
def _create_lookups(self) -> None:
13101315
for sharding in self._sharding_type_to_sharding.values():
1311-
self._lookups.append(sharding.create_lookup())
1316+
lookup = sharding.create_lookup()
1317+
self._updates.append(sharding.create_update(lookup))
1318+
self._lookups.append(lookup)
13121319

13131320
def _create_output_dist(
13141321
self,
@@ -1627,6 +1634,37 @@ def fused_optimizer(self) -> KeyedOptimizer:
16271634
def create_context(self) -> EmbeddingCollectionContext:
16281635
return EmbeddingCollectionContext(sharding_contexts=[])
16291636

1637+
def _create_write_dist(self) -> None:
1638+
for sharding in self._sharding_type_to_sharding.values():
1639+
self._write_dists.append(sharding.create_write_dist())
1640+
self._write_splits.append(sharding._get_num_writable_features())
1641+
1642+
# pyre-ignore [14]
1643+
def write_dist(
1644+
self, ctx: EmbeddingCollectionContext, embeddings: KeyedJaggedTensor
1645+
) -> Awaitable[Awaitable[KJTList]]:
1646+
if not self._write_dists:
1647+
self._create_write_dist()
1648+
with torch.no_grad():
1649+
embeddings_by_shards = embeddings.split(self._write_splits)
1650+
awaitables = []
1651+
for write_dist, embeddings in zip(self._write_dists, embeddings_by_shards):
1652+
awaitables.append(write_dist(embeddings))
1653+
1654+
return KJTListSplitsAwaitable(
1655+
awaitables,
1656+
ctx,
1657+
self._module_fqn,
1658+
list(self._sharding_type_to_sharding.keys()),
1659+
)
1660+
1661+
def update(self, ctx: EmbeddingCollectionContext, dist_input: KJTList) -> None:
1662+
for update, embeddings in zip(
1663+
self._updates,
1664+
dist_input,
1665+
):
1666+
return update(embeddings)
1667+
16301668

16311669
class EmbeddingCollectionSharder(BaseEmbeddingSharder[EmbeddingCollection]):
16321670
def __init__(

torchrec/distributed/embedding_lookup.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from abc import ABC
1212
from collections import OrderedDict
13-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
13+
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.distributed as dist
@@ -39,6 +39,7 @@
3939
BatchedFusedEmbeddingBag,
4040
KeyValueEmbedding,
4141
KeyValueEmbeddingBag,
42+
ZeroCollisionEmbeddingCache,
4243
ZeroCollisionKeyValueEmbedding,
4344
ZeroCollisionKeyValueEmbeddingBag,
4445
)
@@ -49,6 +50,7 @@
4950
from torchrec.distributed.embedding_kernel import BaseEmbedding
5051
from torchrec.distributed.embedding_types import (
5152
BaseEmbeddingLookup,
53+
BaseEmbeddingUpdate,
5254
BaseGroupedFeatureProcessor,
5355
EmbeddingComputeKernel,
5456
GroupedEmbeddingConfig,
@@ -249,12 +251,20 @@ def _create_embedding_kernel(
249251
)
250252
elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE:
251253
# for dram kv
252-
return ZeroCollisionKeyValueEmbedding(
253-
config=config,
254-
pg=pg,
255-
device=device,
256-
backend_type=BackendType.DRAM,
257-
)
254+
if config.enable_embedding_update:
255+
return ZeroCollisionEmbeddingCache(
256+
config=config,
257+
pg=pg,
258+
device=device,
259+
backend_type=BackendType.DRAM,
260+
)
261+
else:
262+
return ZeroCollisionKeyValueEmbedding(
263+
config=config,
264+
pg=pg,
265+
device=device,
266+
backend_type=BackendType.DRAM,
267+
)
258268
else:
259269
raise ValueError(f"Compute kernel not supported {config.compute_kernel}")
260270

@@ -411,6 +421,33 @@ def purge(self) -> None:
411421
emb_module.purge()
412422

413423

424+
class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]):
425+
"""
426+
Update modules for Sequence embeddings (i.e Embeddings)
427+
"""
428+
429+
def __init__(
430+
self,
431+
grouped_embeddings_lookup: GroupedEmbeddingsLookup,
432+
) -> None:
433+
super().__init__()
434+
self._emb_modules: List[BaseEmbedding] = []
435+
self._feature_splits: List[int] = []
436+
for emb_module in grouped_embeddings_lookup._emb_modules:
437+
emb_module = cast(BaseBatchedEmbedding[torch.Tensor], emb_module)
438+
if emb_module.config.enable_embedding_update:
439+
self._feature_splits.append(emb_module.config.num_features())
440+
self._emb_modules.append(emb_module)
441+
442+
def forward(self, embeddings: KeyedJaggedTensor) -> None:
443+
features_by_group = embeddings.split(
444+
self._feature_splits,
445+
)
446+
for emb_module, features in zip(self._emb_modules, features_by_group):
447+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
448+
emb_module.update(features)
449+
450+
414451
class CommOpGradientScaling(torch.autograd.Function):
415452
@staticmethod
416453
# pyre-ignore

0 commit comments

Comments
 (0)