Skip to content

Commit

Permalink
Merge pull request ReaLLMASIC#138 from gkielian/add_linear_variations
Browse files Browse the repository at this point in the history
Add linear variations
  • Loading branch information
klei22 authored Apr 10, 2024
2 parents 948a8bd + c337068 commit e55f418
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 7 deletions.
20 changes: 20 additions & 0 deletions explorations/linear_sweep.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"max_iters": ["100000"],
"n_layer": ["6"],
"n_kv_group": ["6"],
"n_head": ["6"],
"n_embd": ["384"],
"block_size":["256"],
"eval_interval": ["250"],
"patience": ["10"],
"device": ["cuda"],
"dtype": ["float16"],
"dataset": ["shakespeare_char"],
"linear_variant": ["bitlinear_1p58", "bitlinear", "bitlinear_optimized", "linear"],
"compile": [true],
"softmax_variant_attn": ["softmax", "polymax"],
"tensorboard_run_name": ["linear_variation_sweep"]
}
]

19 changes: 13 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from variations.normalization_variations import LayerNorm, RMSNorm
from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions
from variations.activation_variations import SquaredReLU, activation_dictionary
from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary

def create_shared_param_group(layer_type, config):
shared_size = None
Expand Down Expand Up @@ -246,12 +247,15 @@ class MLP(nn.Module):

def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)

# Select linear variant
self.linear_variant = linear_dictionary[config.linear_variant]
self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias)

# Select activation variant
self.activation_variant = activation_dictionary[config.activation_variant]

self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
Expand Down Expand Up @@ -365,20 +369,23 @@ class GPTConfig:
exppolymax_divisor: float = 1.0

# Positional Embeddings Variations
use_abs_pos_embeddings: bool = False # Note: one can use this AND rotary embeddings
use_rotary_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding
use_abs_pos_embeddings: bool = True # Note: one can use this AND rotary embeddings
use_rotary_embeddings: bool = False # If True, uses rotary embeddings, else use conventional absolute position encoding
rope_variant: str = "rope" # options: "shortrope", "rope"
shortrope_length: int = 8 # number of embeddings to use in shortrope

# Structuring Options, remember to compile the model
use_post_ln: bool = True

# Layernorm Alternatives and Options
layernorm_variant: str = "rmsnorm" # Current options "rmsnorm" or "layernorm"
layernorm_variant: str = "rmsnorm"
bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

# Activation Alternatives
activation_variant: str = "gelu" # Current options "gelu", "relu", "squared_relu"
activation_variant: str = "gelu"

# Linear Alternatives
linear_variant: str = "linear"

class GPT(nn.Module):

Expand Down
15 changes: 14 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_args():
# Checkpoint args
training_group.add_argument('--only_save_checkpoint_at_end', action='store_true')
training_group.add_argument('--always_save_checkpoint', action='store_true')
training_group.add_argument('--patience', default=None, type=int)
training_group.add_argument('--patience', default=None, type=int, help="if set, will stop training if the number of evaluations since val loss was seen to decrease exceeds 'patience' setting.")
training_group.add_argument('--init_from', default='scratch', choices=['scratch', 'prev_run', 'resume', 'gpt2*'], type=str)
training_group.add_argument('--prev_run_ckpt', default='', type=str)
training_group.add_argument('--csv_ckpt_dir', default='', type=str)
Expand Down Expand Up @@ -95,6 +95,19 @@ def parse_args():
],
)

# LINEAR VARIATIONS
model_group.add_argument(
"--linear_variant",
type=str,
default="linear",
choices=[
"linear",
"bitlinear",
"bitlinear_1p58",
"bitlinear_optimized",
],
)

