From f8e5f14bb8a9f8e200897810b2ce71d54dfb147b Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 15 Jul 2024 11:32:29 +0200 Subject: [PATCH 1/6] fix: total num of parameter computation --- src/modalities/__main__.py | 4 ++-- src/modalities/util.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 9303d5d7..7f7774c1 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -34,7 +34,7 @@ from modalities.registry.registry import Registry from modalities.running_env.cuda_env import CudaEnv from modalities.trainer import Trainer -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_total_number_of_trainable_parameters @click.group() @@ -255,7 +255,7 @@ def run(self, components: TrainingComponentsInstantiationModel): num_ranks=components.settings.cuda_env.world_size, ) wrapped_model = components.wrapped_model - num_params = compute_number_of_trainable_parameters(wrapped_model) + num_params = get_total_number_of_trainable_parameters(wrapped_model) components.evaluation_subscriber.consume_dict({"No. Parameters": num_params}) logging.info(f"Training model with {num_params} parameters.") diff --git a/src/modalities/util.py b/src/modalities/util.py index 62b61c40..bb9aca14 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -10,6 +10,7 @@ import torch import torch.distributed as dist from pydantic import ValidationError +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.reducer import Reducer @@ -56,10 +57,18 @@ def format_metrics_to_gb(item): return metric_num -def compute_number_of_trainable_parameters(model: torch.nn.Module): +def get_local_number_of_trainable_parameters(model: torch.nn.Module): return sum(p.numel() for p in model.parameters() if p.requires_grad) +def get_total_number_of_trainable_parameters(model: FSDP): + num_params = get_local_number_of_trainable_parameters(model) + num_params_tensor = torch.tensor(num_params).cuda() + dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM) + total_num_params = num_params_tensor.item() + return total_num_params + + class TimeRecorderStates(Enum): RUNNING = "RUNNING" STOPPED = "STOPPED" From e909849b8fe3214682c856640bf9177bc873a90b Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 15 Jul 2024 11:35:32 +0200 Subject: [PATCH 2/6] refactor: update to new API --- .../checkpointing/torch/torch_checkpoint_loading.py | 4 ++-- src/modalities/models/model_factory.py | 7 ++++--- src/modalities/optimizers/optimizer_factory.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/modalities/checkpointing/torch/torch_checkpoint_loading.py b/src/modalities/checkpointing/torch/torch_checkpoint_loading.py index 24c6db9d..dde08f64 100644 --- a/src/modalities/checkpointing/torch/torch_checkpoint_loading.py +++ b/src/modalities/checkpointing/torch/torch_checkpoint_loading.py @@ -8,7 +8,7 @@ from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.config.config import PrecisionEnum -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_local_number_of_trainable_parameters class TorchCheckpointLoading(CheckpointLoadingIF): @@ -46,7 +46,7 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: # set the model to the correct device and precision # model = model.to(self.precision.value) print( - f"Model loaded with {compute_number_of_trainable_parameters(model)} trainable parameters from {file_path}" + f"Model loaded with {get_local_number_of_trainable_parameters(model)} trainable parameters from {file_path}" ) return model diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 84fa2a67..cb7dbe38 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -11,7 +11,7 @@ from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.running_env.env_utils import MixedPrecisionSettings from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_local_number_of_trainable_parameters class ModelFactory: @@ -34,7 +34,8 @@ def get_fsdp_wrapped_model( sharding_strategy: ShardingStrategy, ) -> FSDP: print( - f"Unsharded number of parameters on rank {dist.get_rank()}: {compute_number_of_trainable_parameters(model)}" + f"Unsharded number of parameters on rank {dist.get_rank()}: " + f"{get_local_number_of_trainable_parameters(model)}" ) # Here, FSDPTransformerAutoWrapPolicyFactory is hardcoded and should be passed in instead! # we also might want to have different auto wrap policies later... @@ -52,7 +53,7 @@ def get_fsdp_wrapped_model( ) print( f"Sharded number of parameters on rank {dist.get_rank()}:" - f"{compute_number_of_trainable_parameters(fsdp_model)}" + f"{get_local_number_of_trainable_parameters(fsdp_model)}" ) return fsdp_model diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index 28540517..fbe5b380 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -9,7 +9,7 @@ from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.exceptions import OptimizerError from modalities.models.model import NNModel -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_local_number_of_trainable_parameters OptimizerGroups = List[Dict[str, List[nn.Parameter] | float]] @@ -166,7 +166,7 @@ def _assert_completeness_of_optimizer_groups(model: FSDP, optimizer_groups: Opti checks that the number of parameters in the optimizer groups sum up to the total number of model parameters as expected """ - num_params_check = compute_number_of_trainable_parameters(model) + num_params_check = get_local_number_of_trainable_parameters(model) num_params = sum(p.numel() for optimizer_group in optimizer_groups for p in optimizer_group["params"]) if num_params != num_params_check: raise OptimizerError( From ba6e6fabb516d481bd7342d40536443e3d62e025 Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 15 Jul 2024 12:04:22 +0200 Subject: [PATCH 3/6] refactor: fix import order --- .../logging_broker/subscriber_impl/results_subscriber.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 09547b1f..04315baa 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -2,11 +2,11 @@ from typing import Any, Dict import rich +import wandb import yaml from rich.console import Group from rich.panel import Panel -import wandb from modalities.batch import EvaluationResultBatch from modalities.config.config import WandbMode from modalities.logging_broker.messages import Message From fd93b14fb1eb76a729fe18a0a66fb736219fd072 Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 15 Jul 2024 12:05:25 +0200 Subject: [PATCH 4/6] test: add tests for num parameter computation --- tests/test_utils.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7b8ac28a..719616be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,9 @@ import torch +import modalities +import modalities.util from modalities.batch import DatasetBatch +from modalities.util import get_local_number_of_trainable_parameters, get_total_number_of_trainable_parameters def configure_dataloader_mock( @@ -23,3 +26,47 @@ def configure_dataloader_mock( llm_data_loader_mock.__len__ = lambda _: num_batches return llm_data_loader_mock, batches + + +def test_get_local_number_of_trainable_parameters(): + # Create a simple model with trainable parameters + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) + + # Calculate the expected number of trainable parameters + expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Call the function and check the result + assert get_local_number_of_trainable_parameters(model) == expected_params + + +def test_get_total_number_of_trainable_parameters(): + # Create a simple model with trainable parameters + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) + + # Calculate the expected number of trainable parameters + expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Create a mock FSDP model + class MockFSDP: + def __init__(self, model): + self.model = model + + fsdp_model = MockFSDP(model) + + # Mock the dist.all_reduce function + def mock_all_reduce(tensor, op): + tensor.item = lambda: tensor + return tensor + + def mock_cuda(tensor): + return tensor + + def mock_get_local_number_of_trainable_parameters(model: MockFSDP): + return get_local_number_of_trainable_parameters(model.model) + + modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters + torch.distributed.all_reduce = mock_all_reduce + torch.Tensor.cuda = mock_cuda + + # Call the function and check the result + assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params From 6319ff08d2475b31ec31e52daf25291594178b12 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Mon, 15 Jul 2024 13:54:13 +0200 Subject: [PATCH 5/6] refactor: replaced the wand results subscriber with the dummy one, to allow for parallel test execution --- tests/test_utils.py | 1 - tests/test_yaml_configs/config_lorem_ipsum.yaml | 11 ++--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 719616be..f5e9379d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -68,5 +68,4 @@ def mock_get_local_number_of_trainable_parameters(model: MockFSDP): torch.distributed.all_reduce = mock_all_reduce torch.Tensor.cuda = mock_cuda - # Call the function and check the result assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index 688986e1..527d36e2 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -299,14 +299,7 @@ batch_progress_subscriber: instance_key: eval_dataloaders pass_type: BY_REFERENCE - evaluation_subscriber: component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.local_rank} - project: modalities_lorem_ipsum - mode: ONLINE - directory: "." - experiment_id: ${settings.experiment_id} - config_file_path: ${settings.config_file_path} \ No newline at end of file + variant_key: dummy + config: {} \ No newline at end of file From 2715fbb79da406d678f7dc7314c6e69f0c0937e6 Mon Sep 17 00:00:00 2001 From: mali-git Date: Mon, 15 Jul 2024 15:34:46 +0200 Subject: [PATCH 6/6] refactor: add missing type annotations --- src/modalities/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/modalities/util.py b/src/modalities/util.py index bb9aca14..b3bc6392 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -11,6 +11,7 @@ import torch.distributed as dist from pydantic import ValidationError from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.types import Number from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.reducer import Reducer @@ -57,11 +58,11 @@ def format_metrics_to_gb(item): return metric_num -def get_local_number_of_trainable_parameters(model: torch.nn.Module): +def get_local_number_of_trainable_parameters(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) -def get_total_number_of_trainable_parameters(model: FSDP): +def get_total_number_of_trainable_parameters(model: FSDP) -> Number: num_params = get_local_number_of_trainable_parameters(model) num_params_tensor = torch.tensor(num_params).cuda() dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM)