Skip to content

Commit

Permalink
unset accelerate envs temporarily when merging
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Apr 29, 2024
1 parent bdc362f commit 23271c9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
10 changes: 8 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
maybe_log_params_to_mlfoundry,
sanitize_name,
)
from utils import maybe_set_custom_tempdir, maybe_set_torch_max_memory, try_cleanup_gpus
from utils import (
maybe_set_custom_tempdir,
maybe_set_torch_max_memory,
temporarily_unset_accelerate_envs,
try_cleanup_gpus,
)

logger = logging.getLogger("axolotl")

Expand Down Expand Up @@ -212,7 +217,8 @@ def train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs):
model_dir = cfg.output_dir
cleanup_checkpoints(output_dir=cfg.output_dir)
if cfg.adapter in {"lora", "qlora"}:
axolotl_merge_lora_cli(config=axolotl_config, deepspeed=None, fsdp=None, device_map="auto")
with temporarily_unset_accelerate_envs():
axolotl_merge_lora_cli(config=axolotl_config, deepspeed=None, fsdp=None, device_map="auto")
model_dir = os.path.join(model_dir, "merged")
model_parent_dir = os.path.dirname(model_dir)
# Copy tensorboard logs
Expand Down
11 changes: 11 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import gc
import json
import logging
Expand Down Expand Up @@ -102,6 +103,16 @@ def maybe_set_torch_max_memory(device: int):
torch.cuda.set_per_process_memory_fraction(0.95, device=device)


@contextlib.contextmanager
def temporarily_unset_accelerate_envs():
accelerate_envs = {}
for key in os.environ:
if key.startswith("ACCELERATE_"):
accelerate_envs[key] = os.environ.pop(key)
yield
os.environ.update(accelerate_envs)


# Notebook Utils


Expand Down

0 comments on commit 23271c9

Please sign in to comment.