From 73374893cc9eda3cfbe1f8179476507cf2197f8e Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 15:27:23 -0700 Subject: [PATCH 1/8] Add an open-source bitlinear implementation --- variations/linear_variations.py | 209 ++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 variations/linear_variations.py diff --git a/variations/linear_variations.py b/variations/linear_variations.py new file mode 100644 index 0000000000..1b920f2e7f --- /dev/null +++ b/variations/linear_variations.py @@ -0,0 +1,209 @@ +# # coding=utf-8 +# # Copyright 2023 Beomi (L. Junbum) +# # Licensed under the Apache License, Version 2.0 (the "License") +""" PyTorch BitLinear Layer.""" +import torch +import torch.nn as nn + + +class BitLinearNaive(nn.Linear): + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super(BitLinearNaive, self).__init__(in_features, out_features, bias) + self.num_groups = num_groups + self.eps = 1e-5 # Small epsilon value to avoid division by zero and overflow + + def binarize_weights(self): + alpha = self.weight.mean() + binarized_weights = torch.sign(self.weight - alpha) + return binarized_weights + + def quantize_activations(self, x, b=8): + Q_b = 2 ** (b - 1) + gamma = x.abs().max() + quantized_x = torch.clamp( + x * Q_b / (gamma + self.eps), -Q_b + self.eps, Q_b - self.eps + ) + return quantized_x + + def scale_activations(self, x, b=8): + Q_b = 2 ** (b - 1) + eta = x.min() + gamma = x.abs().max() + scaled_x = torch.clamp( + (x - eta) * Q_b / (gamma + self.eps), self.eps, Q_b - self.eps + ) + return scaled_x + + def forward(self, input): + # Binarize weights + binarized_weights = self.binarize_weights() + + # Normal linear transformation with binarized weights + output = torch.nn.functional.linear(input, binarized_weights, self.bias) + + # Quantize activations (before non-linear functions like ReLU) + output = self.quantize_activations(output) + + # For the sake of demonstration, we'll also include the scaling step. + # In practice, this would be done before a non-linear function in a forward pass. + output = self.scale_activations(output) + + return output + + +class BitLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super(BitLinear, self).__init__(in_features, out_features, bias) + self.num_groups = num_groups + self.eps = 1e-5 + + def ste_binarize(self, x): + # Apply the sign function for binarization + binarized_x = torch.sign(x) + # Use STE: during backward pass, we bypass the binarization + binarized_x = (binarized_x - x).detach() + x + return binarized_x + + def binarize_weights_groupwise(self): + # Divide weights into groups + group_size = self.weight.shape[0] // self.num_groups + binarized_weights = torch.zeros_like(self.weight) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + weight_group = self.weight[start_idx:end_idx] + + # Binarize each group using STE + alpha_g = weight_group.mean() + binarized_weights[start_idx:end_idx] = self.ste_binarize( + weight_group - alpha_g + ) + + return binarized_weights + + def quantize_activations_groupwise(self, x, b=8): + Q_b = 2 ** (b - 1) + + # Divide activations into groups + group_size = x.shape[0] // self.num_groups + quantized_x = torch.zeros_like(x) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + activation_group = x[start_idx:end_idx] + + # Quantize each group + gamma_g = activation_group.abs().max() + quantized_x[start_idx:end_idx] = torch.clamp( + activation_group * Q_b / (gamma_g + self.eps), + -Q_b + self.eps, + Q_b - self.eps, + ) + + return quantized_x + + def forward(self, input): + # Binarize weights (group-wise) using STE + binarized_weights = self.binarize_weights_groupwise() + + # Normal linear transformation with binarized weights + output = torch.nn.functional.linear(input, binarized_weights, self.bias) + + # Quantize activations group-wise + output = self.quantize_activations_groupwise(output) + + return output + + +class BitLinearOptimized(nn.Linear): + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super(BitLinearOptimized, self).__init__(in_features, out_features, bias) + self.num_groups = num_groups + self.eps = 1e-5 + + # Initialize 1-bit quantized weights and store them as int8 + self.register_buffer( + "quantized_weights", torch.sign(self.weight.data).to(torch.int8) + ) + # Clear the original weights to save memory + del self.weight + + @property + def weight(self): + # Return the dequantized weights when accessed + return self.dequantize_weights() + + @weight.setter + def weight(self, value): + # Update the quantized_weights when the weight property is set + self.quantized_weights.data = torch.sign(value).to(torch.int8) + + def dequantize_weights(self): + # Convert quantized_weights back to bfloat16 and compute alpha for the weights + bfloat16_weights = self.quantized_weights.to(torch.bfloat16) + alpha = bfloat16_weights.mean() + return bfloat16_weights * alpha + + def ste_binarize(self, x): + # Apply the sign function for binarization + binarized_x = torch.sign(x) + # Use STE: during backward pass, we bypass the binarization + binarized_x = (binarized_x - x).detach() + x + return binarized_x + + def binarize_weights_groupwise(self): + # Dequantize the weights before binarization + weights = self.dequantize_weights() + + # Divide weights into groups + group_size = weights.shape[0] // self.num_groups + binarized_weights = torch.zeros_like(weights) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + weight_group = weights[start_idx:end_idx] + + # Binarize each group using STE + alpha_g = weight_group.mean() + binarized_weights[start_idx:end_idx] = self.ste_binarize( + weight_group - alpha_g + ) + + return binarized_weights + + def quantize_activations_groupwise(self, x, b=8): + Q_b = 2 ** (b - 1) + + # Divide activations into groups + group_size = x.shape[0] // self.num_groups + quantized_x = torch.zeros_like(x) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + activation_group = x[start_idx:end_idx] + + # Quantize each group + gamma_g = activation_group.abs().max() + quantized_x[start_idx:end_idx] = torch.clamp( + activation_group * Q_b / (gamma_g + self.eps), + -Q_b + self.eps, + Q_b - self.eps, + ) + + return quantized_x + + def forward(self, input): + # Binarize weights (group-wise) using STE + binarized_weights = self.binarize_weights_groupwise() + + # Normal linear transformation with binarized weights + output = torch.nn.functional.linear(input, binarized_weights, self.bias) + + # Quantize activations group-wise + output = self.quantize_activations_groupwise(output) + + return output From 64aad10e911220f4188e1e11282ac21b71620a55 Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 15:39:34 -0700 Subject: [PATCH 2/8] Add formatting and keep standard and opt versions This adds formatting, some comments for license and source, and python imports. --- variations/linear_variations.py | 60 ++++++--------------------------- 1 file changed, 11 insertions(+), 49 deletions(-) diff --git a/variations/linear_variations.py b/variations/linear_variations.py index 1b920f2e7f..0b0645cf7f 100644 --- a/variations/linear_variations.py +++ b/variations/linear_variations.py @@ -1,57 +1,14 @@ -# # coding=utf-8 -# # Copyright 2023 Beomi (L. Junbum) -# # Licensed under the Apache License, Version 2.0 (the "License") -""" PyTorch BitLinear Layer.""" import torch import torch.nn as nn - - -class BitLinearNaive(nn.Linear): - def __init__(self, in_features, out_features, bias=True, num_groups=1): - super(BitLinearNaive, self).__init__(in_features, out_features, bias) - self.num_groups = num_groups - self.eps = 1e-5 # Small epsilon value to avoid division by zero and overflow - - def binarize_weights(self): - alpha = self.weight.mean() - binarized_weights = torch.sign(self.weight - alpha) - return binarized_weights - - def quantize_activations(self, x, b=8): - Q_b = 2 ** (b - 1) - gamma = x.abs().max() - quantized_x = torch.clamp( - x * Q_b / (gamma + self.eps), -Q_b + self.eps, Q_b - self.eps - ) - return quantized_x - - def scale_activations(self, x, b=8): - Q_b = 2 ** (b - 1) - eta = x.min() - gamma = x.abs().max() - scaled_x = torch.clamp( - (x - eta) * Q_b / (gamma + self.eps), self.eps, Q_b - self.eps - ) - return scaled_x - - def forward(self, input): - # Binarize weights - binarized_weights = self.binarize_weights() - - # Normal linear transformation with binarized weights - output = torch.nn.functional.linear(input, binarized_weights, self.bias) - - # Quantize activations (before non-linear functions like ReLU) - output = self.quantize_activations(output) - - # For the sake of demonstration, we'll also include the scaling step. - # In practice, this would be done before a non-linear function in a forward pass. - output = self.scale_activations(output) - - return output +import math class BitLinear(nn.Linear): + """PyTorch BitLinear Layer + Source: https://github.com/Beomi/BitNet-Transformers/tree/main + Source License: Apache Version 2.0 + """ + def __init__(self, in_features, out_features, bias=True, num_groups=1): super(BitLinear, self).__init__(in_features, out_features, bias) self.num_groups = num_groups @@ -118,6 +75,11 @@ def forward(self, input): class BitLinearOptimized(nn.Linear): + """Memory Optimized BitLinear Layer + Source: https://github.com/Beomi/BitNet-Transformers/tree/main + Source License: Apache Version 2.0 + """ + def __init__(self, in_features, out_features, bias=True, num_groups=1): super(BitLinearOptimized, self).__init__(in_features, out_features, bias) self.num_groups = num_groups From 1f10fe6f08ced36e3ab375f00b571fa55a2d5d13 Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 15:49:51 -0700 Subject: [PATCH 3/8] Add dictionary method for selecting linear type --- model.py | 15 +++++++++++---- train.py | 12 ++++++++++++ variations/linear_variations.py | 6 ++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/model.py b/model.py index 45e1749dcb..8721b4a071 100644 --- a/model.py +++ b/model.py @@ -21,6 +21,7 @@ from variations.normalization_variations import LayerNorm, RMSNorm from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions from variations.activation_variations import SquaredReLU, activation_dictionary +from variations.linear_variations import BitLinear, BitLinearOptimized, linear_dictionary def create_shared_param_group(layer_type, config): shared_size = None @@ -246,12 +247,15 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + + # Select linear variant + self.linear_variant = linear_dictionary[config.linear_variant] + self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias) # Select activation variant self.activation_variant = activation_dictionary[config.activation_variant] - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): @@ -362,11 +366,14 @@ class GPTConfig: use_post_ln: bool = True # Layernorm Alternatives and Options - layernorm_variant: str = "rmsnorm" # Current options "rmsnorm" or "layernorm" + layernorm_variant: str = "rmsnorm" bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster # Activation Alternatives - activation_variant: str = "gelu" # Current options "gelu", "relu", "squared_relu" + activation_variant: str = "gelu" + + # Linear Alternatives + linear_variant: str = "linear" class GPT(nn.Module): diff --git a/train.py b/train.py index 2d933b801c..bf9b516092 100644 --- a/train.py +++ b/train.py @@ -94,6 +94,18 @@ def parse_args(): ], ) + # LINEAR VARIATIONS + model_group.add_argument( + "--linear_variant", + type=str, + default="linear", + choices=[ + "linear", + "bitlinear", + "bitlinear_optimized", + ], + ) + # POSITIONAL EMBEDDING VARIATIONS model_group.add_argument('--use_rotary_embeddings', default=False, action=argparse.BooleanOptionalAction) model_group.add_argument("--rope_variant", type=str, default="rope", choices=["shortrope", "rope"]) diff --git a/variations/linear_variations.py b/variations/linear_variations.py index 0b0645cf7f..2308e49694 100644 --- a/variations/linear_variations.py +++ b/variations/linear_variations.py @@ -2,6 +2,12 @@ import torch.nn as nn import math +linear_dictionary = { + "linear": nn.Linear(), + "bitlinear": BitLinear(), + "bitlinear_optimized": BitLinearOptimized(), +} + class BitLinear(nn.Linear): """PyTorch BitLinear Layer From de71e63575186e178d8adc5386050d0e7b49a88f Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 15:54:32 -0700 Subject: [PATCH 4/8] Set default to abs emb true rotary false This increases test speed. --- model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index 8721b4a071..a1883c9536 100644 --- a/model.py +++ b/model.py @@ -357,8 +357,8 @@ class GPTConfig: exppolymax_divisor: float = 1.0 # Positional Embeddings Variations - use_abs_pos_embeddings: bool = False # Note: one can use this AND rotary embeddings - use_rotary_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding + use_abs_pos_embeddings: bool = True # Note: one can use this AND rotary embeddings + use_rotary_embeddings: bool = False # If True, uses rotary embeddings, else use conventional absolute position encoding rope_variant: str = "rope" # options: "shortrope", "rope" shortrope_length: int = 8 # number of embeddings to use in shortrope From 821829c132e6e3ddfa93c9ff58a0902f0f372b65 Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 15:59:20 -0700 Subject: [PATCH 5/8] Add help message to patience option --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index bf9b516092..4a6a2a24e0 100644 --- a/train.py +++ b/train.py @@ -37,7 +37,7 @@ def parse_args(): # Checkpoint args training_group.add_argument('--only_save_checkpoint_at_end', action='store_true') training_group.add_argument('--always_save_checkpoint', action='store_true') - training_group.add_argument('--patience', default=None, type=int) + training_group.add_argument('--patience', default=None, type=int, help="if set, will stop training if the number of evaluations since val loss was seen to decrease exceeds 'patience' setting.") training_group.add_argument('--init_from', default='scratch', choices=['scratch', 'prev_run', 'resume', 'gpt2*'], type=str) training_group.add_argument('--prev_run_ckpt', default='', type=str) training_group.add_argument('--csv_ckpt_dir', default='', type=str) From c7fff413f025d66faf2637d44d94770051f18c9b Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 16:02:24 -0700 Subject: [PATCH 6/8] Add exploration sweep across linear variations --- explorations/linear_sweep.json | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 explorations/linear_sweep.json diff --git a/explorations/linear_sweep.json b/explorations/linear_sweep.json new file mode 100644 index 0000000000..4e106c3c3d --- /dev/null +++ b/explorations/linear_sweep.json @@ -0,0 +1,20 @@ +[ + { + "max_iters": ["100000"], + "n_layer": ["6"], + "n_kv_group": ["6"], + "n_head": ["6"], + "n_embd": ["384"], + "block_size":["256"], + "eval_interval": ["250"], + "patience": ["10"], + "device": ["cuda"], + "dtype": ["float16"], + "dataset": ["shakespeare_char"], + "linear_variant": ["bitlinear", "bitlinear_optimized", "linear"], + "compile": [true], + "softmax_variant_attn": ["softmax", "polymax"], + "tensorboard_run_name": ["linear_variation_sweep"] + } + ] + From 5290cd41bd44c72f5b2c68ce064f9b596de487ae Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 16:16:05 -0700 Subject: [PATCH 7/8] Moved linear dictionary to bottom for ordering This allows dictionary to see classes. --- variations/linear_variations.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/variations/linear_variations.py b/variations/linear_variations.py index 2308e49694..328f21f391 100644 --- a/variations/linear_variations.py +++ b/variations/linear_variations.py @@ -2,12 +2,6 @@ import torch.nn as nn import math -linear_dictionary = { - "linear": nn.Linear(), - "bitlinear": BitLinear(), - "bitlinear_optimized": BitLinearOptimized(), -} - class BitLinear(nn.Linear): """PyTorch BitLinear Layer @@ -175,3 +169,10 @@ def forward(self, input): output = self.quantize_activations_groupwise(output) return output + + +linear_dictionary = { + "linear": nn.Linear, + "bitlinear": BitLinear, + "bitlinear_optimized": BitLinearOptimized, +} From c3370686eaf7601813f20c5999ae3ac85bb8d85e Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 20:35:27 -0700 Subject: [PATCH 8/8] Add Era of 1.58 bit LLMs BitLinear implementation Adding MIT Licensed ternary implementation of BitLinear: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py Ternary BitLinear Arxiv Paper Link: https://arxiv.org/abs/2402.17764 --- explorations/linear_sweep.json | 2 +- model.py | 2 +- train.py | 1 + variations/linear_variations.py | 45 +++++++++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/explorations/linear_sweep.json b/explorations/linear_sweep.json index 4e106c3c3d..db68c4fbaa 100644 --- a/explorations/linear_sweep.json +++ b/explorations/linear_sweep.json @@ -11,7 +11,7 @@ "device": ["cuda"], "dtype": ["float16"], "dataset": ["shakespeare_char"], - "linear_variant": ["bitlinear", "bitlinear_optimized", "linear"], + "linear_variant": ["bitlinear_1p58", "bitlinear", "bitlinear_optimized", "linear"], "compile": [true], "softmax_variant_attn": ["softmax", "polymax"], "tensorboard_run_name": ["linear_variation_sweep"] diff --git a/model.py b/model.py index a1883c9536..c3e14f4524 100644 --- a/model.py +++ b/model.py @@ -21,7 +21,7 @@ from variations.normalization_variations import LayerNorm, RMSNorm from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions from variations.activation_variations import SquaredReLU, activation_dictionary -from variations.linear_variations import BitLinear, BitLinearOptimized, linear_dictionary +from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary def create_shared_param_group(layer_type, config): shared_size = None diff --git a/train.py b/train.py index 4a6a2a24e0..d01e922593 100644 --- a/train.py +++ b/train.py @@ -102,6 +102,7 @@ def parse_args(): choices=[ "linear", "bitlinear", + "bitlinear_1p58", "bitlinear_optimized", ], ) diff --git a/variations/linear_variations.py b/variations/linear_variations.py index 328f21f391..141ccfec3e 100644 --- a/variations/linear_variations.py +++ b/variations/linear_variations.py @@ -2,6 +2,50 @@ import torch.nn as nn import math +class BitLinear1p58(nn.Linear): + """ BitLinear from Era of 1.58 LLMs Paper + Source: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py + Source License: MIT + Paper Link: https://arxiv.org/abs/2402.17764 + """ + + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super().__init__(in_features, out_features, bias) + + """ + RMSNorm is placed outside BitLinear + """ + weight_bits=1 + input_bits=8 + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, x): + + quant_input = x + (self.activation_quant(x, self.input_bits) - x).detach() + quant_weight = self.weight + (self.weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out + + def weight_quant(self, weight, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def activation_quant(self, x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -2 ** (num_bits - 1) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) class BitLinear(nn.Linear): """PyTorch BitLinear Layer @@ -175,4 +219,5 @@ def forward(self, input): "linear": nn.Linear, "bitlinear": BitLinear, "bitlinear_optimized": BitLinearOptimized, + "bitlinear_1p58": BitLinear1p58, }