Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

[WIP] Initially working multi-gpu training #71

Merged
merged 18 commits into from
Feb 19, 2021

Conversation

pzelasko
Copy link
Collaborator

@pzelasko pzelasko commented Jan 7, 2021

I've only seen ~300 steps but the model seems to be converging alright, and both GPUs are close to 100% load. This probably needs further work to make sure the checkpoints work ok and the dev data scores are aggregated from all nodes. It also needs a few ifs to handle both single-gpu and multi-gpu correctly.

You can run this like:

# terminal 1
python mmi_bigram_train.py --world-size 2 --rank 0
# terminal 2
python mmi_bigram_train.py --world-size 2 --rank 1

You probably expected that to work already, but it's nice to see K2 running on multi-gpu :)

@pzelasko pzelasko changed the title Initially working multi-gpu training [dont merge] Initially working multi-gpu training Jan 7, 2021
@pzelasko
Copy link
Collaborator Author

pzelasko commented Jan 7, 2021

I also need to check if Lhotse dataset/dataloader does not return duplicate cuts in the same epoch.

@danpovey
Copy link
Contributor

danpovey commented Jan 7, 2021

Great!!

@csukuangfj csukuangfj mentioned this pull request Jan 19, 2021
@hegc
Copy link

hegc commented Jan 26, 2021

K2SpeechRecognitionIterableDataset has no 'len()', so torch.utils.data.distributed.DistributedSampler cannot split the dataset to subsets.

@pzelasko
Copy link
Collaborator Author

That is expected - it's an "iterable dataset", so it cannot use samplers. I still have to check how to make it compatible with distributed training but I don't expect any hurdles (it already supports splitting datasets into partitions for parallel dataloader workers, will probably just have to split into more partitions for distributed training).

@pzelasko
Copy link
Collaborator Author

pzelasko commented Feb 9, 2021

This seems to work correctly with Lhotse's PR lhotse-speech/lhotse#194

I verified that the cuts are not duplicated in an epoch by dumping the cut IDs from individual workers' partitions into files and comparing them. I went on added loss/num_frames synchronization to the master node so that we're logging those quantities correctly. I'll let the training finish and see if it works OK till the end (and what is the WER).

I might have another idea that could make this a bit simpler to use -- DataLoader has a "batch_sampler" argument so maybe dynamic batching could be performed inside of that instead of in the Dataset. Then we could return to using "map-style" Datasets and I think Lhotse's code could be greatly simplified + the whole thing would be closer to standard PyTorch workflows. Let me check that out before we merge these things.

@danpovey
Copy link
Contributor

danpovey commented Feb 9, 2021

Cool!

@pzelasko pzelasko changed the title [dont merge] Initially working multi-gpu training [WIP] Initially working multi-gpu training Feb 10, 2021
@pzelasko
Copy link
Collaborator Author

The LFMMI training seems to work both with single and multi GPU now; I'll post the results for 2 GPU once the training is done and then if it looks ok, we can merge.

@pzelasko
Copy link
Collaborator Author

Now that I try to train the full thing with 2 GPUs, I see the training consistently hanging in epoch 1 at about 1000 steps. Both GPU and CPU usage is shown as 100% but the training stops progressing (and GPU power use is very low, 70 / 250 W, which I observed in the past is often an indicator of it not being actually utilized).

I started debugging by inspecting python's stack using py-spy, attaching to the hanged process. It pointed to SyncBatchNorm's forward pass, so I removed it and re-run. It didn't help - this time it hanged inside one of K2's functions (P.set_scores_stochastic_).

Then, I checked the native stack strace by attaching gdb to the main process, and this is it:

(gdb) bt
#0  0x00007ffde9281b12 in clock_gettime ()
#1  0x00002b6762da4ba6 in __GI___clock_gettime (clock_id=4, tp=0x7ffde9229800) at ../sysdeps/unix/clock_gettime.c:115
#2  0x00002b67dbe1537e in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00002b67dbed94f7 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00002b67dbefb439 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#5  0x00002b67dbdeace8 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#6  0x00002b67dbcf04e2 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#7  0x00002b67dbcf2654 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#8  0x00002b67dbe7ee93 in cuMemcpyDtoH_v2 () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#9  0x00002b677cd3755a in ?? ()
   from /home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#10 0x00002b677cd17266 in ?? ()
   from /home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#11 0x00002b677cd39f08 in cudaMemcpy ()
   from /home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#12 0x00002b67de5ee2af in k2::Array1<int>::operator[](int) const () from /home/pzelasko/k2-repo/build/lib/libk2context.so
