Skip to content

g1y5x3/bitlinear158-vision-transformers

Repository files navigation

BitLinear-Vision-Transformers

Since the paper The Era of 1-bit LLMs was relased, it makes me wonder whether training transformers with the proposed BitLinear can also work across all modality on applications other than LLMs, for example, vision based models such as ViT(TerViT but no source code that I can find), DETR, DINO, LlaVa etc.

DETR (Detection Transformer)

After some attempts to modify DETR base on some of the most popular computer vision libraries such as ultralytics, mmdet, detectron2, it felt like I was editing yaml files most of the time which was quite frustrating. The implementation from huggingface seems more straight forward but I got lost in too many if statements and hard for me to understand what the original DETR model is about. Therefore I based this repo off the original detr repo, which was a bit out dated, e.g, it didn't support mixed precision, it has low GPU utilization during training. Therefore, I decided to rewrite everything from scratchwith the goal to make it easy to read, study, and hack around. (still a work in progress to remove the complexity, dataloading and preprocessing is another big mess)

Notes on BitLinear

Formulation

$y = f(x) = \tilde{W}\tilde{x}$

  • The tenarization of a weight $W \in \mathbb{R}^{n \times m}$ can be formulated as:

    $\tilde{W} = {RoundClip}(\dfrac{W}{\beta+\epsilon}, -1, 1)$

    where $RoundClip(x, a, b)=max(a, min(b, round(x)))$, and $\displaystyle \beta = \frac{1}{nm}\sum_{ij}|W_{ij}|$.

  • The activations are further quantized to $b$-bit precision by using absmax quantization, which scales activations into the range $[-Q_b, Q_b] (Q_b=2^b-1)$ by multiplying with $Q_b$ and dividing by the absolute maximum of the input matrix:

    $\tilde{x} = Quant(x) = Clip(\dfrac{xQ_b}{\gamma}, -Q_b+\epsilon, Q_b-\epsilon)$

    where $Clip(x, a, b)=max(a, min(b, x))$, and $\gamma = ||x||_{\infty}$

Implementation

  1. Based on the implementations provided by FAQ, both x and w are still in float16 during training. However, they do get quantized to maintain the property of 8 bits for x and ternary for w. Both x_quant and w_quant are also rescaled before F.linear which becomes

    $f(x)=(\beta\tilde{W})(\dfrac{\gamma\tilde{x}}{Q_b})$

    x_scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
    x_quant = (x * x_scale).round().clamp(-128, 127)
    x_quant = x + (x_quant / x_scale - x).detach()
    
    w_scale = 1.0 / w.abs().mean().clamp(min=1e-5)
    w_quant = (w * w_scale).round().clamp(-1, 1)
    w_quant = w + (w_quant / w_scale - w).detach()
    
    output  = F.linear(x_quant, w_quant)

    Using .detach() is a trick to employ straight-through estimator to make F.linear think it is still calculating $f(x)=Wx$ instead of $\tilde{W}\tilde{x}$, which can bypass the non-differentiable functions such as $RoundClip$ and $Clip$. The resulting gradient then becomes $\nabla f = \dfrac{\partial f}{\partial W} = x$.

    The FAQ also mentioned,

    the standard F.linear operation is replaced with a customized low-bit kernel.

    With FP8 GEMM kernels, we can use FP8 activations for training and quantize to INT8 for inference on GPU devices with CUDA comptue capability < 9.0.

    source code of the custom kernels can be found in BitBLAS.

  2. this operation mathmatically is equivalent to $f(x)=(\beta\tilde{W})(\dfrac{\gamma\tilde{x}}{Q_b})=\tilde{W}\tilde{x}(\dfrac{\beta\gamma}{Q_b})$. which means both scaling factors can be applied to the output of F.linear instead of its inputs.

    x_scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
    x_quant = (x * x_scale).round().clamp(-128, 127)
    x_quant = x + (x_quant - x).detach()
    
    w_scale = 1.0 / w.abs().mean().clamp(min=1e-5)
    w_quant = (w * w_scale).round().clamp(-1, 1)
    w_quant = w + (w_quant - w).detach()
    
    output  = F.linear(x_quant, w_quant) / (x_scale * w_scale)

    x_quant and w_quant are $[-127, 127]$ (INT8) and ${-1, 0, 1}$ (~INT2).

  3. If allowing $x$ to stay at FP16 but only quantize and rescale $W$ to tenary, it essentially becomes $f(x)=\tilde{W}\tilde{x}\beta$

    output  = F.linear(x, w_quant) / w_scale

Due to floating-point arithmetic not always being associative or commutative, the outputs slightly diverge even though they are mathmatically equivalent. A few tests in models/bitlinear.py were created to demonstrate this.

Reults

DETR

To clarify, the ResNet-50 backbone are still in FP16 since there is no ternary weight backbone available at this point. During training, all the weights from the transformer were completely ternarized which resulted at around 17M ternarized parameters out of 40M total model parameters.

Currently, the analysis are based on training the model for only 1 epoch. Perform a full training on the COCO dataset would take days given the compute resouces that is avaiable to me. As you can see in the loss curve, when fully quantize the inputs into 8 bits and the weights into ternary state, the model's training loss suffers pretty significantly. However, if keeping the inputs at FP16 precision and only ternaried the model weights, the gap between the quantized model and the original model becomes much closer.

Comparison between using nn.Linear (fp16 X fp16) and BitLinear (fp16 X int1.58-simulated, int8-simulated X int1.58-simulated) in the transformer of DETR.

TODO

  • rewrite the model to make the coder simplier, more readable, and easy to study.
    • implement MultiheadAttention from scratch but keep F.scaled_dot_product_attention to utilized the optimized flash attentions kernel.
    • remove the entirety of NestedTensor in DETR, the forward pass now takes two arguments both padded img and padding mask
    • simply SetCriterion which is the biggest bottleneck of the training (need to profile it), only l1_loss, giou_loss, and cross_entropy were used to compute the gradients. Additionally, using torch.Tensor instead of a dictionary so the all_reduce can be applied automatically.
    • training in float16 using amp
    • deepspeed integration for multigpu training
  • Use custom kernels from BitBLAS for F.linear, however currrently it doesn't support autograd. In addition, based on their reported benchmarks, you only gain significant speed improvement when computing the GEMM in INT8xINT2.
  • Train a ViT, SwinViT backbone with ternaried weights. Specifically, swin-v2 has a 3B parameters model which would put it at the same parameter scale with the model size reported in the BitNet1.58 paper
  • Once there's a backbone with ternarized weight, perform a full COCO training comparison
  • Rewrite the image preprocessing from scratch utilizing Albumentation, this is surprisingly painful right now. Maybe even benchmark it against torchvision.transform.v2
  • Try BitLinear on DINO, LlaVa.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published