Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 54 additions & 9 deletions src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib
import sys
import torch
import os

from ninetoothed.generation import CodeGenerator
from ninetoothed.utils import calculate_default_configs
Expand Down Expand Up @@ -97,13 +99,10 @@ def __call__(self):
self._max_num_configs,
self._prettify,
)
module = import_from_path(source_file, source_file)
module_vars = vars(module)


handle = _Handle(
module_vars[self._kernel_name],
module_vars[code_generator.launch_func_name],
source_file,
code_generator,
)

return handle
Expand All @@ -118,11 +117,57 @@ def import_from_path(module_name, file_path):
return module


def get_target_device(*args, **kwargs):
target_device = None
for arg in args:
if isinstance(arg, torch.Tensor):
target_device = arg.device
break

if target_device is None:
for val in kwargs.values():
if isinstance(val, torch.Tensor):
target_device = val.device
break

return target_device


def convert_to_cpu(source_file_path):
if not os.path.exists(source_file_path):
raise FileNotFoundError(f"源文件不存在: {source_file_path}")

dir_name = os.path.dirname(source_file_path)
base_name = os.path.basename(source_file_path)
name, ext = os.path.splitext(base_name)

new_file_name = f"{name}_cpu{ext}"
new_file_path = os.path.join(dir_name, new_file_name)

with open(source_file_path, 'r', encoding='utf-8') as f:
content = f.read()

new_content = content.replace("triton", "triton_cpu")
with open(new_file_path, 'w', encoding='utf-8') as f:
f.write(new_content)

return new_file_path


class _Handle:
def __init__(self, kernel, launch, source):
self._kernel = kernel
self._launch = launch
def __init__(self, source, code_generator):
self._source = source
self._code_generator = code_generator

def __call__(self, *args, **kwargs):
return self._launch(*args, **kwargs)
target_device = get_target_device(*args, **kwargs)

if target_device is not None and str(target_device) == "cpu":
cpu_path=convert_to_cpu(self._source)
self._source = cpu_path

module = import_from_path(self._source, self._source)
module_vars = vars(module)
self._launch = module_vars[self._code_generator.launch_func_name]

return self._launch(*args, **kwargs)
Loading