diff --git a/csrc/dequantize_kernel/dequantize_kernel.cu b/csrc/dequantize_kernel/dequantize_kernel.cu new file mode 100644 index 0000000..8b5d01a --- /dev/null +++ b/csrc/dequantize_kernel/dequantize_kernel.cu @@ -0,0 +1,119 @@ +#include +#include +#include +#include +#include "cutlass/array.h" +#include "dequantize_kernel.h" +#include + +template, typename source_type = cutlass::Array> +__device__ static result_type convert(source_type const& source) +{ + result_type result; + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + return result; +} + + +__device__ const int reverse_permutation[32] = {0, 1, 4, 5, 8, 9, 12, 13, + 2, 3, 6, 7, 10, 11, 14, 15, + 16, 17, 20, 21, 24, 25, 28, 29, + 18, 19, 22, 23, 26, 27, 30, 31}; + + +template +__global__ void transpose_row_interleave_scale(const half *A, half *B, const half *scale, int m, int n) +{ + __shared__ half S[TILE_DIM][TILE_DIM + 1]; + int bx = blockIdx.x * TILE_DIM; + int by = blockIdx.y * TILE_DIM; + + const int NUM_ITERS = TILE_DIM * TILE_DIM / THREAD_NUM; + const int NUM_ROWS = THREAD_NUM / TILE_DIM; + + #pragma unroll + for(int i = 0; i < NUM_ITERS; i++){ + int thread_row = threadIdx.x / TILE_DIM + i * NUM_ROWS; + int thread_col = threadIdx.x % TILE_DIM; + + int nx1 = bx + thread_col; + int ny1 = by + thread_row; + if(nx1 < n && ny1 < m) + S[thread_row][thread_col] = A[ny1 * n + nx1]; + } + __syncthreads(); + #pragma unroll + for(int i = 0; i < NUM_ITERS; i++){ + int thread_row = threadIdx.x / TILE_DIM + i * NUM_ROWS; + int permuted_row = reverse_permutation[thread_row]; + int thread_col = threadIdx.x % TILE_DIM; + + int nx2 = by + thread_col; + int ny2 = bx + thread_row; + if (nx2 < m && ny2 < n) + { + B[ny2 * m + nx2] = S[thread_col][permuted_row] * scale[nx2]; + } + } +} + + +__global__ void permute_cast(half* dequantized_weight, const uint8_t* weight, int m, int n){ + half* out = &dequantized_weight[blockIdx.x * n]; + const uint8_t* in = &weight[blockIdx.x * n]; + const int interleave_block_size = 64; + const int block_per_row = n / interleave_block_size; + + + for(int i = threadIdx.x; i * 4 < n; i += blockDim.x){ + cutlass::Array output; + + int col_offset_global = i * 4; + int col_offset_local = col_offset_global % interleave_block_size; + int col_index = col_offset_global / interleave_block_size; + int global_index = blockIdx.x * block_per_row + col_index; + int is_second = global_index % 2; + + int origin_row = global_index / (block_per_row * 2) * 2 + is_second; + int origin_col = (global_index / 2) % block_per_row; + + output = convert(reinterpret_cast(&in[col_offset_global])[0]); + uint64_t* t = reinterpret_cast(&output); + + half* out = &dequantized_weight[origin_row * n]; + reinterpret_cast(&out[col_offset_local + origin_col * interleave_block_size])[0] = t[0]; + } +} + + + +void invoke_dequantize(half* dequantized_weight, + const uint8_t* weight, + const half* scale, + int m, + int n) +{ + half* tmp; + cudaMalloc(&tmp, m * n * sizeof(half)); + dim3 block(std::min(256, m / 4)); + dim3 grid(n); + permute_cast<<>>(tmp, weight, n, m); + + constexpr int BLOCK_SZ = 32; + dim3 block_0(256); + dim3 grid_0((m + BLOCK_SZ - 1) / BLOCK_SZ, (n + BLOCK_SZ - 1) / BLOCK_SZ); + transpose_row_interleave_scale<<>>(tmp, dequantized_weight, scale, n, m); + // cudaMemcpy(dequantized_weight, tmp, m * n * sizeof(half), cudaMemcpyDeviceToDevice); + cudaFree(tmp); +} \ No newline at end of file diff --git a/csrc/dequantize_kernel/dequantize_kernel.h b/csrc/dequantize_kernel/dequantize_kernel.h new file mode 100644 index 0000000..8ccbfbb --- /dev/null +++ b/csrc/dequantize_kernel/dequantize_kernel.h @@ -0,0 +1,3 @@ +#include + +void dequantize_weight_cuda(torch::Tensor _dequantized_weight, torch::Tensor _weight, torch::Tensor _scale); diff --git a/csrc/eetpy.cpp b/csrc/eetpy.cpp index 26fc26f..4ddb54a 100644 --- a/csrc/eetpy.cpp +++ b/csrc/eetpy.cpp @@ -3,6 +3,7 @@ #include "cutlass_kernels/fpA_intB_gemm_wrapper.h" #include "embedding_kernels/pos_encoding.h" #include "layernorm_kernels/layernorm.h" +#include "dequantize_kernel/dequantize_kernel.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -17,4 +18,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) py::arg("return_unprocessed_quantized_tensor") = false); m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); m.def("layernorm_forward", &layernorm_forward_cuda, "LayerNorm kernel"); + m.def("dequantize_weight", &dequantize_weight_cuda, "Dequantize kernel"); } \ No newline at end of file diff --git a/csrc/utils/torch_utils.h b/csrc/utils/torch_utils.h index 6c4429f..a059197 100644 --- a/csrc/utils/torch_utils.h +++ b/csrc/utils/torch_utils.h @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include diff --git a/examples/layers/benchmark_dequantize.py b/examples/layers/benchmark_dequantize.py new file mode 100644 index 0000000..c9509d2 --- /dev/null +++ b/examples/layers/benchmark_dequantize.py @@ -0,0 +1,147 @@ +import itertools +from typing import Optional, Tuple, Union + + +from regex import F +from sympy import RisingFactorial +import torch +import triton +from torch import nn +# from vllm import _custom_ops as vllm_ops +from EETQ import quant_weights, w8_a16_gemm, dequantize_weight + + +def w8a16_dequant(qweight, scale): + dtype = torch.float16 + device = qweight.device + I = torch.eye(qweight.shape[0], dtype=dtype, device=device) + gemm_dequantized_weight = w8_a16_gemm(I, qweight, scale) + return gemm_dequantized_weight + + +def dequant_weight(qweight, scale): + dtype = torch.float16 + device = qweight.device + dequantized_weight = torch.zeros_like(qweight, dtype=dtype, device=device) + dequantize_weight(dequantized_weight, qweight, scale) + return dequantized_weight + + +def calculate_diff(m, n): + dtype = torch.float16 + device = "cuda" + weight = torch.randn(m, n, dtype=dtype) + qweight, scale = quant_weights(weight, torch.int8, False) + qweight = qweight.to(device) + scale = scale.to(device) + + + # Calculate the dequantized weight using w8_a16_gemm + gemm_dequantized_weight = w8a16_dequant(qweight, scale) + + + # Calculate the dequantized weight using dequantize_weight + dequantized_weight = dequant_weight(qweight, scale) + + + if torch.allclose( + gemm_dequantized_weight, dequantized_weight, atol=1e-2, rtol=1e-2 + ): + print(f"✅ ({m}, {n}) implementations match") + else: + print(f"❌ ({m}, {n}) Implementations differ") + del gemm_dequantized_weight + del dequantized_weight + + + +M = [2048 + i * 1024 for i in range(0, 11, 2)] +N = [2048 + i * 1024 for i in range(0, 11, 2)] + + +configs = list(itertools.product(M, N)) + + +def get_benchmark(): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["m", "n"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["w8a16", "dequant"], + line_names=["w8a16", "dequant"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name=f"dequantized_performance", + args={}, + ) + ) + def benchmark(m, n, provider): + dtype = torch.float16 + device = "cuda" + weight = torch.randn(m, n, dtype=dtype) + qweight, scale = quant_weights(weight, torch.int8, False) + qweight = qweight.to(device) + scale = scale.to(device) + + + quantiles = [0.5, 0.2, 0.8] + + + if provider == "w8a16": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: w8a16_dequant( + qweight, + scale + ), + quantiles=quantiles, + ) + elif provider == "dequant": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: dequant_weight( + qweight, + scale + ), + quantiles=quantiles, + ) + + + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + + return benchmark + + + +if __name__ == "__main__": + import argparse + + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/", + help="Path to save rmsnorm benchmark results", + ) + args = parser.parse_args() + + + + # M = [i * 256 for i in range(1, 50, 2)] + # N = [i * 256 for i in range(1, 50, 2)] + + + # for m in M: + # for n in N: + # calculate_diff( + # m, n + # ) + + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark() + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) \ No newline at end of file diff --git a/examples/layers/test_w8a16_gemm.py b/examples/layers/test_w8a16_gemm.py index 487a23a..f154f31 100644 --- a/examples/layers/test_w8a16_gemm.py +++ b/examples/layers/test_w8a16_gemm.py @@ -56,4 +56,4 @@ def set_random_seed(seed): print(output) print(torch.sum(output - out_torch)) t2 = time.perf_counter() - print("time: ", (t2 - t1) / 100) + print("time: ", (t2 - t1) / 100) \ No newline at end of file diff --git a/examples/models/eetq_train.py b/examples/models/eetq_train.py new file mode 100644 index 0000000..0632af6 --- /dev/null +++ b/examples/models/eetq_train.py @@ -0,0 +1,80 @@ +import datasets +from transformers import ( + AutoTokenizer, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling, + EetqConfig, + AutoModelForCausalLM +) +from peft import get_peft_model, LoraConfig, TaskType + +def prepare_split(tokenizer): + data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train") + prompt_template = "[INST] {system} {prompt} [/INST] {output}" + + def format_prompt(x): + return prompt_template.format( + system="", + prompt=x["instruction"], + output=x["output"] + ) + + data = data.map( + lambda x: {"text": format_prompt(x)}, + ).select_columns(["text"]) + data = data.map(lambda x: tokenizer(x["text"]), batched=True) + + return data + +model_path = "" + +# Load model +quantization_config = EetqConfig("int8") +model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", quantization_config=quantization_config) + +tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token + +# Prepare data +data_train = prepare_split(tokenizer) + +# Config Lora +lora_config = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.5, + bias="none", + task_type=TaskType.CAUSAL_LM, + inference_mode=False +) + +model = get_peft_model(model, lora_config) + +model.print_trainable_parameters() + +training_arguments = TrainingArguments( + output_dir="./output", + per_device_train_batch_size=1, + optim="adamw_torch", + num_train_epochs=1, + learning_rate=1e-4, + # fp16=True, + evaluation_strategy="no", + save_strategy="epoch", + save_steps=100, + logging_steps=50, + eval_steps=None, + load_best_model_at_end=False +) + +trainer = Trainer( + model=model, + train_dataset=data_train, + args=training_arguments, + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), +) + +trainer.train() +trainer.save_model("output") + diff --git a/python/eetq/modules/qlinear.py b/python/eetq/modules/qlinear.py index a1aee89..fed8d49 100644 --- a/python/eetq/modules/qlinear.py +++ b/python/eetq/modules/qlinear.py @@ -8,7 +8,7 @@ import math from torch.autograd import Function -from EETQ import quant_weights, preprocess_weights, w8_a16_gemm +from EETQ import quant_weights, preprocess_weights, w8_a16_gemm, dequantize_weight def quantize_and_preprocess_weights(weight, scales=None): @@ -80,15 +80,12 @@ def forward( @staticmethod def backward(ctx, grad_output): input, weight, scales, bias = ctx.saved_tensors - identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype) - - # Dequantize the weight - weight = w8_a16_gemm(identity, weight, scales) - + dequantized_weight = torch.zeros_like(weight, dtype=torch.float16).to(weight.device) + dequantize_weight(dequantized_weight, weight, scales) if ctx.needs_input_grad[0]: # 2D matrix multiplication, unsqueeze to 3D grad_input = grad_output.squeeze(0).matmul( - weight.transpose(0, 1) + dequantized_weight.transpose(0, 1) ).unsqueeze(0) return grad_input, None, None, None diff --git a/setup.py b/setup.py index cdd9dd8..f03b612 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,8 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): "csrc/utils/cuda_utils.cc" ] custom_sources = ["csrc/embedding_kernels/pos_encoding_kernels.cu", - "csrc/layernorm_kernels/layernorm.cu" + "csrc/layernorm_kernels/layernorm.cu", + "csrc/dequantize_kernel/dequantize_kernel.cu" ] tensorrt_llm_sources = ["csrc/weightOnlyBatchedGemv/kernelLauncher.cu",