Skip to content

Commit

Permalink
fix: remove tmpdir removal from test
Browse files Browse the repository at this point in the history
  • Loading branch information
lllAlexanderlll committed Mar 4, 2024
1 parent f85c071 commit aea7df8
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions tests/checkpointing/test_convert_to_hf_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
import shutil
from pathlib import Path
from unittest.mock import patch

import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM

from modalities.config.config import GPT2HuggingFaceAdapterConfig, PydanticModelIFType, load_app_config_dict
from modalities.models.gpt2.huggingface_model import HuggingFaceModel
from modalities.checkpointing import checkpoint_conversion
from modalities.config.config import GPT2HuggingFaceAdapterConfig
from modalities.models.gpt2.huggingface_model import HuggingFaceModel


@pytest.fixture
Expand All @@ -35,13 +34,11 @@ def test_convert_to_hf_checkpoint(tmp_path, config_path, device):
checkpoint_dir=config_path.parent,
config_file_name=config_path.name,
model_file_name="",
output_hf_checkpoint_dir=tmp_path
output_hf_checkpoint_dir=tmp_path,
)
pytorch_model = cp._setup_model()
with patch.object(
checkpoint_conversion.CheckpointConversion,
"_get_model_from_checkpoint",
return_value=pytorch_model
checkpoint_conversion.CheckpointConversion, "_get_model_from_checkpoint", return_value=pytorch_model
):
cp.convert_pytorch_to_hf_checkpoint()

Expand All @@ -65,7 +62,3 @@ def test_convert_to_hf_checkpoint(tmp_path, config_path, device):
hf_model.eval()
output_after_loading = hf_model.forward(test_tensor)
assert (output_after_loading == output_before_loading).all()

# Delete temporary model folder
shutil.rmtree(tmp_path.parent)
assert os.path.exists(tmp_path.parent) is False

0 comments on commit aea7df8

Please sign in to comment.