Skip to content

Commit

Permalink
Resolved torchrun Bug: Fixed issue #2163
Browse files Browse the repository at this point in the history
Updated torch.distributed.launch to torchrun.
  • Loading branch information
anxiangsir committed Feb 8, 2023
1 parent b974acc commit bf32ec2
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 26 deletions.
8 changes: 4 additions & 4 deletions recognition/arcface_torch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ To train a model, execute the `train.py` script with the path to the configurati
### 1. To run on a machine with 8 GPUs:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py configs/ms1mv3_r50_lr02
torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
```

### 2. To run on 2 machines with 8 GPUs each:

Node 0:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
```

Node 1:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
```

### 3. Run ViT-B on a machine with 24k batchsize:

```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12345 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
```


Expand Down
9 changes: 1 addition & 8 deletions recognition/arcface_torch/run.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@

CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=7 \
--nnodes=1 \
--node_rank=0 \
--master_addr="127.0.0.1" \
--master_port=12345 train.py $@

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@
13 changes: 7 additions & 6 deletions recognition/arcface_torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future."

try:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
distributed.init_process_group("nccl")
except KeyError:
world_size = 1
rank = 0
local_rank = 0
world_size = 1
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
Expand All @@ -43,7 +45,7 @@ def main(args):
# global control random seed
setup_seed(seed=cfg.seed, cuda_deterministic=False)

torch.cuda.set_device(args.local_rank)
torch.cuda.set_device(local_rank)

os.makedirs(cfg.output, exist_ok=True)
init_logging(rank, cfg.output)
Expand Down Expand Up @@ -82,7 +84,7 @@ def main(args):

train_loader = get_dataloader(
cfg.rec,
args.local_rank,
local_rank,
cfg.batch_size,
cfg.dali,
cfg.seed,
Expand All @@ -93,7 +95,7 @@ def main(args):
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()

backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16,
module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
find_unused_parameters=True)

backbone.train()
Expand Down Expand Up @@ -255,5 +257,4 @@ def main(args):
parser = argparse.ArgumentParser(
description="Distributed Arcface Training in Pytorch")
parser.add_argument("config", type=str, help="py config file")
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
main(parser.parse_args())
17 changes: 9 additions & 8 deletions recognition/arcface_torch/train_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future."

try:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
distributed.init_process_group("nccl")
except KeyError:
world_size = 1
rank = 0
local_rank = 0
world_size = 1
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
Expand All @@ -43,7 +45,7 @@ def main(args):
# global control random seed
setup_seed(seed=cfg.seed, cuda_deterministic=False)

torch.cuda.set_device(args.local_rank)
torch.cuda.set_device(local_rank)

os.makedirs(cfg.output, exist_ok=True)
init_logging(rank, cfg.output)
Expand All @@ -55,7 +57,7 @@ def main(args):
)

wandb_logger = None
if using_wandb:
if cfg.using_wandb:
import wandb
# Sign in to wandb
try:
Expand All @@ -78,11 +80,11 @@ def main(args):
wandb_logger.config.update(cfg)
except Exception as e:
print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
rint(f"Config Error: {e}")
print(f"Config Error: {e}")

train_loader = get_dataloader(
cfg.rec,
args.local_rank,
local_rank,
cfg.batch_size,
cfg.dali,
cfg.seed,
Expand All @@ -93,7 +95,7 @@ def main(args):
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()

backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16,
module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
find_unused_parameters=True)

backbone.train()
Expand Down Expand Up @@ -253,5 +255,4 @@ def main(args):
parser = argparse.ArgumentParser(
description="Distributed Arcface Training in Pytorch")
parser.add_argument("config", type=str, help="py config file")
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
main(parser.parse_args())

0 comments on commit bf32ec2

Please sign in to comment.