Skip to content

Commit

Permalink
feat: move parameters to decorator for type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
meakbiyik committed Sep 25, 2023
1 parent 87a4346 commit d73c42a
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 69 deletions.
46 changes: 45 additions & 1 deletion tests/test_torchcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,20 @@ class CachedModule(SimpleModule):

# Test hashing functionality.
def test_hashing():
cache = _TorchCache()
cache = _TorchCache(
memory_cache_device="cpu",
subsample_count=10000,
persistent=False,
persistent_cache_dir=None,
persistent_module_hash=None,
max_persistent_cache_size=int(10e9),
max_memory_cache_size=int(1e9),
zstd_compression=False,
zstd_compression_level=3,
zstd_compression_threads=1,
cache_dtype=None,
use_mmap_on_load=False,
)
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
hashes = cache.hash_tensor(input_tensor)

Expand Down Expand Up @@ -386,6 +399,37 @@ class CachedModule(SimpleModule):
assert torch.equal(output, input_tensor * 2)


def test_persistent_module_hash(tmp_path):
@torchcache(
persistent=True, persistent_module_hash="test", persistent_cache_dir=tmp_path
)
class CachedModule(SimpleModule):
pass

model = CachedModule()
input_tensor = torch.tensor([[1, 2, 3]], dtype=torch.float32)

output = model(input_tensor)

assert torch.equal(output, input_tensor * 2)

# Create another Module with the same persistent_module_hash but different forward pass
@torchcache(
persistent=True, persistent_module_hash="test", persistent_cache_dir=tmp_path
)
class CachedModule2(SimpleModule):
def forward(self, x):
return x * 3

model2 = CachedModule2()

# Second pass, should retrieve from cache but result should be the same
# as the first module since the persistent_module_hash is the same
output = model(input_tensor)

assert torch.equal(output, input_tensor * 2)


@pytest.fixture(autouse=True)
def cleanup():
yield
Expand Down
159 changes: 91 additions & 68 deletions torchcache/torchcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@


def torchcache(
*cache_args,
**cache_kwargs,
*,
memory_cache_device: str = "cpu",
subsample_count: int = 10000,
persistent: bool = False,
persistent_cache_dir: str = None,
persistent_module_hash: str = None,
max_persistent_cache_size: int = int(10e9),
max_memory_cache_size: int = int(1e9),
zstd_compression: bool = False,
zstd_compression_level: int = 3,
zstd_compression_threads: int = 1,
cache_dtype: torch.dtype = None,
use_mmap_on_load: bool = False,
) -> callable:
r"""Decorate a nn.Module class to cache the output of the forward pass.
Expand All @@ -34,20 +45,81 @@ def torchcache(
class CachedModule(nn.Module):
pass
Refer to the documentation of _TorchCache for more information. You can
also override the arguments of _TorchCache by setting class attributes
that starts with "torchcache\_". For example, to set the cache directory
for a module (persistent_cache_dir), you can do:
You can also override the arguments of the underlying class instance by
setting class attributes that starts with "torchcache\_". For example,
to set the cache directory for a module (persistent_cache_dir), you can do:
.. code-block:: python
@torchcache(persistent=True)
class CachedModule(nn.Module):
def __init__(self, cache_dir: str | Path):
self.torchcache_persistent_cache_dir = cache_dir
Parameters
----------
subsample_count : int
Number of values to subsample from the tensor in hash computation,
by default 10000. This is used to improve hashing performance,
at the cost of a higher probability of hash collisions. Current
default is 10000, which should be enough for most use cases.
memory_cache_device : str or torch.device, optional
Device to use for the cache, by default "cpu". If None, then the
original device of the tensor is used.
persistent : bool, optional
Whether to use a file-system-based cache, by default False
persistent_cache_dir : str or Path, optional
Directory to use for caching, by default None. If None, then a temporary
directory is used. Only used if persistent is True.
persistent_module_hash : str, optional
Hash of the module definition, args, and kwargs, by default None. If None,
then the module hash is automatically determined. You can explicitly
set this if you want to use the same cache for slightly different
modules. You can find the module hash in the following locations:
- In the logs, if you set the logging level to INFO or DEBUG
- In the cached module's self.cache_instance.module_hash attribute
- As the name of the subdirectory in the persistent cache
max_persistent_cache_size : int, optional
Maximum size of the persistent cache in bytes, by default 10e9 (10 GB)
max_memory_cache_size : int, optional
Maximum size of the memory cache in bytes, by default 1e9 (1 GB)
zstd_compression : bool, optional
Whether to use zstd compression, by default False. See
https://github.com/sergey-dryabzhinsky/python-zstd for more information
on the arguments below.
zstd_compression_level : int, optional
Compression level to use, by default 3. Must be between -100 and 22,
where -100 is the fastest compression and 22 is the slowest.
zstd_compression_threads : int, optional
Number of threads to use for compression, by default 1. If 0, then the
number of threads is automatically determined.
cache_dtype : torch.dtype, optional
Data type to use for the cache, by default None. If None, then the
data type of the first tensor that is processed is used.
use_mmap_on_load : bool, optional
Whether to use mmap when loading the cached embeddings from file, by
default False. If None, then it is automatically determined based on
the torch version. This is only used if persistent is True. This option
might be useful if you are using a version >= 2.0.1, as it should
improve the performance for large files, but it cannot be used together
with compression.
"""
# Multiple initialization of the same class shares the same cache
cache_instance = None
cache_kwargs = {
"memory_cache_device": memory_cache_device,
"subsample_count": subsample_count,
"persistent": persistent,
"persistent_cache_dir": persistent_cache_dir,
"persistent_module_hash": persistent_module_hash,
"max_persistent_cache_size": max_persistent_cache_size,
"max_memory_cache_size": max_memory_cache_size,
"zstd_compression": zstd_compression,
"zstd_compression_level": zstd_compression_level,
"zstd_compression_threads": zstd_compression_threads,
"cache_dtype": cache_dtype,
"use_mmap_on_load": use_mmap_on_load,
}
magic_prefix = "torchcache_"

