Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions csrc/dequantize_kernel/dequantize_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass/array.h"
#include "dequantize_kernel.h"
#include <iostream>

template<typename result_type = cutlass::Array<half, 4>, typename source_type = cutlass::Array<uint8_t, 4>>
__device__ static result_type convert(source_type const& source)
{
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);

static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));

static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
return result;
}


__device__ const int reverse_permutation[32] = {0, 1, 4, 5, 8, 9, 12, 13,
2, 3, 6, 7, 10, 11, 14, 15,
16, 17, 20, 21, 24, 25, 28, 29,
18, 19, 22, 23, 26, 27, 30, 31};


template<const int THREAD_NUM=256, const int TILE_DIM = 32>
__global__ void transpose_row_interleave_scale(const half *A, half *B, const half *scale, int m, int n)
{
__shared__ half S[TILE_DIM][TILE_DIM + 1];
int bx = blockIdx.x * TILE_DIM;
int by = blockIdx.y * TILE_DIM;

const int NUM_ITERS = TILE_DIM * TILE_DIM / THREAD_NUM;
const int NUM_ROWS = THREAD_NUM / TILE_DIM;

#pragma unroll
for(int i = 0; i < NUM_ITERS; i++){
int thread_row = threadIdx.x / TILE_DIM + i * NUM_ROWS;
int thread_col = threadIdx.x % TILE_DIM;

int nx1 = bx + thread_col;
int ny1 = by + thread_row;
if(nx1 < n && ny1 < m)
S[thread_row][thread_col] = A[ny1 * n + nx1];
}
__syncthreads();
#pragma unroll
for(int i = 0; i < NUM_ITERS; i++){
int thread_row = threadIdx.x / TILE_DIM + i * NUM_ROWS;
int permuted_row = reverse_permutation[thread_row];
int thread_col = threadIdx.x % TILE_DIM;

int nx2 = by + thread_col;
int ny2 = bx + thread_row;
if (nx2 < m && ny2 < n)
{
B[ny2 * m + nx2] = S[thread_col][permuted_row] * scale[nx2];
}
}
}


__global__ void permute_cast(half* dequantized_weight, const uint8_t* weight, int m, int n){
half* out = &dequantized_weight[blockIdx.x * n];
const uint8_t* in = &weight[blockIdx.x * n];
const int interleave_block_size = 64;
const int block_per_row = n / interleave_block_size;


for(int i = threadIdx.x; i * 4 < n; i += blockDim.x){
cutlass::Array<half, 4> output;

int col_offset_global = i * 4;
int col_offset_local = col_offset_global % interleave_block_size;
int col_index = col_offset_global / interleave_block_size;
int global_index = blockIdx.x * block_per_row + col_index;
int is_second = global_index % 2;

int origin_row = global_index / (block_per_row * 2) * 2 + is_second;
int origin_col = (global_index / 2) % block_per_row;

output = convert(reinterpret_cast<const uint32_t*>(&in[col_offset_global])[0]);
uint64_t* t = reinterpret_cast<uint64_t*>(&output);

half* out = &dequantized_weight[origin_row * n];
reinterpret_cast<uint64_t*>(&out[col_offset_local + origin_col * interleave_block_size])[0] = t[0];
}
}



void invoke_dequantize(half* dequantized_weight,
const uint8_t* weight,
const half* scale,
int m,
int n)
{
half* tmp;
cudaMalloc(&tmp, m * n * sizeof(half));
dim3 block(std::min(256, m / 4));
dim3 grid(n);
permute_cast<<<grid, block>>>(tmp, weight, n, m);

constexpr int BLOCK_SZ = 32;
dim3 block_0(256);
dim3 grid_0((m + BLOCK_SZ - 1) / BLOCK_SZ, (n + BLOCK_SZ - 1) / BLOCK_SZ);
transpose_row_interleave_scale<<<grid_0, block_0>>>(tmp, dequantized_weight, scale, n, m);
// cudaMemcpy(dequantized_weight, tmp, m * n * sizeof(half), cudaMemcpyDeviceToDevice);
cudaFree(tmp);
}
3 changes: 3 additions & 0 deletions csrc/dequantize_kernel/dequantize_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

void dequantize_weight_cuda(torch::Tensor _dequantized_weight, torch::Tensor _weight, torch::Tensor _scale);
2 changes: 2 additions & 0 deletions csrc/eetpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "cutlass_kernels/fpA_intB_gemm_wrapper.h"
#include "embedding_kernels/pos_encoding.h"
#include "layernorm_kernels/layernorm.h"
#include "dequantize_kernel/dequantize_kernel.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
Expand All @@ -17,4 +18,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::arg("return_unprocessed_quantized_tensor") = false);
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
m.def("layernorm_forward", &layernorm_forward_cuda, "LayerNorm kernel");
m.def("dequantize_weight", &dequantize_weight_cuda, "Dequantize kernel");
}
1 change: 0 additions & 1 deletion csrc/utils/torch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <nvToolsExt.h>
#include <torch/custom_class.h>
#include <torch/script.h>
#include <vector>
Expand Down
147 changes: 147 additions & 0 deletions examples/layers/benchmark_dequantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import itertools
from typing import Optional, Tuple, Union


