diff --git a/explorations/linear_sweep.json b/explorations/linear_sweep.json new file mode 100644 index 0000000000..db68c4fbaa --- /dev/null +++ b/explorations/linear_sweep.json @@ -0,0 +1,20 @@ +[ + { + "max_iters": ["100000"], + "n_layer": ["6"], + "n_kv_group": ["6"], + "n_head": ["6"], + "n_embd": ["384"], + "block_size":["256"], + "eval_interval": ["250"], + "patience": ["10"], + "device": ["cuda"], + "dtype": ["float16"], + "dataset": ["shakespeare_char"], + "linear_variant": ["bitlinear_1p58", "bitlinear", "bitlinear_optimized", "linear"], + "compile": [true], + "softmax_variant_attn": ["softmax", "polymax"], + "tensorboard_run_name": ["linear_variation_sweep"] + } + ] + diff --git a/model.py b/model.py index 7214fb6488..3cef7f4dc0 100644 --- a/model.py +++ b/model.py @@ -21,6 +21,7 @@ from variations.normalization_variations import LayerNorm, RMSNorm from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions from variations.activation_variations import SquaredReLU, activation_dictionary +from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary def create_shared_param_group(layer_type, config): shared_size = None @@ -246,12 +247,15 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + + # Select linear variant + self.linear_variant = linear_dictionary[config.linear_variant] + self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias) # Select activation variant self.activation_variant = activation_dictionary[config.activation_variant] - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): @@ -365,8 +369,8 @@ class GPTConfig: exppolymax_divisor: float = 1.0 # Positional Embeddings Variations - use_abs_pos_embeddings: bool = False # Note: one can use this AND rotary embeddings - use_rotary_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding + use_abs_pos_embeddings: bool = True # Note: one can use this AND rotary embeddings + use_rotary_embeddings: bool = False # If True, uses rotary embeddings, else use conventional absolute position encoding rope_variant: str = "rope" # options: "shortrope", "rope" shortrope_length: int = 8 # number of embeddings to use in shortrope @@ -374,11 +378,14 @@ class GPTConfig: use_post_ln: bool = True # Layernorm Alternatives and Options - layernorm_variant: str = "rmsnorm" # Current options "rmsnorm" or "layernorm" + layernorm_variant: str = "rmsnorm" bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster # Activation Alternatives - activation_variant: str = "gelu" # Current options "gelu", "relu", "squared_relu" + activation_variant: str = "gelu" + + # Linear Alternatives + linear_variant: str = "linear" class GPT(nn.Module): diff --git a/train.py b/train.py index 34c2c03c2f..b702302304 100644 --- a/train.py +++ b/train.py @@ -37,7 +37,7 @@ def parse_args(): # Checkpoint args training_group.add_argument('--only_save_checkpoint_at_end', action='store_true') training_group.add_argument('--always_save_checkpoint', action='store_true') - training_group.add_argument('--patience', default=None, type=int) + training_group.add_argument('--patience', default=None, type=int, help="if set, will stop training if the number of evaluations since val loss was seen to decrease exceeds 'patience' setting.") training_group.add_argument('--init_from', default='scratch', choices=['scratch', 'prev_run', 'resume', 'gpt2*'], type=str) training_group.add_argument('--prev_run_ckpt', default='', type=str) training_group.add_argument('--csv_ckpt_dir', default='', type=str) @@ -95,6 +95,19 @@ def parse_args(): ], ) + # LINEAR VARIATIONS + model_group.add_argument( + "--linear_variant", + type=str, + default="linear", + choices=[ + "linear", + "bitlinear", + "bitlinear_1p58", + "bitlinear_optimized", + ], + ) + # POSITIONAL EMBEDDING VARIATIONS model_group.add_argument('--use_rotary_embeddings', default=False, action=argparse.BooleanOptionalAction) model_group.add_argument("--rope_variant", type=str, default="rope", choices=["shortrope", "rope"]) diff --git a/variations/linear_variations.py b/variations/linear_variations.py new file mode 100644 index 0000000000..141ccfec3e --- /dev/null +++ b/variations/linear_variations.py @@ -0,0 +1,223 @@ +import torch +import torch.nn as nn +import math + +class BitLinear1p58(nn.Linear): + """ BitLinear from Era of 1.58 LLMs Paper + Source: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py + Source License: MIT + Paper Link: https://arxiv.org/abs/2402.17764 + """ + + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super().__init__(in_features, out_features, bias) + + """ + RMSNorm is placed outside BitLinear + """ + weight_bits=1 + input_bits=8 + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, x): + + quant_input = x + (self.activation_quant(x, self.input_bits) - x).detach() + quant_weight = self.weight + (self.weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out + + def weight_quant(self, weight, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def activation_quant(self, x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -2 ** (num_bits - 1) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) + +class BitLinear(nn.Linear): + """PyTorch BitLinear Layer + Source: https://github.com/Beomi/BitNet-Transformers/tree/main + Source License: Apache Version 2.0 + """ + + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super(BitLinear, self).__init__(in_features, out_features, bias) + self.num_groups = num_groups + self.eps = 1e-5 + + def ste_binarize(self, x): + # Apply the sign function for binarization + binarized_x = torch.sign(x) + # Use STE: during backward pass, we bypass the binarization + binarized_x = (binarized_x - x).detach() + x + return binarized_x + + def binarize_weights_groupwise(self): + # Divide weights into groups + group_size = self.weight.shape[0] // self.num_groups + binarized_weights = torch.zeros_like(self.weight) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + weight_group = self.weight[start_idx:end_idx] + + # Binarize each group using STE + alpha_g = weight_group.mean() + binarized_weights[start_idx:end_idx] = self.ste_binarize( + weight_group - alpha_g + ) + + return binarized_weights + + def quantize_activations_groupwise(self, x, b=8): + Q_b = 2 ** (b - 1) + + # Divide activations into groups + group_size = x.shape[0] // self.num_groups + quantized_x = torch.zeros_like(x) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + activation_group = x[start_idx:end_idx] + + # Quantize each group + gamma_g = activation_group.abs().max() + quantized_x[start_idx:end_idx] = torch.clamp( + activation_group * Q_b / (gamma_g + self.eps), + -Q_b + self.eps, + Q_b - self.eps, + ) + + return quantized_x + + def forward(self, input): + # Binarize weights (group-wise) using STE + binarized_weights = self.binarize_weights_groupwise() + + # Normal linear transformation with binarized weights + output = torch.nn.functional.linear(input, binarized_weights, self.bias) + + # Quantize activations group-wise + output = self.quantize_activations_groupwise(output) + + return output + + +class BitLinearOptimized(nn.Linear): + """Memory Optimized BitLinear Layer + Source: https://github.com/Beomi/BitNet-Transformers/tree/main + Source License: Apache Version 2.0 + """ + + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super(BitLinearOptimized, self).__init__(in_features, out_features, bias) + self.num_groups = num_groups + self.eps = 1e-5 + + # Initialize 1-bit quantized weights and store them as int8 + self.register_buffer( + "quantized_weights", torch.sign(self.weight.data).to(torch.int8) + ) + # Clear the original weights to save memory + del self.weight + + @property + def weight(self): + # Return the dequantized weights when accessed + return self.dequantize_weights() + + @weight.setter + def weight(self, value): + # Update the quantized_weights when the weight property is set + self.quantized_weights.data = torch.sign(value).to(torch.int8) + + def dequantize_weights(self): + # Convert quantized_weights back to bfloat16 and compute alpha for the weights + bfloat16_weights = self.quantized_weights.to(torch.bfloat16) + alpha = bfloat16_weights.mean() + return bfloat16_weights * alpha + + def ste_binarize(self, x): + # Apply the sign function for binarization + binarized_x = torch.sign(x) + # Use STE: during backward pass, we bypass the binarization + binarized_x = (binarized_x - x).detach() + x + return binarized_x + + def binarize_weights_groupwise(self): + # Dequantize the weights before binarization + weights = self.dequantize_weights() + + # Divide weights into groups + group_size = weights.shape[0] // self.num_groups + binarized_weights = torch.zeros_like(weights) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + weight_group = weights[start_idx:end_idx] + + # Binarize each group using STE + alpha_g = weight_group.mean() + binarized_weights[start_idx:end_idx] = self.ste_binarize( + weight_group - alpha_g + ) + + return binarized_weights + + def quantize_activations_groupwise(self, x, b=8): + Q_b = 2 ** (b - 1) + + # Divide activations into groups + group_size = x.shape[0] // self.num_groups + quantized_x = torch.zeros_like(x) + + for g in range(self.num_groups): + start_idx = g * group_size + end_idx = (g + 1) * group_size + activation_group = x[start_idx:end_idx] + + # Quantize each group + gamma_g = activation_group.abs().max() + quantized_x[start_idx:end_idx] = torch.clamp( + activation_group * Q_b / (gamma_g + self.eps), + -Q_b + self.eps, + Q_b - self.eps, + ) + + return quantized_x + + def forward(self, input): + # Binarize weights (group-wise) using STE + binarized_weights = self.binarize_weights_groupwise() + + # Normal linear transformation with binarized weights + output = torch.nn.functional.linear(input, binarized_weights, self.bias) + + # Quantize activations group-wise + output = self.quantize_activations_groupwise(output) + + return output + + +linear_dictionary = { + "linear": nn.Linear, + "bitlinear": BitLinear, + "bitlinear_optimized": BitLinearOptimized, + "bitlinear_1p58": BitLinear1p58, +}