Skip to content

Commit 4ccebef

Browse files
committed
intel gpu : enable intel gpu
1 parent 30d69b3 commit 4ccebef

File tree

4 files changed

+73
-16
lines changed

4 files changed

+73
-16
lines changed

generate.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif ("cpu" in device) or ("mps" in device):
2022
pass
2123
else:
@@ -24,7 +26,8 @@ def device_sync(device):
2426

2527
torch._inductor.config.coordinate_descent_tuning = True
2628
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29+
if hasattr(torch._inductor.config, "fx_graph_cache"):
30+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2831

2932
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3033

@@ -271,7 +274,7 @@ def main(
271274

272275
global print
273276
from tp import maybe_init_dist
274-
rank = maybe_init_dist()
277+
rank = maybe_init_dist(device)
275278
use_tp = rank is not None
276279
if use_tp:
277280
if rank != 0:
@@ -303,7 +306,7 @@ def main(
303306
torch.manual_seed(1234)
304307
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
305308
if compile:
306-
if is_speculative and use_tp: # and ("cuda" in device):
309+
if is_speculative and use_tp and ("cuda" in device):
307310
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
308311

309312
if is_speculative:
@@ -354,8 +357,15 @@ def callback(x):
354357
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
355358
prof = contextlib.nullcontext()
356359
else:
357-
torch.profiler._utils._init_for_cuda_graphs()
358-
prof = torch.profiler.profile()
360+
if "cuda" in device:
361+
torch.profiler._utils._init_for_cuda_graphs()
362+
prof = torch.profiler.profile()
363+
elif "xpu" in device:
364+
prof = torch.profiler.profile(
365+
activities=[
366+
torch.profiler.ProfilerActivity.CPU,
367+
torch.profiler.ProfilerActivity.XPU],
368+
)
359369
with prof:
360370
y, metrics = generate(
361371
model,
@@ -419,6 +429,11 @@ def callback(x):
419429
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
420430

421431
args = parser.parse_args()
432+
if "xpu" in args.device:
433+
try:
434+
import intel_extension_for_pytorch as ipex
435+
except:
436+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
422437
main(
423438
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
424439
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,

mixtral-moe/generate.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19+
elif "xpu" in device:
20+
torch.xpu.synchronize(device)
1921
elif "cpu" in device:
2022
pass
2123
else:
@@ -24,7 +26,8 @@ def device_sync(device):
2426

2527
torch._inductor.config.coordinate_descent_tuning = True
2628
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29+
if hasattr(torch._inductor.config, "fx_graph_cache"):
30+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2831

2932

3033
# support running without installing as a package
@@ -178,7 +181,7 @@ def main(
178181
assert tokenizer_path.is_file(), str(tokenizer_path)
179182

180183
global print
181-
rank = maybe_init_dist()
184+
rank = maybe_init_dist(device)
182185
use_tp = rank is not None
183186
if use_tp:
184187
if rank != 0:
@@ -203,7 +206,8 @@ def main(
203206
torch.manual_seed(1234)
204207
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
205208
if compile:
206-
torch._inductor.config.assert_indirect_indexing = False
209+
if hasattr(torch._inductor.config, "assert_indirect_indexing"):
210+
torch._inductor.config.assert_indirect_indexing = False
207211

208212
global decode_one_token, prefill
209213
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
@@ -248,8 +252,15 @@ def callback(x):
248252
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
249253
prof = contextlib.nullcontext()
250254
else:
251-
torch.profiler._utils._init_for_cuda_graphs()
252-
prof = torch.profiler.profile()
255+
if "cuda" in device:
256+
torch.profiler._utils._init_for_cuda_graphs()
257+
prof = torch.profiler.profile()
258+
elif "xpu" in device:
259+
prof = torch.profiler.profile(
260+
activities=[
261+
torch.profiler.ProfilerActivity.CPU,
262+
torch.profiler.ProfilerActivity.XPU],
263+
)
253264
with prof:
254265
y = generate(
255266
model,
@@ -302,6 +313,11 @@ def callback(x):
302313
parser.add_argument('--device', type=str, default="cuda", help='device to use')
303314

304315
args = parser.parse_args()
316+
if "xpu" in args.device:
317+
try:
318+
import intel_extension_for_pytorch as ipex
319+
except:
320+
raise ModuleNotFoundError(f"Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run PyTorch code on Intel GPU (XPU). Please check https://github.com/intel/intel-extension-for-pytorch for details.")
305321
main(
306322
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
307323
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device

mixtral-moe/tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def local_break():
2828
def _get_world_size() -> int:
2929
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3030

31-
def maybe_init_dist() -> Optional[int]:
31+
def maybe_init_dist(device) -> Optional[int]:
3232
try:
3333
# provided by torchrun
3434
rank = _get_rank()
@@ -41,8 +41,21 @@ def maybe_init_dist() -> Optional[int]:
4141
# not run via torchrun, no-op
4242
return None
4343

44-
torch.cuda.set_device(rank)
45-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
44+
if "cuda" in device:
45+
torch.cuda.set_device(rank)
46+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47+
elif "xpu" in device:
48+
try:
49+
import oneccl_bindings_for_pytorch
50+
except:
51+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
52+
53+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
54+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
55+
os.environ['CCL_LOCAL_RANK'] = str(rank)
56+
57+
torch.xpu.set_device(rank)
58+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
4659
return rank
4760

4861
rank = _get_rank()

tp.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def local_break():
3333
def _get_world_size() -> int:
3434
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3535

36-
def maybe_init_dist() -> Optional[int]:
36+
def maybe_init_dist(device) -> Optional[int]:
3737
try:
3838
# provided by torchrun
3939
rank = _get_rank()
@@ -46,8 +46,21 @@ def maybe_init_dist() -> Optional[int]:
4646
# not run via torchrun, no-op
4747
return None
4848

49-
torch.cuda.set_device(rank)
50-
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
49+
if "cuda" in device:
50+
torch.cuda.set_device(rank)
51+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
52+
elif "xpu" in device:
53+
try:
54+
import oneccl_bindings_for_pytorch
55+
except:
56+
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")
57+
58+
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
59+
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
60+
os.environ['CCL_LOCAL_RANK'] = str(rank)
61+
62+
torch.xpu.set_device(rank)
63+
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
5164
return rank
5265

5366

0 commit comments

Comments
 (0)