Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W4A8 based on CUTLASS #880

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexsamardzic
Copy link
Collaborator

@alexsamardzic alexsamardzic commented Sep 12, 2024

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/880

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit d6cf052 with merge base a6f8676 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @alexsamardzic!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Sep 12, 2024

The kernel implements W4A8 GEMM, with float16 scaling factors. The zero point support is to be eventually added later, for now several hacks (to be removed) are put in the code, that will force int8_dynamic_activation_int4_weight to do symmetric quantization for both activation and weight.

There are several points to discuss:

CUTLASS would have to be made a dependency. IMO, the best approach to satisfy the dependency would be to install nvidia-cutlass package, the only problem is that it doesn't always contain latest changes in CUTLASS. An alternative would be to have CUTLASS repo as submodule of this repo, like in PyTorch.

The group quantization may be a problem. Let's say X is input matrix of size MxK, with Xs vector of input scales of size M, and Wis weight matrix of size NxK. If group size parameter is equal to K, then weight scales Ws will be a vector of size N, and an element of output matrix Y of a linear operator would be calculated as follows (let's ignore bias for now, as it's not relevant):

$$y_{i,j}=\sum_{k}xs_{i}\cdot x_{i,k}\cdot w_{j,k}\cdot ws_{j}=xs_{i}\cdot ws_{j}\cdot \sum_{k}x_{i,k}\cdot w_{j,k}$$

The sum in the last expression could be efficiently calculated as mixed integer data types GEMM on tensor cores, and the result could be then updated by mulitplying the scale factors in. However, if group size parameter is less than K, say 32 for example (32 < K, K % 32 == 0), then weight scales will be matrix of size Nx(K/32). In this case, an element of output matrix Y of a linear operator would be calculated as follows:

$$y_{i,j}=\sum_{k}xs_{i}\cdot x_{i,k}\cdot w_{j,k}\cdot ws_{j,k/32}=xs_{i}\cdot \sum_{k}x_{i,k}\cdot w_{j,k}\cdot ws_{j,k/32}$$

Now, the only approach possible in CUTLASS to do this calculation in integer mixed data types on tensor cores would be to split it into K/32 GEMMs, and try to run them at the same time as so-called grouped GEMM. The code would be much more complicated, and also the update with the scaling factors will be still different for each of these individual GEMMs, so I don't think this approach would be performant. So my question here is: Does it make sense to create a quantization different than int8_dynamic_activation_int4_weight, that would match this kernel better, in particular that would not use group quantization for weight at all? (BTW, creating a new quantization, or at least adding a variant of int8_dynamic_activation_int4_weight is needed anyway, as this one is not packing two 4-bit weight values into a byte, that is required by CUTLASS for int8/int4 GEMM.)

Another related issue is zero point handling. Let's say Xz is vector of size M of input zero point values, and Wz is vector of size N of weight zero point values. Then the linear operator calculation, in PyTorch notation would be as follows: Y=((X-Xz)*Xs)@((W-Wz)*Ws).T (again, let's ignore bias), that translates into following calculation for an individual element of output matrix Y:

$$ \begin{array}{lcl} y_{i,j} & = & \sum_{k}xs_{i}\cdot (x_{i,k}-xz_{i})\cdot (w_{j,k}-wz_{j})\cdot ws_{j} \\ & = & xs_{i}\cdot ws_{j}\cdot (\sum_{k}x_{i,k}\cdot w_{j,k}-wz_{j}\sum_{k}a_{i,k}-xz_{i}\sum_{k}w_{k,j}+K\cdot xz_{i}\cdot wz_{j}) \\ \end{array} $$

Only the first expression within parentheses could be calculated on tensor cores as mixed integer data types GEMM, while the sums in the next two expression are best to be pre-calculated in case of weight values, or calculated on the fly during the input quantization. So it seems to me these are also calling for specialized type of quantization. (Note also that if group quantization used, above mentioned complications for Ws are extended to Wz too.)

All comments/suggestions welcome; in particular I'm pretty much new to quantization specifics so please let me know if I'm missing something obvious.

@msaroufim
Copy link
Member

I'm on PTO today and tomorrow so will review asap, apologies for the delay

@cpuhrsch
Copy link
Contributor

@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those? I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away.

We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here?

@alexsamardzic
Copy link
Collaborator Author

I'm on PTO today and tomorrow so will review asap, apologies for the delay

