From aea7df8395ae183a479e8b99560d37fb0f67c5ef Mon Sep 17 00:00:00 2001 From: Alexander Weber <12560547+lllAlexanderlll@users.noreply.github.com> Date: Mon, 4 Mar 2024 16:06:27 +0000 Subject: [PATCH] fix: remove tmpdir removal from test --- .../test_convert_to_hf_checkpoint.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/checkpointing/test_convert_to_hf_checkpoint.py b/tests/checkpointing/test_convert_to_hf_checkpoint.py index 7cf92c5a..57a11a95 100644 --- a/tests/checkpointing/test_convert_to_hf_checkpoint.py +++ b/tests/checkpointing/test_convert_to_hf_checkpoint.py @@ -1,5 +1,4 @@ import os -import shutil from pathlib import Path from unittest.mock import patch @@ -7,9 +6,9 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM -from modalities.config.config import GPT2HuggingFaceAdapterConfig, PydanticModelIFType, load_app_config_dict -from modalities.models.gpt2.huggingface_model import HuggingFaceModel from modalities.checkpointing import checkpoint_conversion +from modalities.config.config import GPT2HuggingFaceAdapterConfig +from modalities.models.gpt2.huggingface_model import HuggingFaceModel @pytest.fixture @@ -35,13 +34,11 @@ def test_convert_to_hf_checkpoint(tmp_path, config_path, device): checkpoint_dir=config_path.parent, config_file_name=config_path.name, model_file_name="", - output_hf_checkpoint_dir=tmp_path + output_hf_checkpoint_dir=tmp_path, ) pytorch_model = cp._setup_model() with patch.object( - checkpoint_conversion.CheckpointConversion, - "_get_model_from_checkpoint", - return_value=pytorch_model + checkpoint_conversion.CheckpointConversion, "_get_model_from_checkpoint", return_value=pytorch_model ): cp.convert_pytorch_to_hf_checkpoint() @@ -65,7 +62,3 @@ def test_convert_to_hf_checkpoint(tmp_path, config_path, device): hf_model.eval() output_after_loading = hf_model.forward(test_tensor) assert (output_after_loading == output_before_loading).all() - - # Delete temporary model folder - shutil.rmtree(tmp_path.parent) - assert os.path.exists(tmp_path.parent) is False