Skip to content

Commit

Permalink
Upgrade ZLUDA.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Jan 17, 2025
1 parent 5adfe6b commit 3cf5301
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 29 deletions.
2 changes: 2 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ def prepare_environment():
from modules import zluda_installer
zluda_installer.set_default_agent(device)
try:
if zluda_installer.is_old_zluda():
zluda_installer.uninstall()
zluda_installer.install()
except Exception as e:
error = e
Expand Down
28 changes: 1 addition & 27 deletions modules/zluda_hijacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from modules import zluda_installer, rocm
from modules import rocm


_topk = torch.topk
Expand All @@ -9,32 +9,6 @@ def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-buil
return torch.return_types.topk((values.to(device), indices.to(device),))


_fft_fftn = torch.fft.fftn
def fft_fftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_fftn(input.cpu(), *args, **kwargs).to(input.device)


_fft_ifftn = torch.fft.ifftn
def fft_ifftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_ifftn(input.cpu(), *args, **kwargs).to(input.device)


_fft_rfftn = torch.fft.rfftn
def fft_rfftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_rfftn(input.cpu(), *args, **kwargs).to(input.device)


def jit_script(f, *_, **__): # experiment / provide dummy graph
f.graph = torch._C.Graph() # pylint: disable=protected-access
return f


def do_hijack():
torch.version.hip = rocm.version
torch.topk = topk
torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
torch.fft.rfftn = fft_rfftn

if not zluda_installer.get_blaslt_enabled():
torch.jit.script = jit_script
10 changes: 8 additions & 2 deletions modules/zluda_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
DLL_MAPPING = {
'cublas.dll': 'cublas64_11.dll',
'cusparse.dll': 'cusparse64_11.dll',
'cufft.dll': 'cufft64_10.dll',
'cufftw.dll': 'cufftw64_10.dll',
'nvrtc.dll': 'nvrtc64_112_0.dll',
}
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll']
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', 'hipfft.dll',]
ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',)

path = os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
Expand All @@ -28,12 +30,16 @@ def set_default_agent(agent: rocm.Agent):
default_agent = agent


def is_old_zluda() -> bool: # ZLUDA<3.8.7
return not os.path.exists(os.path.join(path, "cufftw.dll"))


def install() -> None:
if os.path.exists(path):
return

platform = "windows"
commit = os.environ.get("ZLUDA_HASH", "d60bddbc870827566b3d2d417e00e1d2d8acc026")
commit = os.environ.get("ZLUDA_HASH", "c4994b3093e02231339d22e12be08418b2af781f")
if nightly:
platform = "nightly-" + platform
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-{platform}-rocm{rocm.version[0]}-amd64.zip', '_zluda')
Expand Down

0 comments on commit 3cf5301

Please sign in to comment.