Thanks Mark - it's really just a draft, so not yet ready for review, but it would be useful to discuss points that I mentioned in my comment above.

@alexsamardzic
Copy link
Collaborator Author

@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those?

This CUTLASS version is also lagging behind. My CUTLASS PR with mixed int4/int8 GEMM is merged after the latest (3.5.1) CUTLASS release, hopefully there will be a new release soon. But in any case, this is a kind of problem that we'll have if we use more CUTLASS from torchao - for lots of time, the torchao build will have to be pointed to a bleeding edge CUTLASS checkout.

I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away.

It uses group size 128 in order to force weight scale to be a vector, and not a matrix. I tried to explain the issue in my comment above, if group quantization is obligatory here, then it's going to be rather complicated to make this work.

We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here?

I'm just looking into the quantization code, to see is it possible to do it there - it's not hard to make this change, but CUTLASS in general doesn't support doing things before GEMM (while fusing operations after GEMM calculated is reasonably well supported), so it would be the best if the quantization code actually put the weight values in int4x2 format.

@alexsamardzic
Copy link
Collaborator Author

Updated so that there is a new int8_dynamic_activation_int4_weight_cutlass quantization method available that, for now, would quantize both input and weight symmetrically, and won't use group quantization for weight (so weight scales are always a vector). It should be now possible to try kernel on arbitrary models, if quantized by above quantization method.

@@ -506,6 +508,41 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type)


def apply_int8_dynamic_activation_int4_weight_quant_cutlass(weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be represented as a different Layout for int8 dynamic activation/int4 weight quantization? docs for Packing/Layout can be found in #391 "Layout and Packing" and simplified example in https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! Yes, this will need refinement on this and several other places, as I learn about doing things the "torchao way"; but my main goal initially is to connect the dots, so that some benchmarks could be run, and that we could verify that CUTLASS provides some value here.

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 7 times, most recently from f6383ca to 02f8805 Compare September 17, 2024 08:26
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Sep 17, 2024

Made some minor updates, including added support for bfloat16.

Micro-benchmarking script
import copy

import torch

from torchao.utils import (
    TORCH_VERSION_AT_LEAST_2_5,
    unwrap_tensor_subclass,
)
from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int4_weight_cutlass,
)

# FIXME: change this!
_CUTLASS_DIR = ".../cutlass"


