Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization/Parameter] WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes #11622

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

cennn
Copy link
Contributor

@cennn cennn commented Dec 30, 2024

FIX: issue-10612, pull-10609

Problem:

Parameter subclasses are not fully compatible with the following aspects:
torch.compile: There are compatibility issues when using parameter subclasses in the context of torch.compile.
offloadedTensor: Parameter subclasses do not work well with tensor subclasses either.

Solution:

Remove all parameter subclasses and instead add the necessary properties and functions directly onto the raw nn.Parameter to achieve the required characteristics for quantization parameters. This approach mainly involves rewriting the code that defines and inherits parameter subclasses in the following way, and it requires minimal modifications to the parts of the code that call these parameter subclasses.

Example Code Changes:

Original Definition:

class PackedvLLMParameter(ModelWeightParameter):
    def __init__(self,
                 packed_factor: Union[int, Fraction],
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
            marlin_tile_size=self.marlin_tile_size)

New Definition:

def PackedvLLMParameter(data: torch.Tensor, **kwargs) -> Parameter:
    param = Parameter(data, requires_grad=False)
    wrap_base_vllm_parameter(param, **kwargs)
    wrap_column_vllm_parameter(param, **kwargs)
    wrap_row_vllm_parameter(param, **kwargs)
    wrap_packed_vllm_parameter(param, **kwargs)
    return param


def wrap_packed_vllm_parameter(param: Parameter,
                               packed_factor: Union[int, Fraction],
                               packed_dim: int,
                               marlin_tile_size: Optional[int] = None,
                               **kwargs) -> None:
    def adjust_shard_indexes_for_packing(shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=packed_factor,
            marlin_tile_size=marlin_tile_size)

    param.packed_factor = packed_factor
    param.packed_dim = packed_dim
    param.marlin_tile_size = marlin_tile_size
    param.adjust_shard_indexes_for_packing = adjust_shard_indexes_for_packing
    add_param_feature(param, Features.Packed)

Unchanged Call Sites:
The parts of the code that call these parameter subclasses do not need to be modified. For example:

qweight = PackedvLLMParameter(
    data=torch.empty(
        input_size_per_partition // self.quant_config.pack_factor,
        output_size_per_partition,
        dtype=torch.int32,
    ),
    input_dim=0,
    output_dim=1,
    packed_dim=0,
    packed_factor=self.quant_config.pack_factor,
    weight_loader=weight_loader)

Verified Tests:

vllm serve Qwen/Qwen2.5-0.5B-Instruct --quantization fp8
vllm serve Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4
vllm serve Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4 --quantization gptq
vllm serve Qwen/Qwen2-1.5B-Instruct-AWQ
vllm serve Qwen/Qwen2-1.5B-Instruct-AWQ --quantization awq
vllm serve nm-testing/tinyllama-oneshot-w4a16-channel-v2
vllm serve nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t
vllm serve nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@cennn cennn changed the title Replace parameter subclasses with raw nn.Parameter with additional attributes WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes Dec 30, 2024
@cennn cennn changed the title WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes [Quantization/Parameter] WIP: Replace parameter subclasses with raw nn.Parameter with additional attributes Dec 30, 2024
@youkaichao
Copy link
Member

As discussed, please fix the format.

@youkaichao
Copy link
Member

@dsikka can you please take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tracking torch.compile compatibility with cpu offloading
2 participants