Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fsdp pytorch draft PR #823

Draft
wants to merge 13 commits into
base: dev
Choose a base branch
from
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__
.vscode/
env/
venv/
.venv/
workdir/
makefile
*.out
Expand All @@ -23,4 +24,7 @@ wandb/
scoring/plots/

!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

tags
*~
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
import functools
import random
from typing import Any, Dict, Optional, Tuple
from absl import logging

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
from torchvision import transforms
from torchvision.datasets import CIFAR10

Expand Down Expand Up @@ -135,7 +142,16 @@ def init_model_fn(
self._model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
self._model = DDP(self._model, device_ids=[RANK], output_device=RANK)
cifar_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
self._model = FSDP(
self._model,
use_orig_params=True,
auto_wrap_policy=cifar_auto_wrap_policy,
device_id=RANK
)
logging.info(f"Model shape: {self._model}")
else:
self._model = torch.nn.DataParallel(self._model)
return self._model, None
Expand All @@ -155,6 +171,7 @@ def model_fn(
del rng
model = params
if mode == spec.ForwardPassMode.EVAL:
model.zero_grad()
if update_batch_norm:
raise ValueError(
'Batch norm statistics cannot be updated during evaluation.')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
"""Criteo1TB workload implemented in PyTorch."""

import contextlib
import functools
from typing import Dict, Iterator, Optional, Tuple

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -93,7 +100,15 @@ def init_model_fn(
model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -117,6 +132,7 @@ def model_fn(
inputs = augmented_and_preprocessed_input_batch['inputs']

if mode == spec.ForwardPassMode.EVAL:
model.zero_grad()
model.eval()

if mode == spec.ForwardPassMode.TRAIN:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""FastMRI workload implemented in PyTorch."""

import contextlib
import functools
import math
from typing import Dict, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import pytorch_utils
Expand Down Expand Up @@ -125,7 +132,15 @@ def init_model_fn(
model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -148,6 +163,7 @@ def model_fn(
model = params

if mode == spec.ForwardPassMode.EVAL:
model.zero_grad()
model.eval()

if mode == spec.ForwardPassMode.TRAIN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
from torchvision import transforms
from torchvision.datasets.folder import ImageFolder

Expand Down Expand Up @@ -181,7 +187,15 @@ def init_model_fn(
model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -206,6 +220,7 @@ def model_fn(
if update_batch_norm:
raise ValueError(
'Batch norm statistics cannot be updated during evaluation.')
model.zero_grad()
model.eval()

if mode == spec.ForwardPassMode.TRAIN:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""ImageNet ViT workload implemented in PyTorch."""

import contextlib
import functools
from typing import Dict, Optional, Tuple

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import pytorch_utils
Expand Down Expand Up @@ -43,7 +50,15 @@ def init_model_fn(
model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -66,6 +81,9 @@ def model_fn(
model = params

if mode == spec.ForwardPassMode.EVAL:
# need to zero grad for FSDP or else error is thrown
# during evaluation
model.zero_grad()
model.eval()

if mode == spec.ForwardPassMode.TRAIN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import data_utils
from algorithmic_efficiency import param_utils
Expand Down Expand Up @@ -101,7 +107,15 @@ def init_model_fn(
if N_GPUS > 1:
if USE_PYTORCH_DDP:
self.requires_sync_before_eval = True
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -122,6 +136,7 @@ def model_fn(

model = params
if mode == spec.ForwardPassMode.EVAL:
model.zero_grad()
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Optional
import functools

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -58,7 +65,15 @@ def init_model_fn(
self.requires_sync_before_eval = False
if N_GPUS > 1:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

from collections import OrderedDict
import contextlib
import functools
from typing import Any, Dict, Iterator, Optional, Tuple

import torch
from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import init_utils
from algorithmic_efficiency import param_utils
Expand Down Expand Up @@ -140,7 +147,15 @@ def init_model_fn(
self._model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
self._model = DDP(self._model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
self._model = FSDP(
self._model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
self._model = torch.nn.DataParallel(self._model)
return self._model, None
Expand All @@ -161,6 +176,7 @@ def model_fn(
del update_batch_norm
model = params
if mode == spec.ForwardPassMode.EVAL:
model.zero_grad()
model.eval()
contexts = {
spec.ForwardPassMode.EVAL: torch.no_grad,
Expand Down
19 changes: 18 additions & 1 deletion algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
"""OGBG workload implemented in PyTorch."""
import contextlib
import functools
from typing import Any, Callable, Dict, Optional, Tuple

import jax
from jraph import GraphsTuple
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import pytorch_utils
Expand Down Expand Up @@ -156,7 +163,15 @@ def init_model_fn(
model.to(DEVICE)
if N_GPUS > 1:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=2 ** 10
)
model = FSDP(
model,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
device_id=RANK
)
else:
model = torch.nn.DataParallel(model)
return model, None
Expand All @@ -183,6 +198,8 @@ def model_fn(
if mode == spec.ForwardPassMode.TRAIN:
model.train()
elif mode == spec.ForwardPassMode.EVAL:
# need to zero grad for FSDP eval - it is unclear why
model.zero_grad()
model.eval()

contexts = {
Expand Down
Loading
Loading