Skip to content

Conversation

@xiaowangintel
Copy link

@xiaowangintel xiaowangintel commented Jan 10, 2024

This PR adds the initial Intel GPU support in GPT-fast with the device option "xpu" (i.e., --device "xpu"). Both single device and multi-device via tensor parallel are supported functionally while performance is still being improved. Refer to the following steps to run the generation on Intel GPU. We will update the tutorial later with improved performance.

Installation

  1. Install pytorch and Intel® Extension for PyTorch:
    https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/introduction.html#
  2. install oneCCL for distributed:
    https://github.com/oneapi-src/oneCCL
  3. install Intel® Extension for Triton (needed by torch.compile):
    https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/features/torch_compile_gpu.html

How to run gpt-fast code on intel GPUs?

  1. command for single device:
    python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --speculate_k 5 --prompt "Hi my name is" --device xpu
  2. command for multiple devices via Tensor Parallelism:
    ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --device xpu

Note:

  1. Please export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0, a temporary configuration, to avoid unnecessary errors, when runs gpt-fast code with torch.compile.
  2. Please export IPEX_ZE_TRACING=1, a temporary configuration, to get event, when runs gpt-fast code with profile.
  3. Currently, only bf16 is supported, and int4/int8 will be supported later via IPEX without requiring code changes in gpt-fast.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 10, 2024
Comment on lines +53 to +69
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
os.environ['CCL_LOCAL_RANK'] = str(rank)

torch.xpu.set_device(rank)
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move these lines inside "try" region.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do that.

generate.py Outdated
Comment on lines 16 to 19
try:
import intel_extension_for_pytorch as ipex
except:
pass
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to move this into main and make it a conditional import when the user selects "xpu" device. Raise error when things go wrong.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do that.

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
if hasattr(torch._inductor.config, "fx_graph_cache"):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intel GPU currently uses a PyTorch fork based on 2.1 which doesn't have fx_graph_cache yet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment.

generate.py Outdated
Comment on lines 369 to 371
record_shapes=True,
profile_memory=False,
with_stack=True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these extra configurations?

Copy link
Author

@xiaowangintel xiaowangintel Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was removed.

@jgong5
Copy link

jgong5 commented Jan 10, 2024

Please add to the PR description 1) how to build/install the pre-requisite software components; 2) how to run inference with and without tensor parallel.

generate.py Outdated
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
#Intel GPU currently uses a PyTorch fork based on 2.1 which doesn't have fx_graph_cache yet.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#Intel GPU currently uses a PyTorch fork based on 2.1 which doesn't have fx_graph_cache yet.
# To support devices (like Intel GPU) which still use PyTorch 2.1 that doesn't have fx_graph_cache yet.

generate.py Outdated
try:
import intel_extension_for_pytorch as ipex
except:
raise ModuleNotFoundError(f"No module named 'intel_extension_for_pytorch'")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ModuleNotFoundError(f"No module named 'intel_extension_for_pytorch'")
raise ModuleNotFoundError("Intel Extension for PyTorch (intel_extension_for_pytorch) is required to run Intel GPU on the XPU device. Please check https://github.com/intel/intel-extension-for-pytorch for details.")

tp.py Outdated
try:
import oneccl_bindings_for_pytorch
except:
raise ModuleNotFoundError(f"No module named 'oneccl_bindings_for_pytorch'")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ModuleNotFoundError(f"No module named 'oneccl_bindings_for_pytorch'")
raise ModuleNotFoundError(f"OneCCL bindings for PyTorch (oneccl_bindings_for_pytorch) is required to run tensor parallel on Intel GPU (XPU). Please check https://github.com/intel/torch-ccl for details.")

tp.py Outdated
Comment on lines 54 to 62
try:
os.environ['CCL_PROCESS_LAUNCHER'] = 'none'
os.environ['CCL_LOCAL_SIZE'] = str(world_size)
os.environ['CCL_LOCAL_RANK'] = str(rank)

torch.xpu.set_device(rank)
dist.init_process_group(backend="ccl", rank=rank, world_size=world_size)
except:
raise ValueError(f"Failed to initialize 'ccl'")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need try-catch here? CUDA doesn't need try-catch.

@jgong5
Copy link

jgong5 commented Jan 12, 2024

@Chillee This is the initial PR to support Intel GPU. Most needed code changes should be there. Further performance optimizations will be applied inside IPEX. May I ask your review? Thanks!

@xiaowangintel xiaowangintel force-pushed the main branch 4 times, most recently from 0eaf4b5 to 3af7b93 Compare March 19, 2024 02:42
@xiaowangintel xiaowangintel force-pushed the main branch 2 times, most recently from 4ccebef to ff223c8 Compare May 6, 2024 08:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants