Since the paper The Era of 1-bit LLMs was relased, it makes me wonder whether training transformers with
the proposed BitLinear
can also work across all modality on applications other than LLMs, for example, vision based models such as
ViT(TerViT but no source code that I can find), DETR, DINO, LlaVa etc.
After some attempts to modify DETR base on some of the most popular computer vision libraries such as ultralytics, mmdet, detectron2, it felt like I was editing yaml files most of the time which was quite frustrating. The implementation from huggingface seems more straight forward but I got lost in too many if statements and hard for me to understand what the original DETR model is about. Therefore I based this repo off the original detr repo, which was a bit out dated, e.g, it didn't support mixed precision, it has low GPU utilization during training. Therefore, I decided to rewrite everything from scratchwith the goal to make it easy to read, study, and hack around. (still a work in progress to remove the complexity, dataloading and preprocessing is another big mess)
-
The tenarization of a weight
$W \in \mathbb{R}^{n \times m}$ can be formulated as:$\tilde{W} = {RoundClip}(\dfrac{W}{\beta+\epsilon}, -1, 1)$ where
$RoundClip(x, a, b)=max(a, min(b, round(x)))$ , and$\displaystyle \beta = \frac{1}{nm}\sum_{ij}|W_{ij}|$ . -
The activations are further quantized to
$b$ -bit precision by using absmax quantization, which scales activations into the range$[-Q_b, Q_b] (Q_b=2^b-1)$ by multiplying with$Q_b$ and dividing by the absolute maximum of the input matrix:$\tilde{x} = Quant(x) = Clip(\dfrac{xQ_b}{\gamma}, -Q_b+\epsilon, Q_b-\epsilon)$ where
$Clip(x, a, b)=max(a, min(b, x))$ , and$\gamma = ||x||_{\infty}$
-
Based on the implementations provided by FAQ, both
x
andw
are still infloat16
during training. However, they do get quantized to maintain the property of 8 bits forx
and ternary forw
. Bothx_quant
andw_quant
are also rescaled beforeF.linear
which becomes$f(x)=(\beta\tilde{W})(\dfrac{\gamma\tilde{x}}{Q_b})$ x_scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) x_quant = (x * x_scale).round().clamp(-128, 127) x_quant = x + (x_quant / x_scale - x).detach() w_scale = 1.0 / w.abs().mean().clamp(min=1e-5) w_quant = (w * w_scale).round().clamp(-1, 1) w_quant = w + (w_quant / w_scale - w).detach() output = F.linear(x_quant, w_quant)
Using
.detach()
is a trick to employ straight-through estimator to makeF.linear
think it is still calculating$f(x)=Wx$ instead of$\tilde{W}\tilde{x}$ , which can bypass the non-differentiable functions such as$RoundClip$ and$Clip$ . The resulting gradient then becomes$\nabla f = \dfrac{\partial f}{\partial W} = x$ .The FAQ also mentioned,
the standard F.linear operation is replaced with a customized low-bit kernel.
With FP8 GEMM kernels, we can use FP8 activations for training and quantize to INT8 for inference on GPU devices with CUDA comptue capability < 9.0.
source code of the custom kernels can be found in BitBLAS.
-
this operation mathmatically is equivalent to
$f(x)=(\beta\tilde{W})(\dfrac{\gamma\tilde{x}}{Q_b})=\tilde{W}\tilde{x}(\dfrac{\beta\gamma}{Q_b})$ . which means both scaling factors can be applied to the output ofF.linear
instead of its inputs.x_scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) x_quant = (x * x_scale).round().clamp(-128, 127) x_quant = x + (x_quant - x).detach() w_scale = 1.0 / w.abs().mean().clamp(min=1e-5) w_quant = (w * w_scale).round().clamp(-1, 1) w_quant = w + (w_quant - w).detach() output = F.linear(x_quant, w_quant) / (x_scale * w_scale)
x_quant
andw_quant
are$[-127, 127]$ (INT8) and${-1, 0, 1}$ (~INT2). -
If allowing
$x$ to stay at FP16 but only quantize and rescale$W$ to tenary, it essentially becomes$f(x)=\tilde{W}\tilde{x}\beta$ output = F.linear(x, w_quant) / w_scale
Due to floating-point arithmetic not always being associative or commutative, the outputs slightly diverge even though they are mathmatically equivalent. A few tests in models/bitlinear.py were created to demonstrate this.
To clarify, the ResNet-50 backbone are still in FP16 since there is no ternary weight backbone available at this point. During training, all the weights from the transformer were completely ternarized which resulted at around 17M ternarized parameters out of 40M total model parameters.
Currently, the analysis are based on training the model for only 1 epoch. Perform a full training on the COCO dataset would take days given the compute resouces that is avaiable to me. As you can see in the loss curve, when fully quantize the inputs into 8 bits and the weights into ternary state, the model's training loss suffers pretty significantly. However, if keeping the inputs at FP16 precision and only ternaried the model weights, the gap between the quantized model and the original model becomes much closer.
Comparison between using nn.Linear (fp16 X fp16) and BitLinear (fp16 X int1.58-simulated, int8-simulated X int1.58-simulated) in the transformer of DETR.- rewrite the model to make the coder simplier, more readable, and easy to study.
- implement
MultiheadAttention
from scratch but keepF.scaled_dot_product_attention
to utilized the optimized flash attentions kernel. - remove the entirety of
NestedTensor
in DETR, the forward pass now takes two arguments both padded img and padding mask - simply
SetCriterion
which is the biggest bottleneck of the training (need to profile it), onlyl1_loss
,giou_loss
, andcross_entropy
were used to compute the gradients. Additionally, usingtorch.Tensor
instead of a dictionary so theall_reduce
can be applied automatically. - training in float16 using
amp
- deepspeed integration for multigpu training
- implement
- Use custom kernels from BitBLAS for
F.linear
, however currrently it doesn't support autograd. In addition, based on their reported benchmarks, you only gain significant speed improvement when computing the GEMM in INT8xINT2. - Train a ViT, SwinViT backbone with ternaried weights. Specifically, swin-v2 has a 3B parameters model which would put it at the same parameter scale with the model size reported in the BitNet1.58 paper
- Once there's a backbone with ternarized weight, perform a full COCO training comparison
- Rewrite the image preprocessing from scratch utilizing
Albumentation
, this is surprisingly painful right now. Maybe even benchmark it againsttorchvision.transform.v2
- Try
BitLinear
on DINO, LlaVa.