from regex import F
from sympy import RisingFactorial
import torch
import triton
from torch import nn
# from vllm import _custom_ops as vllm_ops
from EETQ import quant_weights, w8_a16_gemm, dequantize_weight


def w8a16_dequant(qweight, scale):
dtype = torch.float16
device = qweight.device
I = torch.eye(qweight.shape[0], dtype=dtype, device=device)
gemm_dequantized_weight = w8_a16_gemm(I, qweight, scale)
return gemm_dequantized_weight


def dequant_weight(qweight, scale):
dtype = torch.float16
device = qweight.device
dequantized_weight = torch.zeros_like(qweight, dtype=dtype, device=device)
dequantize_weight(dequantized_weight, qweight, scale)
return dequantized_weight


def calculate_diff(m, n):
dtype = torch.float16
device = "cuda"
weight = torch.randn(m, n, dtype=dtype)
qweight, scale = quant_weights(weight, torch.int8, False)
qweight = qweight.to(device)
scale = scale.to(device)


# Calculate the dequantized weight using w8_a16_gemm
gemm_dequantized_weight = w8a16_dequant(qweight, scale)


# Calculate the dequantized weight using dequantize_weight
dequantized_weight = dequant_weight(qweight, scale)


if torch.allclose(
gemm_dequantized_weight, dequantized_weight, atol=1e-2, rtol=1e-2
):
print(f"✅ ({m}, {n}) implementations match")
else:
print(f"❌ ({m}, {n}) Implementations differ")
del gemm_dequantized_weight
del dequantized_weight



M = [2048 + i * 1024 for i in range(0, 11, 2)]
N = [2048 + i * 1024 for i in range(0, 11, 2)]


configs = list(itertools.product(M, N))


def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["w8a16", "dequant"],
line_names=["w8a16", "dequant"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name=f"dequantized_performance",
args={},
)
)
def benchmark(m, n, provider):
dtype = torch.float16
device = "cuda"
weight = torch.randn(m, n, dtype=dtype)
qweight, scale = quant_weights(weight, torch.int8, False)
qweight = qweight.to(device)
scale = scale.to(device)


quantiles = [0.5, 0.2, 0.8]


if provider == "w8a16":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: w8a16_dequant(
qweight,
scale
),
quantiles=quantiles,
)
elif provider == "dequant":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: dequant_weight(
qweight,
scale
),
quantiles=quantiles,
)



return 1000 * ms, 1000 * max_ms, 1000 * min_ms


return benchmark



if __name__ == "__main__":
import argparse


parser = argparse.ArgumentParser()

parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/",
help="Path to save rmsnorm benchmark results",
)
args = parser.parse_args()



# M = [i * 256 for i in range(1, 50, 2)]
# N = [i * 256 for i in range(1, 50, 2)]


# for m in M:
# for n in N:
# calculate_diff(
# m, n
# )


# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
2 changes: 1 addition & 1 deletion examples/layers/test_w8a16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def set_random_seed(seed):
print(output)
print(torch.sum(output - out_torch))
t2 = time.perf_counter()
print("time: ", (t2 - t1) / 100)
print("time: ", (t2 - t1) / 100)
80 changes: 80 additions & 0 deletions examples/models/eetq_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import datasets
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
EetqConfig,
AutoModelForCausalLM
)
from peft import get_peft_model, LoraConfig, TaskType

def prepare_split(tokenizer):
data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
prompt_template = "<s>[INST] {system} {prompt} [/INST] {output}</s>"

def format_prompt(x):
return prompt_template.format(
system="",
prompt=x["instruction"],
output=x["output"]
)

data = data.map(
lambda x: {"text": format_prompt(x)},
).select_columns(["text"])
data = data.map(lambda x: tokenizer(x["text"]), batched=True)

return data

model_path = ""

# Load model
quantization_config = EetqConfig("int8")
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", quantization_config=quantization_config)

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

# Prepare data
data_train = prepare_split(tokenizer)

# Config Lora
lora_config = LoraConfig(
r=4,
lora_alpha=8,
lora_dropout=0.5,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

training_arguments = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=1,
optim="adamw_torch",
num_train_epochs=1,
learning_rate=1e-4,
# fp16=True,
evaluation_strategy="no",
save_strategy="epoch",
save_steps=100,
logging_steps=50,
eval_steps=None,
load_best_model_at_end=False
)

trainer = Trainer(
model=model,
train_dataset=data_train,
args=training_arguments,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()
trainer.save_model("output")

Loading