#13 0x00002b67de65ef4a in k2::Array1<int>::Back() const () from /home/pzelasko/k2-repo/build/lib/libk2context.so
#14 0x00002b67de725456 in k2::RaggedShape::TotSize(int) const () from /home/pzelasko/k2-repo/build/lib/libk2context.so
---Type <return> to continue, or q <return> to quit---
#15 0x00002b67db382fcc in ?? ()
   from /home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so
#16 0x00002b67db32af96 in ?? ()
   from /home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so
#17 0x00005626b96a363d in _PyMethodDef_RawFastCallDict ()
    at /tmp/build/80754af9/python_1565725737370/work/Objects/call.c:515

I am not sure what to make of it, or whether it's a K2 or not K2 issue (interestingly, practically the same code seems not to hang in a different project I'm working on).

@pzelasko
Copy link
Collaborator Author

Clarification: by "the same code" I meant the same method of distributed training (DDP, sync batch norm, setup, cleanup, Lhotse's dataloaders) - not the model code, and the project is not using K2.

@pzelasko
Copy link
Collaborator Author

I verified the issue occurs on a completely different machine. I'm building the latest K2 from master, with CUDA 10.2 and PyTorch 1.7.1. The GPUs are GTX1080/RTX2080.

@danpovey
Copy link
Contributor

Mm. See if you can run it in cuda-gdb, the command info cuda kernels should say what kernels are running.. that should narrow it down.

@pzelasko
Copy link
Collaborator Author

This is what I got:

(cuda-gdb) info cuda kernels
  Kernel Parent Dev Grid Status   SMs Mask GridDim BlockDim Invocation
*      0      -   0 52162520 Active 0x00000001 (1,1,1) (64,1,1) ncclReduceRingLLKernel_sum_f64()

The backtrace (bt) shows only this cuda kernel too, so I attached with normal gdb again and the stack trace was the same as I previously shared. Does it give you any ideas?

@danpovey
Copy link
Contributor

danpovey commented Feb 12, 2021 via email

@pzelasko
Copy link
Collaborator Author

Your suggestion gave me another idea, so this time I ran it with CUDA_LAUNCH_BLOCKING=1, and the call stack is different (maybe more informative? not sure). Anyway the process segfaulted as a consequence of attaching cuda-gdb to it...

Reading symbols from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/libcaffe2_nvrtc.so...(no debugging symbols found)...done.                                                                                      [9/1227]
0x00002aaaaaacd6c2 in clock_gettime ()
The CUDA driver could not allocate operating system resources for attaching to the application.

An error occurred while in a function called from GDB.
Evaluation of the expression containing the function
(cudbgApiAttach) will be abandoned.
When the function is done executing, GDB will silently stop.
(cuda-gdb) bt
#0  0x00002aab296fae60 in cudbgReportDriverInternalError () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#1  0x00002aab29700ff6 in cudbgReportDriverInternalError () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#2  <function called from gdb>
#3  0x00002aaaaaacd6c2 in clock_gettime ()
#4  0x00002aaaaafff96d in clock_gettime () from /lib64/libc.so.6
#5  0x00002aab297d881e in cuEGLApiInit () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#6  0x00002aab298bcd14 in cuVDPAUCtxCreate () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#7  0x00002aab29775ef9 in cudbgApiDetach () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#8  0x00002aab297760b0 in cudbgApiDetach () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#9  0x00002aab297a5623 in cuEGLApiInit () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#10 0x00002aab2993e790 in cuVDPAUCtxCreate () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#11 0x00002aab296a83e7 in ?? () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#12 0x00002aab296a8850 in ?? () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#13 0x00002aab296a893e in ?? () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#14 0x00002aab29868fc2 in cuLaunchKernel () from /cm/local/apps/cuda/libs/current/lib64/libcuda.so.1
#15 0x00002aaacb16a547 in __cudaInitModule () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#16 0x00002aaacb16a5d7 in __cudaInitModule () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#17 0x00002aaacb1a192b in cudaLaunchKernel () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#18 0x00002aaad9583cf2 in ncclBarrierEnqueueWait (comm=0x2aab90000e00) at enqueue.cc:215
#19 0x00002aaad9585afb in ncclGroupEnd () at group.cc:282
#20 0x00002aaace048551 in c10d::(anonymous namespace)::AutoNcclGroup::~AutoNcclGroup() [clone .isra.461] () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#21 0x00002aaace052df4 in c10d::ProcessGroup::Work c10d::ProcessGroupNCCL::collective<c10d::ProcessGroupNCCL::reduce(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}, std::shared_ptr<c10d::ProcessGroup::Work> c10d::ProcessGroupNCCL::collective<{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}>(std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<at::Tensor, std::allocator<at::Tensor> >&, {lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1})::{lambda(std::vector<c10::cuda::CUDAStream, std::allocator<c10::cuda::CUDAStream> >&)#1}, c10d::ProcessGroup::Work c10d::ProcessGroupNCCL::reduce(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}<{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}>(std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::shared_ptr<c10d::ProcessGroup::Work>)::{lambda(std::vector<c10::cuda::CUDAStream, std::allocator<c10::cuda::CUDAStream> >)#2}>(std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::shared_ptr<c10d::ProcessGroup::Work>, std::shared_ptr<c10d::ProcessGroup::Work> c10d::ProcessGroupNCCL::collective<{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}>(std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<at::Tensor, std::allocator<at::Tensor> >&, {lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1})::{lambda(std::vector<c10::cuda::CUDAStream, std::allocator<c10::cuda::CUDAStream> >&)#1}, c10d::ProcessGroup::Work c10d::ProcessGroupNCCL::reduce(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&)::{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}<{lambda(at::Tensor&, at::Tensor&, ncclComm*, c10::cuda::CUDAStream&)#1}>(std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::vector<at::Tensor, std::allocator<at::Tensor> >&, std::shared_ptr<c10d::ProcessGroup::Work>)::{lambda(std::vector<c10::cuda::CUDAStream, std::allocator<c10::cuda::CUDAStream> >)#2}) [clone .isra.793] () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#22 0x00002aaace053815 in c10d::ProcessGroupNCCL::reduce(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&) () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#23 0x00002aaacdf16960 in void pybind11::cpp_function::initialize<pybind11::cpp_function::initialize<std::shared_ptr<c10d::ProcessGroup::Work>, c10d::ProcessGroup, std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(std::shared_ptr<c10d::ProcessGroup::Work> (c10d::ProcessGroup::*)(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(c10d::ProcessGroup*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&)#1}, std::shared_ptr<c10d::ProcessGroup::Work>, c10d::ProcessGroup*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(pybind11::cpp_function::initialize<std::shared_ptr<c10d::ProcessGroup::Work>, c10d::ProcessGroup, std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg_v, pybind11::call_guard<pybind11::gil_scoped_release> >(std::shared_ptr<c10d::ProcessGroup::Work> (c10d::ProcessGroup::*)(std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(c10d::ProcessGroup*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&)#1}&&, std::shared_ptr<c10d::ProcessGroup::Work> (*)(c10d::ProcessGroup*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, c10d::ReduceOptions const&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::call_guard<pybind11::gil_scoped_release> const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#24 0x00002aaacd8ea7a0 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#25 0x00005555556b9914 in _PyMethodDef_RawFastCallKeywords () at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:693

@pzelasko
Copy link
Collaborator Author

I found it! The culprit was "torch.distributed.reduce" which causes the NCCL to hang for some reason. After I discovered that I started searching if others had this issue and there are plenty issues in PyTorch's repo about reduce/allreduce + NCCL.

I will see if it helps when I replace it with "gather" and sum manually.

@pzelasko
Copy link
Collaborator Author

It seems something gets consistently stuck after 3000 steps when I use either reduce/gather/all_reduce/all_gather. I've simply removed the loss/num_frames sync across the GPUs, so the reported values are just from the master process. It is a half-measure but one I can live with for now (unless somebody has a better idea). I set up the validation dataloader so that it evaluates the full set on each GPU, so the validation values are presented for the full dev set. The training runs for 3 epochs now - once it finishes, I will update the RESULTS file, resolve conflicts and merge.

@danpovey
Copy link
Contributor

danpovey commented Feb 13, 2021 via email

@pzelasko
Copy link
Collaborator Author

Unfortunately, I didn't. After these changes, it just hangs later (epoch 9). The WER of epoch 8 is 10.74% so it seems to be working correctly, until the NCCL deadlock.

I tried chaging the environment - I used CUDA 11 + CUDNN 8.0.4 and re-built K2 with them. But then, curiously, I'm getting the following error:

Traceback (most recent call last):
  File "./mmi_bigram_train.py", line 568, in <module>
    main()
  File "./mmi_bigram_train.py", line 501, in main
    objf, valid_objf, global_batch_idx_train = train_one_epoch(
  File "./mmi_bigram_train.py", line 233, in train_one_epoch
    curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
  File "./mmi_bigram_train.py", line 147, in get_objf
    all_frames) = get_tot_objf_and_num_frames(tot_scores,
  File "./mmi_bigram_train.py", line 72, in get_tot_objf_and_num_frames
    ok_frames = frames_per_seq[finite_indexes].sum()
IndexError: index 12919261629344 is out of bounds for dimension 0 with size 32
Traceback (most recent call last):
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/distributed/launch.py", line 260, in <module>
    main()
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/distributed/launch.py", line 255, in main
    raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/bin/python', '-u', './mmi_bigram_train.py', '--local_rank=0', '--world_size', '1']' returned non-zero exit status 1.

Really not sure how to proceed.

@csukuangfj
Copy link
Collaborator

mask = torch.ne(tot_scores, -math.inf)
# finite_indexes is a tensor containing successful segment indexes, e.g.
# [ 0 1 3 4 5 ]
finite_indexes = torch.nonzero(mask).squeeze(1)

Can you print the value of

  • tot_scores
  • mask
  • finite_indexes
    ?

@pzelasko
Copy link
Collaborator Author

After I added the print statements, I got the following error (running with one GPU, i.e. --world_size 1):

Traceback (most recent call last):
  File "./mmi_bigram_train.py", line 569, in <module>
    main()
  File "./mmi_bigram_train.py", line 502, in main
    objf, valid_objf, global_batch_idx_train = train_one_epoch(
  File "./mmi_bigram_train.py", line 234, in train_one_epoch
    curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
  File "./mmi_bigram_train.py", line 148, in get_objf
    all_frames) = get_tot_objf_and_num_frames(tot_scores,
  File "./mmi_bigram_train.py", line 70, in get_tot_objf_and_num_frames
    print(tot_scores)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/tensor.py", line 179, in __repr__
    return torch._tensor_str._str(self)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/_tensor_str.py", line 372, in _str
    return _str_intern(self)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/_tensor_str.py", line 352, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/_tensor_str.py", line 241, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2cu11/lib/python3.8/site-packages/torch/_tensor_str.py", line 101, in __init__
    if value != torch.ceil(value):
RuntimeError: CUDA error: device-side assert triggered
terminate called after throwing an instance of 'std::runtime_error'
  what():  NCCL error in: /pytorch/torch/lib/c10d/../c10d/NCCLUtils.hpp:136, unhandled cuda error, NCCL version 2.7.8
bash: line 1: 103075 Aborted                 (core dumped) /home/hltcoe/pzelasko/miniconda3/envs/k2cu11/bin/python ./mmi_bigram_train.py --world_size 1

@pzelasko
Copy link
Collaborator Author

Also, the NCCL hanging error is extremely deterministic - after I resume the training from the start of epoch 9, it will hang at exactly the same batch as in the previous run.

@danpovey
Copy link
Contributor

danpovey commented Feb 16, 2021 via email

@pzelasko
Copy link
Collaborator Author

Yes, the first minibatch. I’ll look at the nsys output and let you know. Isolating the minibatch sounds good, will try it as well.

@danpovey
Copy link
Contributor

Thanks...
cuda-memcheck may possibly show something up as well.

@danpovey
Copy link
Contributor

.. also, export K2_SYNC_KERNELS=1 will tend to make errors show up earlier if you compiled in release mode
(syncing is the default in debug mode).

@pzelasko
Copy link
Collaborator Author

I started the training from epoch 9, batch 1390, which was the last batch printed out before the program hanged (2 GPUs, CUDA 10.2). It did hang again, so I went on and ran it with nsys. Unfortunately, I fail to see the potential reason for hanging in the report. You can download the profile at this URL and maybe you'll be able to read it better than I (it's about 20MB so don't worry about an excessive size).

🔗 https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/pzelask2_jh_edu/EeT_0llsWWVIrSKhlbsQpHAByMsZF4ZvDv6D-0-j7wWNLQ?e=zPrccA

As for 1 GPU + CUDA 11, I have the nsys profile too (~10MB), but can't seem to extract anything useful out of it..

🔗 https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/pzelask2_jh_edu/EdUBHcDB7l9GmGH4iiqmGMsBiXX2L2to6OoSYCLv74bHLw?e=rdcs5t

@pzelasko
Copy link
Collaborator Author

One more thing - in CUDA 11 case, I noticed the following failed assertions just before the crash in print(tot_scores):

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [32,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [33,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [34,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [35,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [36,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [37,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [38,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [39,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [40,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [41,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [42,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [43,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [44,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [45,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [46,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [47,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [48,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [49,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [50,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [51,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [52,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [53,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [54,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [55,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [56,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [57,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [58,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [59,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [60,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [61,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [62,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [47,0,0], thread: [63,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.

@danpovey
Copy link
Contributor

danpovey commented Feb 17, 2021 via email

@danpovey
Copy link
Contributor

BTW, although we should definitely debug this, I don't think we need to drop everything to do this.
One possibility if it proves intractable is to just make other progress and come back to multi-GPU training later.

@danpovey
Copy link
Contributor

And BTW, if nsys isn't loading the other qdrep file it could be because the version of Nsight Systems is not new enough. But don't worry about that; likely it's not the best way to debug the CUDA 11 setup anyway.

For the hang: something that will be helpful is stack traces for all the processes. cuda-gdb may only give it for the parent process. It might be necessary to, individually for the other processes that are running, do something like:
gdb python3
(gdb) attach (process-id)
... and hope that that shows us some kind of stack information. If not, /proc/(process-id)/stack may show some useful information too.

@pzelasko
Copy link
Collaborator Author

I’m OK to go debug it one step at a time in my own „background thread”, as long as we have ideas for the next steps. I’ll proceed with your suggestions and get back to you. About the qdrep files - I was able to open them and see the profile, I meant that I haven’t learned anything useful from it.

@danpovey
Copy link
Contributor

Most important next steps are those about n-best list rescoring and to do with extracting phone-synchronous features.
I believe there are some issues on snowfall and possibly k2 about that.

@pzelasko
Copy link
Collaborator Author

pzelasko commented Feb 18, 2021

For CUDA 11 bug I ran it with CUDA_LAUNCH_BLOCKING=1 and K2_SYNC_KERNELS=1, but the output was the same.

Actually yes, let's put a pin on this one. I will resolve the conflicts and merge (if you run single GPU training the NCCL hanging issue does not arise; if you use CUDA 11 the issue does not seem related to this PR's changes). We can debug this further in the future. Maybe the hanging problem won't arise with a different architecture that doesn't use LSTM (I wouldn't be shocked).

@pzelasko
Copy link
Collaborator Author

One good bit of news - the BucketingSampler wastes less computation on padding and so allows to increase max_frames (effectively the batch size). On a V100 I cranked it up all the way to max_frames=130000 which uses all 32GB of RAM. The WER number is 10.62% (I updated the results), I think maybe it needs some learning rate tuning to account for much larger batch size.

I've added an option to use bucketing in the Transformer MMI recipe too, it seems like the strongest recipe currently so I'll check how useful it is there.

@danpovey
Copy link
Contributor

Thanks!! Merging.

@danpovey danpovey merged commit c20762e into k2-fsa:master Feb 19, 2021
@csukuangfj
Copy link
Collaborator

#71 (comment)

I found it! The culprit was "torch.distributed.reduce" which causes the NCCL to hang for some reason. After I discovered that I started searching if others had this issue and there are plenty issues in PyTorch's repo about reduce/allreduce + NCCL.

@pzelasko
Maybe #152 (comment) explains why it hangs in allreduce.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants