Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix too sensitive "Unsloth currently does not support multi GPU setups" when training with a single GPU in a multi-GPU environment. #1295

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,11 +1684,11 @@ def from_pretrained(
else:
inner_training_loop = Trainer._original_training_loop
except:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
raise RuntimeError('llama.py:1687 Unsloth currently does not support multi GPU setups - but we are working on it!')
pass

if ((post_check - pre_check) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
raise RuntimeError('llama.py:1691 Unsloth currently does not support multi GPU setups - but we are working on it!')

import transformers.trainer
items_in_trainer = dir(transformers.trainer)
Expand All @@ -1715,17 +1715,23 @@ def from_pretrained(
f"{chr(92)} / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
logger.warning(debug_info)
import subprocess, re, gc, numpy as np
import subprocess, os, re, gc, numpy as np
index_for_cuda = os.environ.get("CUDA_VISIBLE_DEVICES", -1)
if "," in index_for_cuda:
raise RuntimeError("llama.py:1681 Unsloth currently does not support multi GPU setups - but we are working on it!")
index_for_cuda = int(index_for_cuda)
a = np.array([0,])
try:
a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)
a = np.array([int(x.decode('utf-8'))/1024 for x in a])
if index_for_cuda != -1:
a = np.array([a[index_for_cuda],])
except:
if not torch.cuda.is_available():
raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')
if ((a - PRE_CHECK) >= 1).sum() > 1:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
raise RuntimeError('llama.py:1694 Unsloth currently does not support multi GPU setups - but we are working on it!')
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()"""
Expand Down Expand Up @@ -1786,7 +1792,7 @@ def from_pretrained(
"False",
)
if "n_total_devices >" not in inner_training_loop:
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')
raise RuntimeError('llama.py:1795 Unsloth currently does not support multi GPU setups - but we are working on it!')
pass
inner_training_loop = inner_training_loop.replace(
"is_sagemaker_mp_enabled()",
Expand Down
14 changes: 12 additions & 2 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,12 +830,18 @@ def check_tokenizer(


def check_nvidia():
index_for_cuda = os.environ.get("CUDA_VISIBLE_DEVICES", -1)
if "," in index_for_cuda:
raise RuntimeError("Unsloth currently does not support multi GPU setups - but we are working on it!")
index_for_cuda = int(index_for_cuda)
# Unsloth doesn't work yet on AMD devices - we're working on it!
output = np.array([0,])
try:
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
if index_for_cuda != -1:
output = np.array([output[index_for_cuda],])
except:
if not torch.cuda.is_available():
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
Expand Down Expand Up @@ -958,7 +964,11 @@ def patch_sft_trainer_tokenizer():

check_text = \
"\n"\
"import subprocess, re, gc, numpy as np\n"\
"import subprocess, os, re, gc, numpy as np\n"\
"index_for_cuda = os.environ.get(\"CUDA_VISIBLE_DEVICES\", -1)\n"\
"if \",\" in index_for_cuda:\n"\
" raise RuntimeError(\"tokenizer_utils.py:970 Unsloth currently does not support multi GPU setups - but we are working on it!\")\n"\
"index_for_cuda = int(index_for_cuda)\n"\
"a = np.array([0,])\n"\
"try:\n"\
" a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\
Expand All @@ -968,7 +978,7 @@ def patch_sft_trainer_tokenizer():
" if not torch.cuda.is_available():\n"\
" raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"\
"if ((a - PRE_CHECK) >= 1).sum() > 1:\n"\
" raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\
" raise RuntimeError('tokenizer_utils.py:981 Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\
"for _ in range(3):\n"\
" gc.collect()\n"\
" torch.cuda.empty_cache()\n"\
Expand Down