Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchrec/distributed/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ def _run_benchmark_core(
export_stacks: Whether to export flamegraph-compatible stack files.
reset_accumulated_memory_stats: Whether to reset accumulated memory
stats in addition to peak memory stats.
all_rank_traces: Whether to save traces for all ranks or just rank 0.
memory_snapshot: Whether to capture memory snapshot during the profiling
ussage: https://docs.pytorch.org/memory_viz
"""

# Preparation & memory reset
Expand Down Expand Up @@ -912,6 +915,7 @@ def benchmark_func(
export_stacks: Whether to export flamegraph-compatible stack files.
all_rank_traces: Whether to export traces from all ranks.
memory_snapshot: Whether to capture memory snapshot during the profiling
ussage: https://docs.pytorch.org/memory_viz
"""
if benchmark_func_kwargs is None:
benchmark_func_kwargs = {}
Expand Down
54 changes: 33 additions & 21 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchrec.distributed.test_utils.input_config import ModelInputConfig
from torchrec.distributed.test_utils.model_config import (
BaseModelConfig,
DLSeqConfig,
generate_sharded_model_and_optimizer,
ModelSelectionConfig,
)
Expand Down Expand Up @@ -116,8 +117,7 @@ class RunOptions(BenchFuncConfig):
def runner(
rank: int,
world_size: int,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
table_list: List[List[EmbeddingBagConfig]],
run_option: RunOptions,
model_config: BaseModelConfig,
pipeline_config: PipelineConfig,
Expand All @@ -136,21 +136,34 @@ def runner(
backend="nccl",
use_deterministic_algorithms=False,
) as ctx:
unsharded_model = model_config.generate_model(
tables=tables,
weighted_tables=weighted_tables,
dense_device=ctx.device,
)

# Create a planner for sharding based on the specified type
planner = planner_config.generate_planner(
tables=tables + weighted_tables,
)

bench_inputs = input_config.generate_batches(
tables=tables,
weighted_tables=weighted_tables,
)
if isinstance(model_config, DLSeqConfig):
bench_inputs = input_config.generate_list_batches(
tables=table_list, # pyre-ignore
table_options=model_config.table_options,
)
unsharded_model = model_config.generate_model(
tables=table_list, # pyre-ignore
dense_device=ctx.device,
)
planner = planner_config.generate_planner(
tables=[table for tables in table_list for table in tables],
)
else:
tables, weighted_tables, *_ = table_list
bench_inputs = input_config.generate_std_batches(
tables=tables,
weighted_tables=weighted_tables,
)
unsharded_model = model_config.generate_model(
tables=tables,
weighted_tables=weighted_tables,
dense_device=ctx.device,
)
# Create a planner for sharding based on the specified type
planner = planner_config.generate_planner(
tables=tables + weighted_tables,
)

# Prepare fused_params for sparse optimizer
fused_params = {
Expand Down Expand Up @@ -228,8 +241,7 @@ def run_pipeline(
benchmark_res_per_rank = run_multi_process_func(
func=runner,
world_size=run_option.world_size,
tables=tables,
weighted_tables=weighted_tables,
table_list=[tables, weighted_tables],
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
Expand Down Expand Up @@ -270,14 +282,14 @@ def main(
input_config: ModelInputConfig,
planner_config: PlannerConfig,
) -> None:
tables, weighted_tables, *_ = table_config.generate_tables()
table_list = table_config.generate_tables()
model_config = model_selection.create_model_config()

# launch trainers
run_multi_process_func(
func=runner,
world_size=run_option.world_size,
tables=tables,
weighted_tables=weighted_tables,
table_list=table_list,
run_option=run_option,
model_config=model_config,
pipeline_config=pipeline_config,
Expand Down
55 changes: 55 additions & 0 deletions torchrec/distributed/benchmark/yaml/sparse_data_dist_seq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# this is a very basic sparse data dist config
# runs on 2 ranks, showing traces with reasonable workloads
RunOptions:
world_size: 2
num_batches: 5
num_benchmarks: 2
sharding_type: table_wise
profile_dir: "."
name: "sparse_data_dist_seq"
# export_stacks: True # enable this to export stack traces
PipelineConfig:
pipeline: "sparse"
ModelInputConfig:
feature_pooling_avg: 10
EmbeddingTablesConfig:
num_unweighted_features: 100
num_weighted_features: 100
embedding_feature_dim: 128
additional_tables:
- - name: FP16_table
embedding_dim: 512
num_embeddings: 100_000
feature_names: ["additional_0_0"]
data_type: FP16
- name: large_table
embedding_dim: 2048
num_embeddings: 1_000_000
feature_names: ["additional_0_1"]
- []
- - name: ec_likes
embedding_dim: 32
num_embeddings: 100_000
feature_names: ["user_likes"]
config_class: EmbeddingConfig
- name: ec_clicks
embedding_dim: 32
num_embeddings: 100_000
feature_names: ["user_clicks"]
config_class: EmbeddingConfig
PlannerConfig:
additional_constraints:
large_table:
sharding_types: [row_wise]
ec_likes:
sharding_types: [column_wise]
ec_clicks:
sharding_types: [column_wise]
ModelSelectionConfig:
model_name: "dlseq"
model_config:
num_float_features: 10
table_options:
- {}
- is_weighted: True
- {}
51 changes: 47 additions & 4 deletions torchrec/distributed/test_utils/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
# pyre-strict

from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union

import torch
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_configs import (
EmbeddingBagConfig,
EmbeddingConfig,
EmbeddingTableConfig,
)

from .model_input import ModelInput
from .model_input import ListModelInput, ModelInput


@dataclass
Expand All @@ -31,7 +35,7 @@ class ModelInputConfig:
long_kjt_lengths: bool = True
pin_memory: bool = True

def generate_batches(
def generate_std_batches(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
Expand Down Expand Up @@ -63,3 +67,42 @@ def generate_batches(
)
for batch_size in range(self.num_batches)
]

def generate_list_batches(
self,
tables: List[
Union[
List[EmbeddingTableConfig],
List[EmbeddingBagConfig],
List[EmbeddingConfig],
]
],
table_options: Optional[List[Dict[str, Any]]] = None,
) -> List[ListModelInput]:
"""
Generate model input data for benchmarking.

Args:
tables: List of embedding tables

Returns:
A list of ModelInput objects representing the generated batches
"""
device = torch.device(self.device) if self.device is not None else None

return [
ListModelInput.generate(
batch_size=self.batch_size,
tables=tables,
num_float_features=self.num_float_features,
pooling_avg=self.feature_pooling_avg,
use_offsets=self.use_offsets,
device=device,
indices_dtype=(torch.int64 if self.long_kjt_indices else torch.int32),
offsets_dtype=(torch.int64 if self.long_kjt_offsets else torch.int32),
lengths_dtype=(torch.int64 if self.long_kjt_lengths else torch.int32),
pin_memory=self.pin_memory,
table_options=table_options,
)
for batch_size in range(self.num_batches)
]
46 changes: 45 additions & 1 deletion torchrec/distributed/test_utils/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.sharding_plan import get_default_sharders
from torchrec.distributed.test_utils.test_model import (
TestDenseArch,
TestMultiEmbSparseArch,
TestMultiSparseNN,
TestOverArchLarge,
TestOverArchMultiEmb,
TestSparseNN,
TestTowerCollectionSparseNN,
TestTowerSparseNN,
)
from torchrec.distributed.types import ShardingEnv
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
from torchrec.models.dlrm import DLRMWrapper
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection


Expand Down Expand Up @@ -213,6 +217,44 @@ def generate_model(
)


@dataclass
class DLSeqConfig(BaseModelConfig):
"""Configuration for DLRM model."""

table_options: Optional[List[Dict[str, Any]]] = None
dense_arch_out_size: Optional[int] = None
dense_arch_layer_sizes: Optional[List[int]] = None
over_arch_out_size: Optional[int] = None
over_arch_hidden_layers: Optional[int] = None

# pyre-ignore[14]
def generate_model(
self,
tables: List[Union[List[EmbeddingBagCollection], List[EmbeddingConfig]]],
device: Optional[torch.device] = None,
**kwargs: Any,
) -> nn.Module:
if device is None:
device = torch.device("cpu")

# DLRM only uses unweighted tables
sparse = TestMultiEmbSparseArch(tables, self.table_options, device=device)
dense = TestDenseArch(
self.num_float_features,
device=device,
dense_arch_hidden_sizes=self.dense_arch_layer_sizes,
dense_arch_out_size=self.dense_arch_out_size,
)
over = TestOverArchMultiEmb(
tables,
device=device,
dense_arch_out_size=self.dense_arch_out_size,
over_arch_out_size=self.over_arch_out_size,
over_arch_hidden_layers=self.over_arch_hidden_layers,
)
return TestMultiSparseNN(sparse, dense, over)


# pyre-ignore[2]: Missing parameter annotation
def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
"""
Expand Down Expand Up @@ -334,6 +376,8 @@ def get_model_config_class(self) -> Type[BaseModelConfig]:
return DeepFMConfig
case "dlrm":
return DLRMConfig
case "dlseq":
return DLSeqConfig
case _:
raise ValueError(f"Unknown model name: {self.model_name}")

Expand Down
Loading
Loading