diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 9303d5d7..7f7774c1 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -34,7 +34,7 @@ from modalities.registry.registry import Registry from modalities.running_env.cuda_env import CudaEnv from modalities.trainer import Trainer -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_total_number_of_trainable_parameters @click.group() @@ -255,7 +255,7 @@ def run(self, components: TrainingComponentsInstantiationModel): num_ranks=components.settings.cuda_env.world_size, ) wrapped_model = components.wrapped_model - num_params = compute_number_of_trainable_parameters(wrapped_model) + num_params = get_total_number_of_trainable_parameters(wrapped_model) components.evaluation_subscriber.consume_dict({"No. Parameters": num_params}) logging.info(f"Training model with {num_params} parameters.") diff --git a/src/modalities/checkpointing/torch/torch_checkpoint_loading.py b/src/modalities/checkpointing/torch/torch_checkpoint_loading.py index 24c6db9d..dde08f64 100644 --- a/src/modalities/checkpointing/torch/torch_checkpoint_loading.py +++ b/src/modalities/checkpointing/torch/torch_checkpoint_loading.py @@ -8,7 +8,7 @@ from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.config.config import PrecisionEnum -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_local_number_of_trainable_parameters class TorchCheckpointLoading(CheckpointLoadingIF): @@ -46,7 +46,7 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: # set the model to the correct device and precision # model = model.to(self.precision.value) print( - f"Model loaded with {compute_number_of_trainable_parameters(model)} trainable parameters from {file_path}" + f"Model loaded with {get_local_number_of_trainable_parameters(model)} trainable parameters from {file_path}" ) return model diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 09547b1f..04315baa 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -2,11 +2,11 @@ from typing import Any, Dict import rich +import wandb import yaml from rich.console import Group from rich.panel import Panel -import wandb from modalities.batch import EvaluationResultBatch from modalities.config.config import WandbMode from modalities.logging_broker.messages import Message diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 84fa2a67..cb7dbe38 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -11,7 +11,7 @@ from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.running_env.env_utils import MixedPrecisionSettings from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_local_number_of_trainable_parameters class ModelFactory: @@ -34,7 +34,8 @@ def get_fsdp_wrapped_model( sharding_strategy: ShardingStrategy, ) -> FSDP: print( - f"Unsharded number of parameters on rank {dist.get_rank()}: {compute_number_of_trainable_parameters(model)}" + f"Unsharded number of parameters on rank {dist.get_rank()}: " + f"{get_local_number_of_trainable_parameters(model)}" ) # Here, FSDPTransformerAutoWrapPolicyFactory is hardcoded and should be passed in instead! # we also might want to have different auto wrap policies later... @@ -52,7 +53,7 @@ def get_fsdp_wrapped_model( ) print( f"Sharded number of parameters on rank {dist.get_rank()}:" - f"{compute_number_of_trainable_parameters(fsdp_model)}" + f"{get_local_number_of_trainable_parameters(fsdp_model)}" ) return fsdp_model diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index 28540517..fbe5b380 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -9,7 +9,7 @@ from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.exceptions import OptimizerError from modalities.models.model import NNModel -from modalities.util import compute_number_of_trainable_parameters +from modalities.util import get_local_number_of_trainable_parameters OptimizerGroups = List[Dict[str, List[nn.Parameter] | float]] @@ -166,7 +166,7 @@ def _assert_completeness_of_optimizer_groups(model: FSDP, optimizer_groups: Opti checks that the number of parameters in the optimizer groups sum up to the total number of model parameters as expected """ - num_params_check = compute_number_of_trainable_parameters(model) + num_params_check = get_local_number_of_trainable_parameters(model) num_params = sum(p.numel() for optimizer_group in optimizer_groups for p in optimizer_group["params"]) if num_params != num_params_check: raise OptimizerError( diff --git a/src/modalities/util.py b/src/modalities/util.py index 62b61c40..b3bc6392 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -10,6 +10,8 @@ import torch import torch.distributed as dist from pydantic import ValidationError +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.types import Number from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.reducer import Reducer @@ -56,10 +58,18 @@ def format_metrics_to_gb(item): return metric_num -def compute_number_of_trainable_parameters(model: torch.nn.Module): +def get_local_number_of_trainable_parameters(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) +def get_total_number_of_trainable_parameters(model: FSDP) -> Number: + num_params = get_local_number_of_trainable_parameters(model) + num_params_tensor = torch.tensor(num_params).cuda() + dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM) + total_num_params = num_params_tensor.item() + return total_num_params + + class TimeRecorderStates(Enum): RUNNING = "RUNNING" STOPPED = "STOPPED" diff --git a/tests/test_utils.py b/tests/test_utils.py index 7b8ac28a..f5e9379d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,9 @@ import torch +import modalities +import modalities.util from modalities.batch import DatasetBatch +from modalities.util import get_local_number_of_trainable_parameters, get_total_number_of_trainable_parameters def configure_dataloader_mock( @@ -23,3 +26,46 @@ def configure_dataloader_mock( llm_data_loader_mock.__len__ = lambda _: num_batches return llm_data_loader_mock, batches + + +def test_get_local_number_of_trainable_parameters(): + # Create a simple model with trainable parameters + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) + + # Calculate the expected number of trainable parameters + expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Call the function and check the result + assert get_local_number_of_trainable_parameters(model) == expected_params + + +def test_get_total_number_of_trainable_parameters(): + # Create a simple model with trainable parameters + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 2)) + + # Calculate the expected number of trainable parameters + expected_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Create a mock FSDP model + class MockFSDP: + def __init__(self, model): + self.model = model + + fsdp_model = MockFSDP(model) + + # Mock the dist.all_reduce function + def mock_all_reduce(tensor, op): + tensor.item = lambda: tensor + return tensor + + def mock_cuda(tensor): + return tensor + + def mock_get_local_number_of_trainable_parameters(model: MockFSDP): + return get_local_number_of_trainable_parameters(model.model) + + modalities.util.get_local_number_of_trainable_parameters = mock_get_local_number_of_trainable_parameters + torch.distributed.all_reduce = mock_all_reduce + torch.Tensor.cuda = mock_cuda + + assert get_total_number_of_trainable_parameters(fsdp_model) == expected_params diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index 688986e1..527d36e2 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -299,14 +299,7 @@ batch_progress_subscriber: instance_key: eval_dataloaders pass_type: BY_REFERENCE - evaluation_subscriber: component_key: results_subscriber - variant_key: wandb - config: - local_rank: ${settings.cuda_env.local_rank} - project: modalities_lorem_ipsum - mode: ONLINE - directory: "." - experiment_id: ${settings.experiment_id} - config_file_path: ${settings.config_file_path} \ No newline at end of file + variant_key: dummy + config: {} \ No newline at end of file