Skip to content

Commit

Permalink
feat: refactor module hash function
Browse files Browse the repository at this point in the history
  • Loading branch information
meakbiyik committed Jul 20, 2024
1 parent ffa422b commit db758d7
Showing 1 changed file with 48 additions and 16 deletions.
64 changes: 48 additions & 16 deletions torchcache/torchcache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module implements the torchcache decorator and the underlying class."""

import atexit
import hashlib
import inspect
Expand Down Expand Up @@ -343,23 +344,11 @@ def wrap_module(
module.register_forward_pre_hook(self.forward_pre_hook)
module.register_forward_hook(self.forward_hook)
if self.module_hash is None:
# Create a hash of the module definition, args, and kwargs
# So that we do not mistakenly use the cache for a different module
logger.debug("Creating module hash")
try:
module_definition = inspect.getsource(moduleClass)
hash_string = module_definition + repr(args) + repr(kwargs)
except OSError as e:
logger.error(f"Could not retrieve the module source: {e}")
# If the module source cannot be retrieved, we use the module name
hash_string = module.__class__.__name__ + repr(args) + repr(kwargs)
# Also add the crucial parameters of torchcache
hash_string += repr(self.subsample_count) + repr(self.zstd_compression)
logger.debug(f"Module hash string: {hash_string}")
self.module_hash = hashlib.blake2b(
hash_string.encode(),
digest_size=32,
).hexdigest()
self.module_hash = self._generate_module_hash(
module, moduleClass, *args, **kwargs
)

logger.info(f"Module hash: {self.module_hash}")
# If we are using a persistent cache, create a subdirectory for the module
if self.persistent:
Expand All @@ -380,6 +369,49 @@ def forward_wrapper(*args, **kwargs):

return module

def _generate_module_hash(
self,
module: torch.nn.Module,
moduleClass: Type[torch.nn.Module],
*args,
**kwargs,
) -> str:
"""Generate a hash of the module definition, args, and kwargs.
If possible, create a hash of the module definition, args, and kwargs
so that we do not mistakenly use the cache for a different module.
Parameters
----------
module : torch.nn.Module
Module to hash.
moduleClass : Type[torch.nn.Module]
Module class to hash.
*args
Positional arguments to hash.
**kwargs
Keyword arguments to hash.
Returns
-------
str
Hash of the module definition, args, and kwargs.
"""
try:
module_definition = inspect.getsource(moduleClass)
hash_string = module_definition + repr(args) + repr(kwargs)
except OSError as e:
logger.error(f"Could not retrieve the module source: {e}")
# If the module source cannot be retrieved, we use the module name
hash_string = module.__class__.__name__ + repr(args) + repr(kwargs)
# Also add the crucial parameters of torchcache
hash_string += repr(self.subsample_count) + repr(self.zstd_compression)
logger.debug(f"Module hash string: {hash_string}")
return hashlib.blake2b(
hash_string.encode(),
digest_size=32,
).hexdigest()

def _fetch_cached_embeddings(
self,
) -> Tensor:
Expand Down

0 comments on commit db758d7

Please sign in to comment.