class ToyModel(torch.nn.Module):
    def __init__(self, nin, nout1, nout2):
        super().__init__()
        self.linear1 = torch.nn.Linear(nin, nout1)
        self.linear2 = torch.nn.Linear(nout1, nout2, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


methodq = int8_dynamic_activation_int4_weight_cutlass()
compile = False
dtype = torch.float16  # dtype = torch.bfloat16
device = "cuda"
bs, nin, nout1, nout2 = 256, 1024, 2048, 128

inputs = (torch.randn((1, bs, nin), dtype=dtype, device=device),)
model = ToyModel(nin, nout1, nout2).eval().to(dtype).to(device)
modelq = copy.deepcopy(model)

if compile:
    model = torch.compile(model, mode="max-autotune")

quantize_(modelq, methodq)
if not TORCH_VERSION_AT_LEAST_2_5:
    unwrap_tensor_subclass(modelq)

if compile:
    modelq = torch.compile(
        modelq,
        options={
            "max_autotune": True,
            "autotune_in_subproc": False,
            "max_autotune_gemm_backends": "Triton,CUTLASS",
            "cuda.cutlass_dir": _CUTLASS_DIR,
            "use_mixed_mm": True,
        },
    )


if __name__ == "__main__":
    from torchao.utils import benchmark_model

    nruns = 100
    torch._dynamo.reset()
    time = benchmark_model(model, nruns, inputs)
    timeq = benchmark_model(modelq, nruns, inputs)
    print(f"original model mean time  : {time:8.3f}")
    print(f"quantized model mean time : {timeq:8.3f}")
    print(f"speedup by quantization   : {time / timeq:8.3f}")

For particular shapes given in the script above, on A100 the micro-benchmark shows around 2x speedup over the case when float16 MM used, and around 1.8x speedup over the case when bfloat16 MM used. (Note that this is for eager mode execution, as compilation to corresponding CUTLASS kernel is not yet supported by PyTorch.)

Patch to run torchao/_models/llama/generate.py
diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py
index 5fb905d..e5b891b 100644
--- a/torchao/_models/llama/generate.py
+++ b/torchao/_models/llama/generate.py
@@ -206,6 +206,7 @@ def main(
             quantize_,
             int8_weight_only,
             int8_dynamic_activation_int8_weight,
+            int8_dynamic_activation_int4_weight_cutlass,
             int4_weight_only,
             fpx_weight_only,
             uintx_weight_only,
@@ -216,6 +217,8 @@ def main(
             quantize_(model, int8_weight_only())
         if "int8dq" in quantization:
             quantize_(model, int8_dynamic_activation_int8_weight())
+        if "w4a8-cutlass" in quantization:
+            quantize_(model, int8_dynamic_activation_int4_weight_cutlass())
         if "int4wo" in quantization:
             if "hqq" in quantization:
                 use_hqq=True
@@ -414,7 +417,7 @@ if __name__ == '__main__':
     parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
     parser.add_argument('-q', '--quantization', type=str, 
         help=(
-            'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+            'Which quantization techniques to apply: int8dq, w4a8-cutlass, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
             +'autoquant-int4, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
         )
     )
diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py
index 1df3549..1252bb8 100644
--- a/torchao/dtypes/affine_quantized_tensor.py
+++ b/torchao/dtypes/affine_quantized_tensor.py
@@ -1158,6 +1158,7 @@ implements = AffineQuantizedTensor.implements
 # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
 
 def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
+    return False
     return (
         isinstance(input_tensor, AffineQuantizedTensor) and
         _aqt_is_int8_reduced_range(input_tensor) and
diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py
index 3005cb1..451d0e6 100644
--- a/torchao/kernel/intmm.py
+++ b/torchao/kernel/intmm.py
@@ -54,6 +54,8 @@ if TORCH_VERSION_AT_LEAST_2_2:
             and k_is_nonzero_multiple_of_8
         )
 
+        bad_dimensions_for_cublas = False
+
         if device_cpu or bad_dimensions_for_cublas:
             # fallback path
             return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(

With the patch above, I was able to run Llama generator.py script. The command to run is as follows:

python generate.py -q w4a8-cutlass

and the output is as follows (again, this is run on A100):

==========
Average tokens/sec: 10.21
Average Bandwidth: 33.78 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

while the reference output, for the case when no arguments supplied to generate.py, is as follows:

==========
Average tokens/sec: 32.87
Average Bandwidth: 434.31 GB/s
Peak Memory Usage: 13.62 GB
Model Size: 13.21 GB

So the tokens/sec is more than 3x slower, but this is not even that bad, considering that batch size is 1 here, and that the CUTLASS code has it hard-coded for a block of threads to handle input tile size that is 128 for the same dimension, so most of the work is wasted.

So there is a room for improvement regarding the speed. The text generated is garbage, however. Even for the micro-benchmark above, output values visibly deviate from the values produced when native precision used (but at least they resemble each other).

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 2 times, most recently from 575e074 to 956fc80 Compare September 18, 2024 13:14
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Sep 18, 2024

Made an update - turns out that actually CUTLASS needs a fix (posted below for now), and then generate.py script for Llama model would generate meaningful content.

CUTLASS fix
diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
index 1692cc30..5a1b164c 100644
--- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
+++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
@@ -263,6 +263,44 @@ struct DefaultIteratorsTensorOp<
   static int const kFragmentsPerIteration = 2;
 };
 
+/// Partial specialization for bfloat16 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
+template <
+  typename ThreadblockShape,
+  typename WarpShape,
+  typename InstructionShape,
+  typename ThreadMap
+>
+struct DefaultIteratorsTensorOp<
+  bfloat16_t, 
+  int32_t, 
+  8, 
+  ThreadblockShape, 
+  WarpShape, 
+  InstructionShape, 
+  ThreadMap> {
+  
+  using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
+    WarpShape,
+    InstructionShape,
+    int32_t,
+    32,
+    16,
+    8,
+    8
+  >;
+
+  using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
+    ThreadMap,
+    int32_t,
+    32,
+    16,
+    8,
+    8
+  >;
+
+  static int const kFragmentsPerIteration = 2;
+};
+
 /// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
 /// Threadblock::kN = 256 still has bank conflicts.
 template <

On the other side, I tried with adapting tile sizes processed by block/warp of threads of corresponding CUTLASS kernel, in order to adapt to the fact that batch size is 1 here. Here is an example of such change:

+++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
@@ -418,8 +418,8 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale,
   using ElementA = int8_t;
   using ElementB = cutlass::int4b_t;
   using ElementAccumulator = int32_t;
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
-  using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
+  using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+  using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
   using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
   AT_DISPATCH_SWITCH(
     input_scale.scalar_type(),

However, tokens/sec is not much improved this way. Thus, the performance of this kernel for Llama model will require more work.

Edit: CUTLASS fix posted upstream here.

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will make a second pass for the kernel code

setup.py Outdated
@@ -65,6 +65,12 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
import cutlass_library
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting: not too familiar with cutlass packaging but what is cutlass_library exactly? only reference I found is this https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a recent addition to CUTLASS: a Python library that is able to generate C++ code for CUTLASS GEMM templates instantiation (which is nice to have, as these templates have dozen or more arguments, and it's oftentimes hard to get them right). It's used in CUTLASS codegen for TorchInductor, like here. However, recently CUTLASS itself also added a functionality to generate and compile C++ code for GEMM kernels, from a high-level specification in Python - this is part of cutlass Python package, see here. Both cutlass and cutlass_library are available through nvidia-cutlass pip package. It's important to note that this package also contains all of the CUTLASS C++ header files, in order to make it possible to compile the C++ generated kernels.

setup.py Outdated
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: what is this exactly? Do you need any help packaging CUTLASS?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this a bit in my first comment on this PR. In order ao to compile after this PR eventually merged, CUTLASS C++ header files are to be made available. There are at least two ways to do it:

  1. To make CUTLASS repo a submodule of ao repo, just like PyTorch did it.
  2. To make above mentioned nvidia-cutlass package a dependency of ao.

I'm leaning towards the later, and this is what above code, before "FIXME" is expecting. However, in both of above cases, we'll certainly face an issue of having to depend on stuff that is not yet merged into CUTLASS, but we need it. For example, at this very moment:

  1. My CUTLASS PR with int4/int8 GEMM support for CUTLASS is merged, but CUTLASS team has not made a release in the meantime, so this functionality is only available in CUTLASS main branch, and also above mentioned nvidia-cutlass package doesn't contain it yet.
  2. As mentioned in one of my comments above, while working in this PR, I found an omission in CUTLASS. I created a CUTLASS PR with a fix, but this one is not yet merged, so neither CUTLASS main branch nor nvidia-cutlass package contain the fix at the moment, it's only available in my branch. So the only way to proceed with the development of my PR was to create a local copy of this branch - I created it in /data/quansight/scratch/cutlass directory on my machine; in order to try this PR, the local copy of this branch is to be created, and this last line in the snippet above is to be changed to the local directory.

From my experience with this stuff from PyTorch development based on CUTLASS, this is going to be permanent issue - if we decide to use CUTLASS in ao, the for the most of the time we'll need bleeding edge features. So this is to be discussed further, IMO the best approach would be to build our own nvidia-cutlass package, from whatever CUTLASS branch we find the most appropriate.

@@ -85,6 +85,7 @@
"_get_subclass_inserter",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int4_weight_cutlass",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have some baseline numbers vs int8_dynamic_activation_int4_weight

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now when I have the dots connected, in the sense that I can run a micro-benchmark, and also Lllama model, using this kernel, I'm working on a more detailed profiling, part of this is also comparing the performance of this kernel with int8_dynamic_activation_int4_weight kernel. I'll report all my findings here when I'm done with the profiling.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a quick update here: Using the micro-benchmarking script above, it seems this PR is just 3-5% faster than int8_dynamic_activation_int4_weight. However, on the Llama generator, it seems about 2x faster, when tokens/sec numbers compared. (Remember that all the caveats from my first comment above still apply, so let's not jump into any conclusions for now.)

@@ -0,0 +1,51 @@
# FIXME: move this test to the appropriate test file!!!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah maybe make yourself a cutlass folder to park all your work

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Again, as mentioned in one of my comments above: At the moment, most of the "FIXME"-s in the PR are as I'm aware that I took shortcuts to make things work. If/when we're happy with the main stuff, I'll revisit all of these, and redo them in the proper "ao-way".

output_ref = model(input)

modelq = copy.deepcopy(model)
quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe another reference would be the non cutlass variant

# then corresponding changes made in
# _linear_int8_act_int4_weight_cutlass_check and for the check in
# the CUTLASS kernel!!!
weight.original_weight_tensor.layout_tensor.int_data = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment like

# Combine pairs of 4-bit values into single bytes
weight.original_weight_tensor.layout_tensor.int_data = (
    # Take odd-indexed columns, keep lower 4 bits, shift left by 4 bits
    (weight.original_weight_tensor.layout_tensor.int_data[:, 1::2] & 0xF) << 4
) | (
    # Take even-indexed columns, keep lower 4 bits
    weight.original_weight_tensor.layout_tensor.int_data[:, 0::2] & 0xF
)

"""
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant_cutlass)


def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated comment, what is this use_hqq? @jerryzh168 do you know?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this means use hqq algorithm to choose qparams and quantize the weight, since it is reusing the tinygemm kernel, we just added this as a separate option here

const int n = tensor_b.size(0);
const int k = tensor_a.size(1);

constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind adding a comment for why 128

Also how do you think about padding vs erroring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 128 bits here is because of how tensor cores work (so it's not CUTLASS-specific), at least for SM 8.x. It's related to the layout of tiles of matrix operands that single warp of thread is multiplying cooperatively. The best explanation that I found so far is in GTC 2020 talk, by CUTLASS team, around slide 15.

We can consider padding (maybe at the later stage?), I believe it would the best to incorporate padding together with the quantization.

using SmArch = cutlass::arch::Sm80;
using ThreadblockSwizzle =
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
constexpr auto NumStages = 4;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass n00b but how do you pick these hyperparams?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These, and others, are the CUTLASS GEMM C++ template arguments. As mentioned above, there is dozen of these to set, but on the other side only small number of combinations of these arguments actually works. The above mentioned cutlass_library package enumerates some of these working combinations. The CUTLASS itself doesn't include any sort of heuristic for selection of these parameters, for example based on GEMM operand shapes. So I had to hard-code some values, at least for now. The values selected here are based on my previous experimentation with different combinations, and different operand shapes - in the sense that these values should provide acceptable performance for number of cases. But certainly there are cases where these values are not good fit, Lllama inference, having batch size 1, is one such example. So we may want to consider adding some heuristic here, but on the longer term we'd probably prefer to do support some auto-tuning, just like what is possible with Triton kernels.

@alexsamardzic
Copy link
Collaborator Author

(Pushed an update, where the branch is just rebased on the latest main.)

I did lots of profiling in the meantime, focusing primarily on running Llama generator (torchao/_models/llama/generate.py), using tokens/sec as performance measure, and comparing between this PR and W8A8DQ case (i.e. when model quantized using int8_dynamic_activation_int8_weight). All of the results presented below were for A100 runs, the W8A8DQ run was as follows:

python generate.py -q int8dq

and the run for this PR was as follows (with the patch mentioned above applied beforehand):

python generate.py -q w4a8-cutlass

TLDR (note that each of these items could be verified by profiling W8A8DQ alone, without using this PR at all):

  1. The CUTLASS MM kernel in case of this PR, and also the Triton kernel MM for the W8A8DQ are not, at least at this moment, the most critical for performance. Instead, the other parts of the code, that are run each time along with the linear operator, are taking more execution time - see the remaining two items in the list.
  2. The dispatch checks registered here are re-run over and over again. These take considerable time, and also they make the performance depending on the position of registering the kernel, and corresponding check, in this list: if corresponding item moved between top and bottom of the list, the tokens/sec differ up to 10%. (@jerryzh168)
  3. The dynamic quantization takes considerable time too, more than the MM kernel itself. This could be improved, by working on fusing PyTorch operators used to perform quantization, or by implementing dedicated kernel(s) for dynamic quantization; also for Llama generator in particular by adjusting configs of these kernels to the fact that the number of inputs is 1. (Still, IMO it's questionable is there any performance benefit in using dynamic quantization vs. weight quantization only.)

As an example for item 1 above, here are the performance results, as printed by generate.py script, in case when item registering given kernel and check moved to the first place in the list:

python generate -q int8dq
# ... lots of output here
==========
Average tokens/sec: 4.81
Average Bandwidth: 31.83 GB/s
Peak Memory Usage: 14.86 GB
Model Size: 6.62 GB

python generate -q w4a8-cutlass
# ... lots of output here
==========
Average tokens/sec: 10.31
Average Bandwidth: 34.11 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

and when moved to the last place in the list:

python generate -q int8dq
# ... lots of output here
==========
Average tokens/sec: 4.35
Average Bandwidth: 28.82 GB/s
Peak Memory Usage: 14.86 GB
Model Size: 6.62 GB

==========
# ... lots of output here
Average tokens/sec: 9.92
Average Bandwidth: 32.82 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

The generator runs are profiled using pyinstrument, and verified using cProfile and nsys profilers. With the profiling run launched as follows:

python -m pyinstrument generate.py -q w4a8-cutlass

here is the relevant part of the pyinstrument output:

pyinstrument

So, for the attention segment of the model, one could see that everything related to running the linear operator takes about 34s in total. Out of this time, 24s are spend in the dynamic quantization, while about 9.4s only are spent on the linear operator itself, and then out of these 9.4s, only 2.4s are spent on the CUTLASS MM kernel execution, while the rest of time get spent on checking to which kernel to dispatch (note that for this run, the check for applicability of the CUTLASS kernel is added last to the list) - these checks are not visible in this snippet, as pyinstrument by default suppresses calls that take shorter time, but attached below is full pyinstrument output to verify it. The distribution of time spent is alike for the feed-forward part of the network - this could be also seen from the full output below.

Here is the pyinstrument --show-all ... output for the run above: pyinstrument.txt.

As mentioned above, profiling results are verified using cProfile and nsys. For example, for nsys run as follows:

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu -python-sampling=true $(which python) generate.py -q w4a8-cutlass

here is a screenshot of the timeline as shown by nsys-ui:

nsys-ui-1

Here, one could see that loading of model takes about 30s, then there is a short sequence of copying model to GPU and doing weights quantization, and then the rest of the timeline is the inference. The CUTLASS MM kernel, designated as Kernel2 here, takes less 30% of time of all of the CUDA kernels executed. If timeline zoomed into a segment of time during the inference, one could see that CUDA kernels are not actually executed tightly (because the checks and dynamic quantization are actually a sequence of calls to PyTorch kernels that are not fused):

nsys-ui-2

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Oct 9, 2024

@alexsamardzic - Was the model torch.compile'd with mode 'max-autotune'? Also you can use torch.profiler to generate kernel traces potentially a bit more quickly than with nsys (at least for rapid iteration). You can then open these with https://ui.perfetto.dev/

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 4 times, most recently from 01cfdca to 2a9e52a Compare December 5, 2024 12:48
@cpuhrsch cpuhrsch added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Dec 5, 2024
@cpuhrsch
Copy link
Contributor

cpuhrsch commented Dec 5, 2024

I don't see the job in the list of jobs for this PR.

on:
push:
branches:
- w4a8-cutlass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can enable this more generally to prevent regressions:

@alexsamardzic
Copy link
Collaborator Author

I don't see the job in the list of jobs for this PR.

I'm getting an email for this workflow saying something along the line that it could not be run across organizations; I'm going to check this...

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Dec 5, 2024

@alexsamardzic - Do we need to add a new job for this or can we ensure that CUTLASS builds and gets tested as part of the existing tests?

@@ -17,6 +17,7 @@
from .uintx import (
BlockSparseLayout,
Int4CPULayout,
Int4PackedLayout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we include cutlass in the name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do it, but layout is not specific to CUTLASS - it's just 4-bit values that are packed by two in a 8-bit tensor.

@@ -33,6 +33,16 @@ def _aqt_is_tensor_core_tile_uint4(aqt):
)


def _aqt_is_tensor_core_tile_int4(aqt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to rename this to remove tensor_core_tile in the name I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, and moved out of this file, as indeed it doesn't belong here.

@@ -738,6 +746,25 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
)


def _int8_symm_per_token_reduced_range_quant_cutlass(x: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably need to add this to

torch.serialization.add_safe_globals(
for serialization support, can you add cutlass int4 layout to so it's tested for serialization as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made both changes.

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 2 times, most recently from 4c7b5da to 10179d2 Compare December 6, 2024 10:26
@alexsamardzic
Copy link
Collaborator Author

alexsamardzic commented Dec 6, 2024

@alexsamardzic - Do we need to add a new job for this or can we ensure that CUTLASS builds and gets tested as part of the existing tests?

It will be tested both through regression and nightly workflows; it's only that until CUTLASS 3.6 released, and until we decide how to make CUTLASS a dependency (more on this below), CUTLASS headers has to be provided in a particular way. I removed the workflow I've added, and changed instead the regression workflow (regression_test.yml) temporarily so that it gets tested in CI.

@alexsamardzic
Copy link
Collaborator Author

About CUTLASS:

CUTLASS is a heavily templated C++ header library by NVIDIA, primarily implementing GEMM (but also convolutions etc.) on tensor cores, for various combination of operands. However, there is also some Python code in-there, that basically makes it possible to write given GEMM call in Python, and have corresponding C++ code implementing that call generated and compiled automatically. Thus, CUTLASS gets packaged as a Python library nvidia-cutlass, that contains both its Python modules and C++ header files.

For writing CUTLASS-based GEMM kernels, in our case it's typically needed to first extend CUTLASS, to support a combination of operands that we're interested in. As usual, the change is typically made first in a private fork of the CUTLASS repo, then it eventually gets merged into main branch of the CUTLASS repo, then after a new revision of CUTLASS is released it appears in cutlass-nvidia package too. Note that the delay of including changes is increasing along the way. For us, it's also relevant that PyTorch keeps CUTLASS repo as sub-module, and they have the sub-module pin moved from time to time, so the changes appear there with even more delay.

Now, the question is if we include CUTLASS-based kernels in torchao, how are we going to provide CUTLASS header files during the torchao build? This PR already changes setup.py so that building CUTLASS-based kernels is optional, depending on are the CUTLASS header file found when torchao built; but obviously if we add CUTLASS-based kernels to torchao, we'd eventually want them to be built and used. There are several options, each one with own trade-offs:

  1. Like PyTorch, we can make CUTLASS a sub-module of torchao repo. The advantage is that we could move the pin at our will, the disadvantage is that it complicates the structure of torchao repo. Furthermore, the question is: do we have sub-module pointing to some kind of our fork of CUTLASS repo (in which case we have our eventual CUTLASS changes available as soon as we made it), or to CUTLASS official repo (in which case we are still facing the delay of merging our eventual CUTLASS changes into the CUTLASS main)?
  2. We can add nvidia-cutlass package into the dev-requirements.txt. This PR is written having this approach in mind. IMO, this approach is the least hassle for torchao, but then we face an extended delay regarding using our eventual updates in CUTLASS upstream - we'll see them only when CUTLASS team makes a release. CUTLASS releases typically happen each two months, but that's not the hard rule - right now, it's more than three months since the last release. Another problem with this approach is that a developer waiting for a nvidia-cutlass package release with nededed updates, has in meantime to somehow point torchao setup.py to private copy of a CUTLASS tree, having the changes needed; same holds for CI testing (that's what I'm trying to do now with these temporary changes in the torchao CI workflow files), that makes the development cumbersome. Maybe a solution is that we build and host our own version of nvidia-cutlass package, that we'd base on whatever fork of CUTLASS repo that suits us?
  3. We can use CUTLASS provided by PyTorch. For @torch.compile auto-tuning, I think that PyTorch package should contain CUTLASS headers, but it seems that this is not the case at the moment; I'll have to check this more carefully. The problem with this approach would be that the delay of seeing CUTLASS updates that we need is longest - for example, at the moment PyTorch still has its CUTLASS pin set to version 3.4.1, that is rather old, and even when PyTorch team updates the pin, we'd have to wait for the next PyTorch release to have these changes available.

A variation of latest approach is to split this PR into two, and just move C++ part into PyTorch. The disadvantage remains the same, but torchao would then have no worries about CUTLASS. This may be worth consideration as there are already CUTLASS-based kernels in PyTorch that may be interesting for torchao (at least there are two of mine: 2:4 sparsity kernel and W8A16/W4A16 kernel - and I'm wondering, whatever we do, shall these sit together with this W4A8 kernel, be it in PyTorch or tochao code base?). Moreover, as already mentioned, PyTorch has some CUTLASS-based support for @torch.compile auto-tuning and we'll probably want to extend on this: CUTLASS-based GEMM kernels are in my experience oftentimes slightly faster than GEMM kernels that @torch.compile generates in Triton (on the other side, CUTLASS-based GEMM kernels have disadvantages here too: they are very slow to compile when compared to Triton, and cannot be fused with other code compiled to Triton). I don't think it's possible to do CUTLASS-based auto-tuning for torch.comple outside of PyTorch, so at least for this we'll have to employ CUTLASS from within PyTorch. Finally, unless we employ the latest approach mentioned, we'll have to take care not to clash with whatever part of CUTLASS comes within PyTorch package.

So there is no clear cut, and we'll have to be careful what we decide here... Any comments welcome.

@cpuhrsch @drisspg @jcaip

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Dec 6, 2024

So with torchao we can compile and ship our own CUDA binaries and since CUTLASS is header only why is option 1 so much more of a hassle than option 2? I don't think it's something that the users of our nightlies or stable releases will see.

If we want to prevent conflict with other versions of CUTLASS available, I think we should guard on the CUTLASS version within our own CUDA code, but that also seems reasonable to do. Otherwise there's a risk that we might compile different CUTLASS versions at once. Maybe there's also a risk we can overwrite symbols at link time. But the guards should help with that.

I don't think we should depend on a fork within the submodule, but I do think it's ok to land our own more complex modifications / copies for various CUTLASS templates for a specific kernel first and then eventually deduplicate again.

@alexsamardzic can you give option 1 a go and see if it works in CI?

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 3 times, most recently from 907f0ab to 0835f75 Compare December 7, 2024 19:44
@alexsamardzic
Copy link
Collaborator Author

@alexsamardzic can you give option 1 a go and see if it works in CI?

It works - this is kind of how I run the tests at the moment: I've changed .github/workflows/regression_test.yml to clone the main branch of CUTLASS repo into the torchao top-level directory, and then I've changed setup.py of torchao to find CUTLASS include files there (ofc, these changes are just temporarily). It could be seen here that all tests from test/test_s8s4_linear_cutlass.py pass.

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
@drisspg
Copy link
Contributor

drisspg commented Dec 10, 2024

Sorry for responding late, but yeah I am pro Option 1 and anti cutlass fork. I think we should reach out to the cutlass folks to get a sense of when the 3.6 release to come out

@alexsamardzic
Copy link
Collaborator Author

Pushed a variant that uses "option 1", let me check how it runs through CI.

@jcaip
Copy link
Contributor

jcaip commented Dec 10, 2024

Yeah +1 on including the CUTLASS headers in AO. I think it's advantageous to not be tied to PyTorch core for CUTLASS template updates. I know that Daniel is working on some custom kernels for 2:4 activation that need the latest CUTLASS version.

A variation of latest approach is to split this PR into two, and just move C++ part into PyTorch. The disadvantage remains the same, but torchao would then have no worries about CUTLASS. This may be worth consideration as there are already CUTLASS-based kernels in PyTorch that may be interesting for torchao (at least there are two of mine: 2:4 sparsity kernel and W8A16/W4A16 kernel - and I'm wondering, whatever we do, shall these sit together with this W4A8 kernel, be it in PyTorch or tochao code base?).

Do you have an idea on how this would affect 2:4 support in core? I don't think we have a way to expose torchao ops / kernels in core but I may be wrong here.

@alexsamardzic
Copy link
Collaborator Author

Do you have an idea on how this would affect 2:4 support in core? I don't think we have a way to expose torchao ops / kernels in core but I may be wrong here.

I guess some period of deprecation could be used where users are pointed to the same functionality to torchao - just like it was discussed recently, for moving quantization related functionality. On the other side, CUTLASS-based 2:4 sparsity operator is indeed somewhat coupled with cuSPARSELt-based one in the core; but maybe the whole SparseSemiStructuredTensor should be moved to torchao?

CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing
linear transformation over quantized 8-bit input and quantized 4-bit
weight tensors, with corresponding floating point scale tensors
attached.

A benchmark script, for comparing performance of MM based on this
linear operator with MM over 16-bit floating point tensors is supplied
in benchmarks/benchmarks/benchmark_s8s4_cutlass.py.

The Llama generator script torchao/_models/llama/generate.py is
changed, to add "int8adq-int4w-symm" quantization as an option, that
will in turn activate s8s4_linear_cutlass() operator.  With this type
of quantization activated, i.e. if generate.py script run as follows:

python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm

the generator achieves around 133 tok/sec on A100, vs. around 93
tok/sec without quantization, i.e. when generate.py script run as
follows:

python generate.py --compile --precision=torch.float16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants