Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
Expand All @@ -24,7 +26,8 @@ def device_sync(device):

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
if hasattr(torch._inductor.config, "fx_graph_cache"):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intel GPU currently uses a PyTorch fork based on 2.1 which doesn't have fx_graph_cache yet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment.

torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -238,8 +241,9 @@ def _load_model(checkpoint_path, device, precision, use_tp):
model.load_state_dict(checkpoint, assign=True)

if use_tp:
from tp import apply_tp
from tp import apply_tp, global_device
print("Applying tensor parallel to model ...")
global_device(device)
apply_tp(model)

model = model.to(device=device, dtype=precision)
Expand Down Expand Up @@ -271,7 +275,7 @@ def main(

global print
from tp import maybe_init_dist
rank = maybe_init_dist()
rank = maybe_init_dist(device)
use_tp = rank is not None
if use_tp:
if rank != 0:
Expand Down Expand Up @@ -303,7 +307,7 @@ def main(
torch.manual_seed(1234)
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
if compile:
if is_speculative and use_tp: # and ("cuda" in device):
if is_speculative and use_tp and ("cuda" in device):
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case

if is_speculative:
Expand Down Expand Up @@ -354,8 +358,15 @@ def callback(x):
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
if "cuda" in device:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
elif "xpu" in device:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU],
)
with prof:
y, metrics = generate(
model,
Expand Down Expand Up @@ -419,6 +430,11 @@ def callback(x):
parser.add_argument('--device', type=str, default=default_device, help='Device to use')

args = parser.parse_args()
if "xpu" in args.device:
try:
import intel_extension_for_pytorch as ipex
except:
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
Expand Down
26 changes: 21 additions & 5 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif "cpu" in device:
pass
else:
Expand All @@ -24,7 +26,8 @@ def device_sync(device):

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
if hasattr(torch._inductor.config, "fx_graph_cache"):
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future


# support running without installing as a package
Expand Down Expand Up @@ -178,7 +181,7 @@ def main(
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
rank = maybe_init_dist()
rank = maybe_init_dist(device)
use_tp = rank is not None
if use_tp:
if rank != 0:
Expand All @@ -203,7 +206,8 @@ def main(
torch.manual_seed(1234)
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
if compile:
torch._inductor.config.assert_indirect_indexing = False
if hasattr(torch._inductor.config, "assert_indirect_indexing"):
torch._inductor.config.assert_indirect_indexing = False

global decode_one_token, prefill
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
Expand Down Expand Up @@ -248,8 +252,15 @@ def callback(x):
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
if "cuda" in device:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
elif "xpu" in device:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU],
)
with prof:
y = generate(
model,
Expand Down Expand Up @@ -302,6 +313,11 @@ def callback(x):
parser.add_argument('--device', type=str, default="cuda", help='device to use')

args = parser.parse_args()
if "xpu" in args.device:
try:
import intel_extension_for_pytorch as ipex
except:
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device
Expand Down
19 changes: 16 additions & 3 deletions mixtral-moe/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def local_break():
def _get_world_size() -> int:
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))

def maybe_init_dist() -> Optional[int]:
def maybe_init_dist(device) -> Optional[int]:
try:
# provided by torchrun
rank = _get_rank()
Expand All @@ -41,8 +41,21 @@ def maybe_init_dist() -> Optional[int]:
# not run via torchrun, no-op
return None

torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
if "cuda" in device:
torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
elif "xpu" in device:
try:
import oneccl_bindings_for_pytorch
except:
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")

os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
os.environ['CCL_LOCAL_RANK'] = str(rank)

torch.xpu.set_device(rank)
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
return rank

rank = _get_rank()
Expand Down
6 changes: 5 additions & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ def quantize(
device: str = default_device,
) -> None:
assert checkpoint_path.is_file(), checkpoint_path
device = 'cpu'
precision = torch.bfloat16

print("Loading model ...")
Expand Down Expand Up @@ -621,4 +620,9 @@ def quantize(
parser.add_argument('--device', type=str, default=default_device, help='device to use')

args = parser.parse_args()
if "xpu" in args.device:
try:
import intel_extension_for_pytorch as ipex
except:
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)
47 changes: 42 additions & 5 deletions tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from model import Attention, FeedForward, Transformer
from quantize import WeightOnlyInt4Linear

Int4Device = "cpu"

def global_device(device: str = "cpu"):
global Int4Device
Int4Device = device


def _get_rank() -> int:
return int(os.environ.get("LOCAL_RANK", "0"))
Expand All @@ -33,7 +39,7 @@ def local_break():
def _get_world_size() -> int:
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))

def maybe_init_dist() -> Optional[int]:
def maybe_init_dist(device) -> Optional[int]:
try:
# provided by torchrun
rank = _get_rank()
Expand All @@ -46,8 +52,21 @@ def maybe_init_dist() -> Optional[int]:
# not run via torchrun, no-op
return None

torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
if "cuda" in device:
torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
elif "xpu" in device:
try:
import oneccl_bindings_for_pytorch
except:
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")

os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
os.environ['CCL_LOCAL_RANK'] = str(rank)

torch.xpu.set_device(rank)
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
Comment on lines +64 to +69
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move these lines inside "try" region.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do that.

return rank


Expand Down Expand Up @@ -83,14 +102,32 @@ def shard_qkv(qkv, dim, weight_splits):
assert len(weight_splits) == 3

if isinstance(linear, WeightOnlyInt4Linear):
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
if ("xpu" in Int4Device):
in_features = linear.in_features
out_features = linear.out_features//8
sharded_weight_size = list(linear.weight.size())
sharded_weight_size[shard_dim] = -1
weight = linear.weight.reshape((in_features, out_features))
sharded_weight = shard_qkv(weight, 1 - shard_dim, [i//8 for i in weight_splits])
sharded_weight = sharded_weight.reshape(sharded_weight_size)
else:
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits)
else:
sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard_qkv(linear.scales, 0, weight_splits)
else:
sharded_weight = shard(linear.weight, shard_dim)
if isinstance(linear, WeightOnlyInt4Linear) and ("xpu" in Int4Device):
in_features = linear.in_features
out_features = linear.out_features//8
sharded_weight_size = list(linear.weight.size())
sharded_weight_size[shard_dim] = -1
weight = linear.weight.reshape((in_features, out_features))
sharded_weight = shard(weight, 1 - shard_dim)
sharded_weight = sharded_weight.reshape(sharded_weight_size)
else:
sharded_weight = shard(linear.weight, shard_dim)
if isinstance(linear, WeightOnlyInt4Linear):
linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
if style == "rowwise":
Expand Down