-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert OpenELM to float16 Core ML #95
Comments
I'm interested. |
Hey, managed to track the issue to the CoreML converter pipelines, When using the following as pipeline = ct.PassPipeline.EMPTY
pipeline.append_pass("common::const_elimination")
pipeline.append_pass("common::add_fp16_cast")
pipeline.append_pass("common::dedup_op_and_var_names") These are the minimal passes required for it to run. I do not see ANE usage when converting to ALL compute devices, but do see a medium-low usage when converting to CPU_AND_NE, but with a much lower inference time. This is related to the fact that without the pass pipeline optimizations there are still a lot of additional operations that involve FP32 precision. I tried adding the cast_optimization pass, but it causes the model predictions to be erroneous again, so the issue is probably related to this optimization. |
I think the main issue is with the RMSNorm, had to do some hacky hacks to partially execute it in fp32. I think that the most accurate way to normalize in CoreML is using the MIL op l2_norm, which isn't accesible from Pytorch, so I patched the del _TORCH_OPS_REGISTRY["acos"]
eps = 1e-5
@register_torch_op
def acos(context, node):
x, = _get_inputs(context, node, expected=1)
x = mb.expand_dims(x=x, axes=[-1, -2]) # l2_norm works on the last 3 dimensions, so we have to expand 2 dims
x = mb.l2_norm(x=x, epsilon=eps)
x = mb.squeeze(x=x, axes=[-1, -2], name=node.name)
context.add(x) And made CustomRMSNorm that utilizes class CustomRMSNorm(nn.Module):
def __init__(self, weight, eps):
super().__init__()
self.weight = weight
self.hscale = weight.size(0) ** 0.5
self.eps = eps
def forward(self, x):
# CoreML works with inputs up to 5 dimensions, so the queries and keys normalization would
# fail because they have (batch, sequence, nheads, hdim) 4 dimensions, and we expand 2 additional dims
# so we squeeze the batch dim and unsqueeze it after
# THIS MEANS THAT THIS METHOD CURRENTLY WORKS WITH BATCH SIZE 1
if len(x.size()) == 4:
x = x.squeeze(0)
unsqueeze = True
else:
unsqueeze = False
x = x.acos()
if unsqueeze:
x = x.unsqueeze(0)
return x * self.weight * x.size(-1) ** 0.5 # l2_norm does not perform scaling with the sqrt of dim (.pow().mean() in Pytorch), so we do it here And we replace all the RMSNorm layers with our custom layer model.transformer.norm = CustomRMSNorm(model.transformer.norm.weight, model.transformer.norm.eps)
for layer in model.transformer.layers:
layer.attn.q_norm = CustomRMSNorm(layer.attn.q_norm.weight, layer.attn.q_norm.eps)
layer.attn.k_norm = CustomRMSNorm(layer.attn.k_norm.weight, layer.attn.k_norm.eps)
layer.ffn_norm = CustomRMSNorm(layer.ffn_norm.weight, layer.ffn_norm.eps)
layer.attn_norm = CustomRMSNorm(layer.attn_norm.weight, layer.attn_norm.eps) Finally, we perform the def selector(op):
return op.op_type != "l2_norm"
compute_precision = ct.transform.FP16ComputePrecision(op_selector=selector)
coreml_model = ct.convert(
...,
compute_precision=compute_precision,
) This modification do no require to modify the pass_pipelines I mention in my previous reply. These hacks provide a ~20-30% speedup against the fp32 CoreML, while achieving very similar outputs. From what I saw it still runs mainly on GPU with a low ANE usage. Some other considerations, on why I used I think the best solution is to directly implement the model in MIL ops, I've been working on this for the past week, and already have a functioning GPT-2 model, hopefully I'll be able to have the OpenELM implementation someday next week. |
You don't need to use the Also, it's interesting that you need to clamp. Llama 2 exhibits similar large activation outliers (as do other models, per this paper) but I didn't need to clamp for it to work. I am doing dim=1 (different tensor layout) so possible that is why. |
Thanks for the |
Thanks all for the great comments and analysis! Turns out the error on random inputs happens because inference is running on CPU for validation. When you run on GPU, generations are actually fine! I haven't tested the Neural Engine yet, will do and report back. It's interesting that some op is not working properly on CPU when using half precision, I don't think I've seen it before in other models. Worth diving deeper in my opinion :) |
Seems that |
Made Optimization Guidelines for the Apple Neural Engine.txt for targeting the ANE based on:
Also uploaded a palettized model anthonymikinka/OpenELM-1_1B-Instruct-128-FP16ComputePrecision-Palettized-Kmeans-4bits I am unable to run the performance report due to the RAM issue. Hopefully this helps @0seba |
Very nice summary and great collection of resources @antmikinka! 🙌 The model you linked produces 16-bit outputs, which are still a bit less compatible than float32. Could you export such that the output is float32, and I can try to run with
|
This is your model running on Xcode @antmikinka: |
@pcuenca I'm glad that the resources are helpful. Original Model: anthonymikinka/OpenELM-1_1B-Instruct-128-FP16ComputePrecision_v2 |
@antmikinka - as a reference : your v2 6Bits model running on 2021 M1 Pro 10 core 16GB. As a reference... any news on ANE ? I don’t have enough memory to performance test on my machine. |
If I am not mistaken, the repeat_interleave op is one op we're deleting. This op also deals with kv cache. apple/ml-recurrent-drafter was just updated last week with 5499 file additions, a number of them copyrighted from 2020. Maybe we can take some of this information and apply it to swift-transformers? ANE optimization:
ml-recurrent-drafter repo incorporates ANE Principles with an example of llama: Here are some of the files:
|
Updated antmikinka/swift-transformers-test yesterday. Check it out for more information. I have two different viewing methods for the models. CoreMLInspect Models - Layers, OPs, & Precision
layer-iteration.py Model - Layers, OPs, & Precisionmodel chunkingI have chunked the OpenELM-270M-Instruct. I am unsure how it performs, need to update my chunk_mlprogram.py to calculate the PSNR. Noticed after uploading to HF, mlpackage went from 625MB to 482MB in size. Chunking did change what ops run where, increased some to other compute units. |
Apple updating a lot of work with coremltools probably to showcase the upcoming WWDC 2024 Day 2 event.
CoreMLTools 8.0b1 Release Link Optimization Overview CoreMLTools8.0b1
|
I converted the models to
float32
using this script: https://gist.github.com/pcuenca/23cd08443460bc90854e2a6f0f575084, but found precision problems when targetingfloat16
. It'd be interesting to see what the performance is forfloat16
, but we need to determine what layers/ops need to be kept infloat32
. Anyone interested please let us know and we can work on it or test together :)The text was updated successfully, but these errors were encountered: