From 21662e1b9319ed57a543e5b3cd33a6fe4bf67f84 Mon Sep 17 00:00:00 2001 From: "M. Eren Akbiyik" Date: Sat, 2 Sep 2023 00:46:26 +0200 Subject: [PATCH] test: compression and cache limits --- tests/test_torchcache.py | 101 +++++++++++++++++++++++++++++++++++++++ torchcache/torchcache.py | 23 ++++++--- 2 files changed, 118 insertions(+), 6 deletions(-) diff --git a/tests/test_torchcache.py b/tests/test_torchcache.py index 954d3af..7e1915e 100644 --- a/tests/test_torchcache.py +++ b/tests/test_torchcache.py @@ -46,6 +46,28 @@ class CachedModule(SimpleModule): output_cached = model(input_tensor) assert torch.equal(output, output_cached) + # Third time is the charm, but let's use a bigger batch size + input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32) + output_cached = model(input_tensor) + assert torch.equal(output, output_cached[:2]) + + # Argument checks + with pytest.raises(ValueError): + + @torchcache(persistent=True, zstd_compression=True, use_mmap_on_load=True) + class CachedModule(SimpleModule): + pass + + CachedModule() + + with pytest.raises(ValueError): + + @torchcache(persistent=False, zstd_compression=True) + class CachedModule(SimpleModule): + pass + + CachedModule() + # Test caching mechanism with persistent storage. def test_persistent_caching(tmp_path): @@ -121,6 +143,85 @@ def test_hashing(): assert hashes.shape[0] == input_tensor.shape[0] +def test_compression(tmp_path): + @torchcache(persistent=True, persistent_cache_dir=tmp_path, zstd_compression=True) + class CachedModule(SimpleModule): + pass + + model = CachedModule() + input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32) + + # First pass, caching should occur and save to file + output = model(input_tensor) + assert torch.equal(output, input_tensor * 2) + + # Check if cache files were created + assert len(list((tmp_path / model.cache_instance.module_hash).iterdir())) == 2 + + # Second pass, should retrieve from cache from memory + output_cached = model(input_tensor) + assert torch.equal(output, output_cached) + + # Now create a new instance of the model and check if the cache is loaded from disk + # We re-define the class to flush the cache in memory + @torchcache(persistent=True, persistent_cache_dir=tmp_path, zstd_compression=True) + class CachedModule(SimpleModule): + pass + + model2 = CachedModule() + original_load_from_file = model2.cache_instance._load_from_file + model2.cache_instance.original_load_from_file = original_load_from_file + load_from_file_called = False + + def _load_from_file(*args, **kwargs): + nonlocal load_from_file_called + load_from_file_called = True + original_load_from_file(*args, **kwargs) + + model2.cache_instance._load_from_file = _load_from_file + output_cached = model2(input_tensor) + assert torch.equal(output, output_cached) + assert load_from_file_called + + +# Test cache size limits +def test_cache_size(tmp_path): + # Overhead of saving a tensor in disk is around 700 bytes + @torchcache( + persistent=True, + persistent_cache_dir=tmp_path, + max_persistent_cache_size=1500, + max_memory_cache_size=20, + ) + class CachedModule(SimpleModule): + pass + + model = CachedModule() + input_tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32) + input_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32) + + # First pass, caching should occur and save to file + output = model(input_tensor1) + assert torch.equal(output, input_tensor1 * 2) + + # Check if cache files were created + assert len(list((tmp_path / model.cache_instance.module_hash).iterdir())) == 2 + + # Check that the persistent flag is not set, but the memory flag is + assert not model.cache_instance.is_persistent_cache_full + assert model.cache_instance.is_memory_cache_full + + # Now pass a tensor that is bigger than the cache size + output = model(input_tensor2) + assert torch.equal(output, input_tensor2 * 2) + + # Check if cache files were not created + assert len(list((tmp_path / model.cache_instance.module_hash).iterdir())) == 2 + + # Check that the flag is set + assert model.cache_instance.is_persistent_cache_full + + # Test for mixed cache hits def test_mixed_cache_hits(): @torchcache(persistent=False) diff --git a/torchcache/torchcache.py b/torchcache/torchcache.py index 3c38f1a..a1b674d 100644 --- a/torchcache/torchcache.py +++ b/torchcache/torchcache.py @@ -152,6 +152,14 @@ def __init__( improve the performance for large files, but it cannot be used together with compression. """ + if not persistent and zstd_compression: + raise ValueError("Cannot use zstd compression without persistent cache") + + if zstd_compression and use_mmap_on_load: + raise ValueError( + "Cannot use zstd compression and mmap on load at the same time" + ) + # Rolling powers of the hash base, up until 2**15 to fit in float16 roll_powers = torch.arange(0, subsample_count * 2) % 15 self.subsample_count = subsample_count @@ -171,11 +179,6 @@ def __init__( self.is_memory_cache_full = False self.cache_dtype = cache_dtype - if self.zstd_compression and self.use_mmap_on_load: - raise ValueError( - "Cannot use zstd compression and mmap on load at the same time" - ) - # We allow explicit overloading of mmap option despite version # check so that people can use it with nightly versions torch_version = torch.__version__.split(".") @@ -598,7 +601,15 @@ def _load_from_file(self, hash_val: int) -> Union[Tensor, None]: else: if self.use_mmap_on_load: load_kwargs["mmap"] = True - embedding = torch.load(str(file_path), **load_kwargs) + try: + embedding = torch.load(str(file_path), **load_kwargs) + except Exception as e: + logger.error( + f"Could not read file {file_path}, skipping loading from file. " + f"Error: {e}\nRemoving the file to avoid future errors." + ) + file_path.unlink(missing_ok=True) + return None logger.debug("Caching to memory before returning") self._cache_to_memory(embedding, hash_val)