# POSITIONAL EMBEDDING VARIATIONS
model_group.add_argument('--use_rotary_embeddings', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument("--rope_variant", type=str, default="rope", choices=["shortrope", "rope"])
Expand Down
223 changes: 223 additions & 0 deletions variations/linear_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import torch
import torch.nn as nn
import math

class BitLinear1p58(nn.Linear):
""" BitLinear from Era of 1.58 LLMs Paper
Source: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py
Source License: MIT
Paper Link: https://arxiv.org/abs/2402.17764
"""

def __init__(self, in_features, out_features, bias=True, num_groups=1):
super().__init__(in_features, out_features, bias)

"""
RMSNorm is placed outside BitLinear
"""
weight_bits=1
input_bits=8
self.weight_bits = weight_bits
self.input_bits = input_bits

def forward(self, x):

quant_input = x + (self.activation_quant(x, self.input_bits) - x).detach()
quant_weight = self.weight + (self.weight_quant(self.weight, self.weight_bits) - self.weight).detach()

out = nn.functional.linear(quant_input, quant_weight)
if not self.bias is None:
out += self.bias.view(1, -1).expand_as(out)

return out

def weight_quant(self, weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)

def activation_quant(self, x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -2 ** (num_bits - 1)
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)

class BitLinear(nn.Linear):
"""PyTorch BitLinear Layer
Source: https://github.com/Beomi/BitNet-Transformers/tree/main
Source License: Apache Version 2.0
"""

def __init__(self, in_features, out_features, bias=True, num_groups=1):
super(BitLinear, self).__init__(in_features, out_features, bias)
self.num_groups = num_groups
self.eps = 1e-5

def ste_binarize(self, x):
# Apply the sign function for binarization
binarized_x = torch.sign(x)
# Use STE: during backward pass, we bypass the binarization
binarized_x = (binarized_x - x).detach() + x
return binarized_x

def binarize_weights_groupwise(self):
# Divide weights into groups
group_size = self.weight.shape[0] // self.num_groups
binarized_weights = torch.zeros_like(self.weight)

for g in range(self.num_groups):
start_idx = g * group_size
end_idx = (g + 1) * group_size
weight_group = self.weight[start_idx:end_idx]

# Binarize each group using STE
alpha_g = weight_group.mean()
binarized_weights[start_idx:end_idx] = self.ste_binarize(
weight_group - alpha_g
)

return binarized_weights

def quantize_activations_groupwise(self, x, b=8):
Q_b = 2 ** (b - 1)

# Divide activations into groups
group_size = x.shape[0] // self.num_groups
quantized_x = torch.zeros_like(x)

for g in range(self.num_groups):
start_idx = g * group_size
end_idx = (g + 1) * group_size
activation_group = x[start_idx:end_idx]

# Quantize each group
gamma_g = activation_group.abs().max()
quantized_x[start_idx:end_idx] = torch.clamp(
activation_group * Q_b / (gamma_g + self.eps),
-Q_b + self.eps,
Q_b - self.eps,
)

return quantized_x

def forward(self, input):
# Binarize weights (group-wise) using STE
binarized_weights = self.binarize_weights_groupwise()

# Normal linear transformation with binarized weights
output = torch.nn.functional.linear(input, binarized_weights, self.bias)

# Quantize activations group-wise
output = self.quantize_activations_groupwise(output)

return output


class BitLinearOptimized(nn.Linear):
"""Memory Optimized BitLinear Layer
Source: https://github.com/Beomi/BitNet-Transformers/tree/main
Source License: Apache Version 2.0
"""

def __init__(self, in_features, out_features, bias=True, num_groups=1):
super(BitLinearOptimized, self).__init__(in_features, out_features, bias)
self.num_groups = num_groups
self.eps = 1e-5

# Initialize 1-bit quantized weights and store them as int8
self.register_buffer(
"quantized_weights", torch.sign(self.weight.data).to(torch.int8)
)
# Clear the original weights to save memory
del self.weight

@property
def weight(self):
# Return the dequantized weights when accessed
return self.dequantize_weights()

@weight.setter
def weight(self, value):
# Update the quantized_weights when the weight property is set
self.quantized_weights.data = torch.sign(value).to(torch.int8)

def dequantize_weights(self):
# Convert quantized_weights back to bfloat16 and compute alpha for the weights
bfloat16_weights = self.quantized_weights.to(torch.bfloat16)
alpha = bfloat16_weights.mean()
return bfloat16_weights * alpha

def ste_binarize(self, x):
# Apply the sign function for binarization
binarized_x = torch.sign(x)
# Use STE: during backward pass, we bypass the binarization
binarized_x = (binarized_x - x).detach() + x
return binarized_x

def binarize_weights_groupwise(self):
# Dequantize the weights before binarization
weights = self.dequantize_weights()

# Divide weights into groups
group_size = weights.shape[0] // self.num_groups
binarized_weights = torch.zeros_like(weights)

for g in range(self.num_groups):
start_idx = g * group_size
end_idx = (g + 1) * group_size
weight_group = weights[start_idx:end_idx]

# Binarize each group using STE
alpha_g = weight_group.mean()
binarized_weights[start_idx:end_idx] = self.ste_binarize(
weight_group - alpha_g
)

return binarized_weights

def quantize_activations_groupwise(self, x, b=8):
Q_b = 2 ** (b - 1)

# Divide activations into groups
group_size = x.shape[0] // self.num_groups
quantized_x = torch.zeros_like(x)

for g in range(self.num_groups):
start_idx = g * group_size
end_idx = (g + 1) * group_size
activation_group = x[start_idx:end_idx]

# Quantize each group
gamma_g = activation_group.abs().max()
quantized_x[start_idx:end_idx] = torch.clamp(
activation_group * Q_b / (gamma_g + self.eps),
-Q_b + self.eps,
Q_b - self.eps,
)

return quantized_x

def forward(self, input):
# Binarize weights (group-wise) using STE
binarized_weights = self.binarize_weights_groupwise()

# Normal linear transformation with binarized weights
output = torch.nn.functional.linear(input, binarized_weights, self.bias)

# Quantize activations group-wise
output = self.quantize_activations_groupwise(output)

return output


linear_dictionary = {
"linear": nn.Linear,
"bitlinear": BitLinear,
"bitlinear_optimized": BitLinearOptimized,
"bitlinear_1p58": BitLinear1p58,
}

0 comments on commit e55f418

Please sign in to comment.