diff --git a/torchrec/distributed/benchmark/base.py b/torchrec/distributed/benchmark/base.py index 0821fe579..75cac5e35 100644 --- a/torchrec/distributed/benchmark/base.py +++ b/torchrec/distributed/benchmark/base.py @@ -633,6 +633,9 @@ def _run_benchmark_core( export_stacks: Whether to export flamegraph-compatible stack files. reset_accumulated_memory_stats: Whether to reset accumulated memory stats in addition to peak memory stats. + all_rank_traces: Whether to save traces for all ranks or just rank 0. + memory_snapshot: Whether to capture memory snapshot during the profiling + ussage: https://docs.pytorch.org/memory_viz """ # Preparation & memory reset @@ -912,6 +915,7 @@ def benchmark_func( export_stacks: Whether to export flamegraph-compatible stack files. all_rank_traces: Whether to export traces from all ranks. memory_snapshot: Whether to capture memory snapshot during the profiling + ussage: https://docs.pytorch.org/memory_viz """ if benchmark_func_kwargs is None: benchmark_func_kwargs = {} diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 4439d4e63..c65dd033b 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -37,6 +37,7 @@ from torchrec.distributed.test_utils.input_config import ModelInputConfig from torchrec.distributed.test_utils.model_config import ( BaseModelConfig, + DLSeqConfig, generate_sharded_model_and_optimizer, ModelSelectionConfig, ) @@ -116,8 +117,7 @@ class RunOptions(BenchFuncConfig): def runner( rank: int, world_size: int, - tables: List[EmbeddingBagConfig], - weighted_tables: List[EmbeddingBagConfig], + table_list: List[List[EmbeddingBagConfig]], run_option: RunOptions, model_config: BaseModelConfig, pipeline_config: PipelineConfig, @@ -136,21 +136,34 @@ def runner( backend="nccl", use_deterministic_algorithms=False, ) as ctx: - unsharded_model = model_config.generate_model( - tables=tables, - weighted_tables=weighted_tables, - dense_device=ctx.device, - ) - - # Create a planner for sharding based on the specified type - planner = planner_config.generate_planner( - tables=tables + weighted_tables, - ) - bench_inputs = input_config.generate_batches( - tables=tables, - weighted_tables=weighted_tables, - ) + if isinstance(model_config, DLSeqConfig): + bench_inputs = input_config.generate_list_batches( + tables=table_list, # pyre-ignore + table_options=model_config.table_options, + ) + unsharded_model = model_config.generate_model( + tables=table_list, # pyre-ignore + dense_device=ctx.device, + ) + planner = planner_config.generate_planner( + tables=[table for tables in table_list for table in tables], + ) + else: + tables, weighted_tables, *_ = table_list + bench_inputs = input_config.generate_std_batches( + tables=tables, + weighted_tables=weighted_tables, + ) + unsharded_model = model_config.generate_model( + tables=tables, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + # Create a planner for sharding based on the specified type + planner = planner_config.generate_planner( + tables=tables + weighted_tables, + ) # Prepare fused_params for sparse optimizer fused_params = { @@ -228,8 +241,7 @@ def run_pipeline( benchmark_res_per_rank = run_multi_process_func( func=runner, world_size=run_option.world_size, - tables=tables, - weighted_tables=weighted_tables, + table_list=[tables, weighted_tables], run_option=run_option, model_config=model_config, pipeline_config=pipeline_config, @@ -270,14 +282,14 @@ def main( input_config: ModelInputConfig, planner_config: PlannerConfig, ) -> None: - tables, weighted_tables, *_ = table_config.generate_tables() + table_list = table_config.generate_tables() model_config = model_selection.create_model_config() + # launch trainers run_multi_process_func( func=runner, world_size=run_option.world_size, - tables=tables, - weighted_tables=weighted_tables, + table_list=table_list, run_option=run_option, model_config=model_config, pipeline_config=pipeline_config, diff --git a/torchrec/distributed/benchmark/yaml/sparse_data_dist_seq.yml b/torchrec/distributed/benchmark/yaml/sparse_data_dist_seq.yml new file mode 100644 index 000000000..facf0263b --- /dev/null +++ b/torchrec/distributed/benchmark/yaml/sparse_data_dist_seq.yml @@ -0,0 +1,55 @@ +# this is a very basic sparse data dist config +# runs on 2 ranks, showing traces with reasonable workloads +RunOptions: + world_size: 2 + num_batches: 5 + num_benchmarks: 2 + sharding_type: table_wise + profile_dir: "." + name: "sparse_data_dist_seq" + # export_stacks: True # enable this to export stack traces +PipelineConfig: + pipeline: "sparse" +ModelInputConfig: + feature_pooling_avg: 10 +EmbeddingTablesConfig: + num_unweighted_features: 100 + num_weighted_features: 100 + embedding_feature_dim: 128 + additional_tables: + - - name: FP16_table + embedding_dim: 512 + num_embeddings: 100_000 + feature_names: ["additional_0_0"] + data_type: FP16 + - name: large_table + embedding_dim: 2048 + num_embeddings: 1_000_000 + feature_names: ["additional_0_1"] + - [] + - - name: ec_likes + embedding_dim: 32 + num_embeddings: 100_000 + feature_names: ["user_likes"] + config_class: EmbeddingConfig + - name: ec_clicks + embedding_dim: 32 + num_embeddings: 100_000 + feature_names: ["user_clicks"] + config_class: EmbeddingConfig +PlannerConfig: + additional_constraints: + large_table: + sharding_types: [row_wise] + ec_likes: + sharding_types: [column_wise] + ec_clicks: + sharding_types: [column_wise] +ModelSelectionConfig: + model_name: "dlseq" + model_config: + num_float_features: 10 + table_options: + - {} + - is_weighted: True + - {} diff --git a/torchrec/distributed/test_utils/input_config.py b/torchrec/distributed/test_utils/input_config.py index a0683f81f..67fe9bc8f 100644 --- a/torchrec/distributed/test_utils/input_config.py +++ b/torchrec/distributed/test_utils/input_config.py @@ -8,12 +8,16 @@ # pyre-strict from dataclasses import dataclass -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union import torch -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingConfig, + EmbeddingTableConfig, +) -from .model_input import ModelInput +from .model_input import ListModelInput, ModelInput @dataclass @@ -31,7 +35,7 @@ class ModelInputConfig: long_kjt_lengths: bool = True pin_memory: bool = True - def generate_batches( + def generate_std_batches( self, tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], @@ -63,3 +67,42 @@ def generate_batches( ) for batch_size in range(self.num_batches) ] + + def generate_list_batches( + self, + tables: List[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ], + table_options: Optional[List[Dict[str, Any]]] = None, + ) -> List[ListModelInput]: + """ + Generate model input data for benchmarking. + + Args: + tables: List of embedding tables + + Returns: + A list of ModelInput objects representing the generated batches + """ + device = torch.device(self.device) if self.device is not None else None + + return [ + ListModelInput.generate( + batch_size=self.batch_size, + tables=tables, + num_float_features=self.num_float_features, + pooling_avg=self.feature_pooling_avg, + use_offsets=self.use_offsets, + device=device, + indices_dtype=(torch.int64 if self.long_kjt_indices else torch.int32), + offsets_dtype=(torch.int64 if self.long_kjt_offsets else torch.int32), + lengths_dtype=(torch.int64 if self.long_kjt_lengths else torch.int32), + pin_memory=self.pin_memory, + table_options=table_options, + ) + for batch_size in range(self.num_batches) + ] diff --git a/torchrec/distributed/test_utils/model_config.py b/torchrec/distributed/test_utils/model_config.py index 0341a00c6..b8669b02d 100644 --- a/torchrec/distributed/test_utils/model_config.py +++ b/torchrec/distributed/test_utils/model_config.py @@ -31,7 +31,11 @@ from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.test_utils.test_model import ( + TestDenseArch, + TestMultiEmbSparseArch, + TestMultiSparseNN, TestOverArchLarge, + TestOverArchMultiEmb, TestSparseNN, TestTowerCollectionSparseNN, TestTowerSparseNN, @@ -39,7 +43,7 @@ from torchrec.distributed.types import ShardingEnv from torchrec.models.deepfm import SimpleDeepFMNNWrapper from torchrec.models.dlrm import DLRMWrapper -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection @@ -213,6 +217,44 @@ def generate_model( ) +@dataclass +class DLSeqConfig(BaseModelConfig): + """Configuration for DLRM model.""" + + table_options: Optional[List[Dict[str, Any]]] = None + dense_arch_out_size: Optional[int] = None + dense_arch_layer_sizes: Optional[List[int]] = None + over_arch_out_size: Optional[int] = None + over_arch_hidden_layers: Optional[int] = None + + # pyre-ignore[14] + def generate_model( + self, + tables: List[Union[List[EmbeddingBagCollection], List[EmbeddingConfig]]], + device: Optional[torch.device] = None, + **kwargs: Any, + ) -> nn.Module: + if device is None: + device = torch.device("cpu") + + # DLRM only uses unweighted tables + sparse = TestMultiEmbSparseArch(tables, self.table_options, device=device) + dense = TestDenseArch( + self.num_float_features, + device=device, + dense_arch_hidden_sizes=self.dense_arch_layer_sizes, + dense_arch_out_size=self.dense_arch_out_size, + ) + over = TestOverArchMultiEmb( + tables, + device=device, + dense_arch_out_size=self.dense_arch_out_size, + over_arch_out_size=self.over_arch_out_size, + over_arch_hidden_layers=self.over_arch_hidden_layers, + ) + return TestMultiSparseNN(sparse, dense, over) + + # pyre-ignore[2]: Missing parameter annotation def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: """ @@ -334,6 +376,8 @@ def get_model_config_class(self) -> Type[BaseModelConfig]: return DeepFMConfig case "dlrm": return DLRMConfig + case "dlseq": + return DLSeqConfig case _: raise ValueError(f"Unknown model name: {self.model_name}") diff --git a/torchrec/distributed/test_utils/model_input.py b/torchrec/distributed/test_utils/model_input.py index 71e8e42c1..b31e75da1 100644 --- a/torchrec/distributed/test_utils/model_input.py +++ b/torchrec/distributed/test_utils/model_input.py @@ -8,7 +8,7 @@ # pyre-strict from dataclasses import dataclass -from typing import cast, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Union import torch from tensordict import TensorDict @@ -232,6 +232,27 @@ def generate_local_batches( for _ in range(world_size) ] + @classmethod + def generate_float_features( + cls, + batch_size: int, + num_float_features: Optional[int], + all_zeros: bool, + device: Optional[torch.device], + ) -> torch.Tensor: + if num_float_features is None: # for label + return ( + torch.zeros((batch_size,), device=device) + if all_zeros + else torch.rand((batch_size,), device=device) + ) + else: + return ( + torch.zeros((batch_size, num_float_features), device=device) + if all_zeros + else torch.rand((batch_size, num_float_features), device=device) + ) + @classmethod def generate( cls, @@ -270,10 +291,8 @@ def generate( on pinned memory for a fast transfer to gpu. For more on pin_memory: https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory """ - float_features = ( - torch.zeros((batch_size, num_float_features), device=device) - if all_zeros - else torch.rand((batch_size, num_float_features), device=device) + float_features = cls.generate_float_features( + batch_size, num_float_features, all_zeros, device ) idlist_features = ( ModelInput.create_standard_kjt( @@ -311,11 +330,7 @@ def generate( if weighted_tables is not None and len(weighted_tables) > 0 else None ) - label = ( - torch.zeros((batch_size,), device=device) - if all_zeros - else torch.rand((batch_size,), device=device) - ) + label = cls.generate_float_features(batch_size, None, all_zeros, device) if pin_memory: # all tensors in `ModelInput` should be on pinned memory otherwise # the `_to_copy` (host-to-device) data transfer still blocks cpu execution @@ -555,17 +570,99 @@ def _create_batched_standard_kjts( return global_kjt, local_kjts -# @dataclass -# class VbModelInput(ModelInput): -# pass +@dataclass +class ListModelInput(ModelInput): + sparse_feature_list: List[KeyedJaggedTensor] + + def to(self, device: torch.device, non_blocking: bool = False) -> "ListModelInput": + return ListModelInput( + float_features=self.float_features.to( + device=device, non_blocking=non_blocking + ), + idlist_features=( + self.idlist_features.to(device=device, non_blocking=non_blocking) + if self.idlist_features is not None + else None + ), + idscore_features=( + self.idscore_features.to(device=device, non_blocking=non_blocking) + if self.idscore_features is not None + else None + ), + label=self.label.to(device=device, non_blocking=non_blocking), + sparse_feature_list=[ + kjt.to(device=device, non_blocking=non_blocking) + for kjt in self.sparse_feature_list + ], + ) + + @classmethod + def generate( # pyre-ignore[14] + cls, + batch_size: int, + tables: List[ + Union[ + List[EmbeddingTableConfig], + List[EmbeddingBagConfig], + List[EmbeddingConfig], + ] + ], + num_float_features: int = 16, + pooling_avg: int = 10, + tables_pooling: Optional[List[int]] = None, + max_feature_lengths: Optional[List[int]] = None, + use_offsets: bool = False, + device: Optional[torch.device] = None, + indices_dtype: torch.dtype = torch.int64, + offsets_dtype: torch.dtype = torch.int64, + lengths_dtype: torch.dtype = torch.int64, + all_zeros: bool = False, + pin_memory: bool = False, # pin_memory is needed for training job qps benchmark + table_options: Optional[List[Dict[str, Any]]] = None, + ) -> "ListModelInput": + kjt_list: List[KeyedJaggedTensor] = [] + for idx, table in enumerate(tables): + option = table_options[idx] if table_options else {} + kjt = ModelInput.create_standard_kjt( + batch_size=batch_size, + tables=table, + pooling_avg=option.get("pooling_avg", pooling_avg), + tables_pooling=option.get("tables_pooling", tables_pooling), + weighted=( + option.get("is_weighted", False) + in (True, "is_weighted", "1", "true", "True") + ), + max_feature_lengths=option.get( + "max_feature_lengths", max_feature_lengths + ), + use_offsets=option.get("use_offsets", use_offsets), + device=device, + indices_dtype=option.get("indices_dtype", indices_dtype), + offsets_dtype=option.get("offsets_dtype", offsets_dtype), + lengths_dtype=option.get("lengths_dtype", lengths_dtype), + all_zeros=option.get("all_zeros", all_zeros), + ) + kjt = kjt.pin_memory() if pin_memory else kjt + kjt_list.append(kjt) -# @staticmethod -# def _create_variable_batch_kjt() -> KeyedJaggedTensor: -# pass + float_features = cls.generate_float_features( + batch_size=batch_size, + num_float_features=num_float_features, + all_zeros=all_zeros, + device=device, + ) + label = cls.generate_float_features(batch_size, None, all_zeros, device) -# @staticmethod -# def _merge_variable_batch_kjts(kjts: List[KeyedJaggedTensor]) -> KeyedJaggedTensor: -# pass + if pin_memory: + float_features = float_features.pin_memory() + label = label.pin_memory() + return ListModelInput( + float_features=float_features, + idlist_features=None, + idscore_features=None, + label=label, + sparse_feature_list=kjt_list, + ) @dataclass diff --git a/torchrec/distributed/test_utils/table_config.py b/torchrec/distributed/test_utils/table_config.py index 764e74370..beec62666 100644 --- a/torchrec/distributed/test_utils/table_config.py +++ b/torchrec/distributed/test_utils/table_config.py @@ -8,9 +8,9 @@ # pyre-strict from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, Dict, List, Union -from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.types import DataType @@ -39,11 +39,19 @@ class EmbeddingTablesConfig: table_data_type: DataType = DataType.FP32 additional_tables: List[List[Dict[str, Any]]] = field(default_factory=list) - def convert_to_ebconf(self, kwargs: Dict[str, Any]) -> EmbeddingBagConfig: + def convert_to_ebconf( + self, kwargs: Dict[str, Any] + ) -> Union[EmbeddingConfig, EmbeddingBagConfig]: if "data_type" in kwargs: kwargs["data_type"] = DataType[kwargs["data_type"]] else: kwargs["data_type"] = self.table_data_type + if "config_class" in kwargs: + config_class = kwargs.pop("config_class") + if config_class == "EmbeddingConfig": + return EmbeddingConfig(**kwargs) + elif config_class != "EmbeddingBagConfig": + raise ValueError(f"Unknown config class: {config_class}") return EmbeddingBagConfig(**kwargs) def generate_tables( @@ -99,6 +107,7 @@ def generate_tables( tables = [] for adt in adts: tables.append(self.convert_to_ebconf(adt)) + tables_list.append(tables) if len(tables_list) == 0: tables_list.append(unweighted_tables) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index cb7004670..28138b580 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -16,7 +16,10 @@ import torch.nn as nn from tensordict import TensorDict from torchrec import EmbeddingCollection +from torchrec.distributed.embedding import ShardedEmbeddingCollection from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection +from torchrec.distributed.test_utils.model_input import ListModelInput from torchrec.distributed.utils import CopyableMixin from torchrec.modules.activation import SwishLayerNorm from torchrec.modules.embedding_configs import ( @@ -36,7 +39,12 @@ MCHManagedCollisionModule, ) from torchrec.modules.regroup import KTRegroupAsDict -from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + _to_offsets, + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) from torchrec.streamable import Pipelineable @@ -835,6 +843,7 @@ def record_stream(self, stream: torch.Stream) -> None: DENSE_LAYER_OUT_SIZE = 8 OVER_ARCH_OUT_SIZE = 16 +MAX_SEQUENCE_LENGTH = 20 def _tables_dim_sum( @@ -1263,6 +1272,109 @@ def forward( return self.overarch(_concat(dense, values)) +class TestOverArchMultiEmb(nn.Module): + """ + A simple (but larger) over arch to merge a dense arch and a sparse arch + using a KTRegroup module to permute the KT from the sparse arch + then call a sequential MLP with the concatenated sparse+dense + + Args: + tables: the embedding tables + tables: the list of embedding tables + dense_arch_out_size: the size of output dense embedding + over_arch_out_size: the size of output over arch + over_arch_hidden_layers: the number of hidden layers in the MLP + device: the device on which this module will be placed. + + Call Args: + dense: torch.Tensor, the output of a dense arch + sparse: KeyedTensor, the output of a sparse arch + + Returns: + torch.Tensor + + Example: + over_arch = TestOverArchLarge(tables, weighted_tables) + + """ + + def __init__( + self, + tables: List[Union[List[EmbeddingBagCollection], List[EmbeddingConfig]]], + table_options: Optional[List[Dict[str, Any]]] = None, + device: Optional[torch.device] = None, + dense_arch_out_size: Optional[int] = None, + over_arch_out_size: Optional[int] = None, + over_arch_hidden_layers: Optional[int] = None, + ) -> None: + """ + Args: + tables: the embedding tables + weighted_tables: the weighted embedding tables + embedding_names: the names of the embedding features + dense_arch_out_size: the size of output dense embedding + over_arch_out_size: the size of output over arch + over_arch_hidden_layers: the number of hidden layers in the MLP + device: the device on which this module will be placed. + """ + super().__init__() + if device is None: + device = torch.device("cpu") + if dense_arch_out_size is None: + dense_arch_out_size = DENSE_LAYER_OUT_SIZE + if over_arch_out_size is None: + over_arch_out_size = OVER_ARCH_OUT_SIZE + if over_arch_hidden_layers is None: + over_arch_hidden_layers = 5 + + in_features = dense_arch_out_size + for idx, table in enumerate(tables): + options = table_options[idx] if table_options else {} + if isinstance(table[0], EmbeddingConfig): + max_sequence_length = options.get( + "max_sequence_length", MAX_SEQUENCE_LENGTH + ) + else: + max_sequence_length = 1 + # pyre-ignore[6] + in_features += _tables_dim_sum(table, max_sequence_length) + + out_features = over_arch_out_size + layers = [ + torch.nn.Linear( + in_features=in_features, + out_features=out_features, + ), + SwishLayerNorm([out_features]), + ] + + for _ in range(over_arch_hidden_layers): + layers += [ + torch.nn.Linear( + in_features=out_features, + out_features=out_features, + ), + SwishLayerNorm([out_features]), + ] + + self.overarch = torch.nn.Sequential(*layers) + + def forward( + self, + dense: torch.Tensor, + sparse: List[KeyedTensor], + ) -> torch.Tensor: + """ + Args: + dense: torch.Tensor, the output of a dense arch + sparse: KeyedTensor, the output of a sparse arch + + Returns: + torch.Tensor + """ + return self.overarch(_concat(dense, [kt.values() for kt in sparse])) + + def _pad_kt_values( kt: KeyedTensor, batch_size: Optional[int] = None, @@ -1295,6 +1407,38 @@ def _pad_kt_values( ) +@torch.fx.wrap +def _pad_jt_values( + djt: Dict[str, JaggedTensor], max_sequence_length: int +) -> KeyedTensor: + """ + Pad a Dict[str, JaggedTensor] to the given batch size. + + Args: + djt: the Dict[str, JaggedTensor] to pad + batch_size: the desired batch size + device: the device on which to pad + + Returns: + The padded KeyedTensor + """ + keys: List[str] = [] + length_per_key: List[int] = [] + values: List[torch.Tensor] = [] + for key, jt in djt.items(): + keys.append(key) + dim = jt.values().size(1) + length_per_key.append(dim) + + padded_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( + values=jt.values(), + offsets=jt.offsets(), + max_sequence_length=max_sequence_length, + ) + values.append(padded_embeddings.view(-1, max_sequence_length * dim)) + return KeyedTensor.from_tensor_list(keys, values) + + @torch.fx.wrap def _post_sparsenn_forward( ebc: KeyedTensor, @@ -1484,6 +1628,129 @@ def forward( return result +class TestMultiEmbSparseArch(nn.Module): + """ + A simple sparse arch that wraps three EmbeddingBagCollection modules + + It does not merge the sparse module outputs into a single KeyedTensor. + + Args: + tables: List[EmbeddingBagConfig], + table_options: Optional[List[Dict[str, Any]]], + device: Optional[torch.device], + + Call Args: + id_list_features + id_score_list_features + + Returns: + Tuple[KeyedTensor, Optional[KeyedTensor], Optional[KeyedTensor]] + + Example: + sparse_arch = TestMultiEmbSparseArch(tables, weighted_tables) + kt_list = sparse_arch(id_list_features, id_score_list_features) + """ + + def __init__( + self, + tables: List[Union[List[EmbeddingBagCollection], List[EmbeddingConfig]]], + table_options: Optional[List[Dict[str, Any]]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self.embedding_modules = nn.ModuleList() + assert table_options is None or len(table_options) == len(tables) + self.max_sequence_lengths: Dict[int, int] = {} + for idx, table in enumerate(tables): + options = table_options[idx] if table_options else {} + self.embedding_modules.append( + self.create_embedding_module(table, device, **options) + ) + max_sequence_length = options.get( + "max_sequence_length", MAX_SEQUENCE_LENGTH + ) + self.max_sequence_lengths[idx] = max_sequence_length + + @classmethod + def create_embedding_module( + cls, + tables: Union[List[EmbeddingBagCollection], List[EmbeddingConfig]], + device: Optional[torch.device] = None, + is_weighted: bool = False, + max_feature_length: Optional[int] = None, + max_sequence_length: Optional[int] = None, + ) -> nn.Module: + if device is None: + device = torch.device("cpu") + if isinstance(tables[0], EmbeddingConfig): + # EC + return EmbeddingCollection( + tables=cast(List[EmbeddingConfig], tables), + device=device, + ) + elif is_weighted: + # weighted EBC + return EmbeddingBagCollection( + tables=cast(List[EmbeddingBagConfig], tables), + device=device, + is_weighted=True, + ) + elif max_feature_length: + # FP_EBC + max_feature_lengths: Dict[str, int] = { + str(table.name): max_feature_length for table in tables + } + return FeatureProcessedEmbeddingBagCollection( + embedding_bag_collection=EmbeddingBagCollection( + tables=cast(List[EmbeddingBagConfig], tables), + device=device, + is_weighted=True, + ), + feature_processors=PositionWeightedModuleCollection( + max_feature_lengths=max_feature_lengths, + device=( + device + if device != torch.device("meta") + else torch.device("cpu") + ), + ), + ) + else: + # EBC + return EmbeddingBagCollection( + tables=cast(List[EmbeddingBagConfig], tables), + device=device, + ) + + def forward( + self, + input_list: List[KeyedJaggedTensor], + batch_size: Optional[int] = None, + ) -> List[KeyedTensor]: + """ + Args: + input_list: List[KeyedJaggedTensor], + batch_size: Optional[int], + + Returns: + List[KeyedTensor] + """ + results: List[KeyedTensor] = [] + + for idx, module in enumerate(self.embedding_modules): + res = module(input_list[idx]) + if isinstance(module, ShardedEmbeddingBagCollection): + results.append(res) + elif isinstance(module, ShardedEmbeddingCollection): + results.append(_pad_jt_values(res, self.max_sequence_lengths[idx])) + else: + raise ValueError(f"Unsupported output type {type(res)}") + + return results + + class TestSparseNNBase(nn.Module): """ Base class for a SparseNN model. @@ -1640,6 +1907,49 @@ def forward( return self.dense_forward(input, self.sparse_forward(input)) +class TestMultiSparseNN(nn.Module): + """ + Simple version of a SparseNN model. + + Args: + sparse: nn.Module, + dense: nn.Module, + over: nn.Module, + + Call Args: + input: ModelInput, + + Returns: + torch.Tensor + + Example: + TestMultiSparseNN() + """ + + def __init__(self, sparse: nn.Module, dense: nn.Module, over: nn.Module) -> None: + super().__init__() + self.sparse = sparse + self.dense = dense + self.over = over + + def forward( + self, model_input: ListModelInput + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + sparse_output = self.sparse(model_input.sparse_feature_list) + dense_output = self.dense(model_input.float_features) + over_output = self.over(dense_output, sparse_output) + pred = torch.sigmoid(torch.mean(over_output, dim=1)) + if self.training: + return ( + torch.nn.functional.binary_cross_entropy_with_logits( + pred, model_input.label + ), + pred, + ) + else: + return pred + + class TestTowerInteraction(nn.Module): """ Basic nn.Module for testing @@ -2341,7 +2651,28 @@ def forward( class TestMixedSequenceOverArch(nn.Module): - """Simple overarch that handles both pooled and flattened sequence embeddings""" + """ + Simple overarch that handles both pooled and flattened sequence embeddings + + Args: + ebc_tables: List[EmbeddingBagConfig], + ec_tables: List[EmbeddingConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + max_sequence_length: Optional[int], + dense_arch_out_size: Optional[int], + over_arch_out_size: Optional[int], + + Call Args: + dense: torch.Tensor, + sparse: torch.Tensor, + + Returns: + torch.Tensor + + Example: + >>> TestMixedSequenceOverArch(ebc_tables, ec_tables, weighted_tables, device) + """ def __init__( self, @@ -2364,7 +2695,7 @@ def __init__( if device is None: device = torch.device("cpu") if max_sequence_length is None: - max_sequence_length = 20 + max_sequence_length = MAX_SEQUENCE_LENGTH if dense_arch_out_size is None: dense_arch_out_size = DENSE_LAYER_OUT_SIZE if over_arch_out_size is None: @@ -2413,7 +2744,7 @@ def __init__( embedding_groups: Optional[Dict[str, List[str]]] = None, dense_device: Optional[torch.device] = None, sparse_device: Optional[torch.device] = None, - feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, + max_sequence_length: Optional[int] = None, over_arch_clazz: Type[nn.Module] = TestMixedSequenceOverArch, device: Optional[torch.device] = None, ) -> None: @@ -2428,6 +2759,7 @@ def __init__( ) if device is None: device = torch.device("cpu") + self.max_sequence_length: int = max_sequence_length or MAX_SEQUENCE_LENGTH ebc_tables: List[EmbeddingBagConfig] = [] ec_tables: List[EmbeddingConfig] = [] @@ -2541,8 +2873,8 @@ def sparse_forward( torch.ops.fbgemm.jagged_2d_to_dense( values=ec_result[e].values(), offsets=ec_result[e].offsets(), - max_sequence_length=20, - ).view(-1, 20 * self.ec_embedding_dim) + max_sequence_length=self.max_sequence_length, + ).view(-1, self.max_sequence_length * self.ec_embedding_dim) for e in self._ec_features ]