def decorator(ModuleClass):
Expand Down Expand Up @@ -89,69 +161,20 @@ class _TorchCache:
def __init__(
self,
*,
memory_cache_device: str = "cpu",
subsample_count: int = 10000,
persistent: bool = False,
persistent_cache_dir: str = None,
persistent_module_hash: str = None,
max_persistent_cache_size: int = int(10e9),
max_memory_cache_size: int = int(1e9),
zstd_compression: bool = False,
zstd_compression_level: int = 3,
zstd_compression_threads: int = 1,
cache_dtype: torch.dtype = None,
use_mmap_on_load: bool = False,
memory_cache_device: str,
subsample_count: int,
persistent: bool,
persistent_cache_dir: str,
persistent_module_hash: str,
max_persistent_cache_size: int,
max_memory_cache_size: int,
zstd_compression: bool,
zstd_compression_level: int,
zstd_compression_threads: int,
cache_dtype: torch.dtype,
use_mmap_on_load: bool,
):
"""Initialize the torchcache.
Parameters
----------
subsample_count : int
Number of values to subsample from the tensor in hash computation,
by default 10000. This is used to improve hashing performance,
at the cost of a higher probability of hash collisions. Current
default is 10000, which should be enough for most use cases.
memory_cache_device : str or torch.device, optional
Device to use for the cache, by default "cpu". If None, then the
original device of the tensor is used.
persistent : bool, optional
Whether to use a file-system-based cache, by default False
persistent_cache_dir : str or Path, optional
Directory to use for caching, by default None. If None, then a temporary
directory is used. Only used if persistent is True.
persistent_module_hash : str, optional
Hash of the module definition, args, and kwargs, by default None. If None,
then the module hash is automatically determined. You can explicitly
set this if you want to use the same cache for slightly different
modules. You can find the module hash in the following locations:
- In the logs, if you set the logging level to INFO or DEBUG
- In the cached module's self.cache_instance.module_hash attribute
- As the name of the subdirectory in the persistent cache
max_persistent_cache_size : int, optional
Maximum size of the persistent cache in bytes, by default 10e9 (10 GB)
max_memory_cache_size : int, optional
Maximum size of the memory cache in bytes, by default 1e9 (1 GB)
zstd_compression : bool, optional
Whether to use zstd compression, by default False. See
https://github.com/sergey-dryabzhinsky/python-zstd for more information
on the arguments below.
zstd_compression_level : int, optional
Compression level to use, by default 3. Must be between -100 and 22,
where -100 is the fastest compression and 22 is the slowest.
zstd_compression_threads : int, optional
Number of threads to use for compression, by default 1. If 0, then the
number of threads is automatically determined.
cache_dtype : torch.dtype, optional
Data type to use for the cache, by default None. If None, then the
data type of the first tensor that is processed is used.
use_mmap_on_load : bool, optional
Whether to use mmap when loading the cached embeddings from file, by
default False. If None, then it is automatically determined based on
the torch version. This is only used if persistent is True. This option
might be useful if you are using a version >= 2.0.1, as it should
improve the performance for large files, but it cannot be used together
with compression.
"""
"""Initialize the torchcache."""
if not persistent and zstd_compression:
raise ValueError("Cannot use zstd compression without persistent cache")

Expand Down

0 comments on commit d73c42a

Please sign in to comment.