Skip to content

Commit a54ee13

Browse files
committed
Fixes to run quantized bumblebee models
1 parent ee8f855 commit a54ee13

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

lib/axon.ex

+13-3
Original file line numberDiff line numberDiff line change
@@ -3849,9 +3849,19 @@ defmodule Axon do
38493849
{_node, model} = Axon.pop_node(model)
38503850
"""
38513851
@doc type: :graph
3852-
def pop_node(%Axon{nodes: nodes, output: id} = axon) do
3853-
{%{parent: [parent_id]} = popped, nodes} = Map.pop!(nodes, id)
3854-
{popped, %{axon | nodes: nodes, output: parent_id}}
3852+
def pop_node(%Axon{nodes: nodes, output: id}) do
3853+
{popped, nodes} = Map.pop!(nodes, id)
3854+
3855+
case popped do
3856+
%{op_name: :container, parent: parents, op: fun} = popped ->
3857+
{popped, apply(fun, Enum.map(parents, &%Axon{nodes: nodes, output: &1}) ++ [[]])}
3858+
3859+
%{parent: [_ | _] = parents} = popped ->
3860+
{popped, Enum.map(parents, &%Axon{nodes: nodes, output: &1})}
3861+
3862+
%{parent: [parent_id]} = popped ->
3863+
{popped, %Axon{nodes: nodes, output: parent_id}}
3864+
end
38553865
end
38563866

38573867
@doc """

lib/axon/quantization.ex

+7-4
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ defmodule Axon.Quantization do
4141
All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`.
4242
"""
4343
def quantize_model(%Axon{} = model) do
44-
quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias ->
45-
weight_only_quantized_dense(x, units, use_bias: use_bias)
44+
quantized_dense_rewriter = fn [%Axon{} = x], _output, name_fn, units, use_bias ->
45+
weight_only_quantized_dense(x, units,
46+
use_bias: use_bias,
47+
name: name_fn
48+
)
4649
end
4750

4851
Axon.rewrite_nodes(model, fn
49-
%Axon.Node{op: :dense, meta: meta} ->
50-
&quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias])
52+
%Axon.Node{op: :dense, meta: meta, name: name_fn} ->
53+
&quantized_dense_rewriter.(&1, &2, name_fn, meta[:units], meta[:use_bias])
5154

5255
_ ->
5356
:skip

lib/axon/quantization/layers.ex

+18-6
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,26 @@ defmodule Axon.Quantization.Layers do
3030
end
3131

3232
defnp weight_only_quantized_dense_impl(
33-
input,
34-
%QTensor{value: kernel, scale: scale},
33+
x,
34+
%QTensor{value: w_int8, scale: scales},
3535
bias,
3636
_opts
3737
) do
38-
input
39-
|> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0])
40-
|> Nx.multiply(scale)
41-
|> Nx.add(bias)
38+
x_shape = Nx.shape(x)
39+
last_dim = Nx.axis_size(x, -1)
40+
41+
x_view = Nx.reshape(x, {:auto, last_dim})
42+
43+
y = Nx.dot(x_view, Nx.as_type(Nx.transpose(w_int8), Nx.type(x)))
44+
y = Nx.multiply(y, scales)
45+
y = reshape_output(y, x_shape)
46+
47+
Nx.add(y, bias)
48+
end
49+
50+
deftransformp reshape_output(output, x_shape) do
51+
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
52+
new_shape = Tuple.append(all_but_last, :auto)
53+
Nx.reshape(output, new_shape)
4254
end
4355
end

lib/axon/quantization/q_tensor.ex

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ defmodule Axon.Quantization.QTensor do
2222

2323
case opts[:type] do
2424
{:s, 8} ->
25-
dynamically_quantize_per_channel(x, min: -128, max: 127, type: {:s, 8})
25+
dynamically_quantize_per_channel(Nx.transpose(x), min: -128, max: 127, type: {:s, 8})
2626

2727
other ->
2828
raise "unsupported quantization type #{inspect(other)}"

0 commit comments

Comments
 (0)