Skip to content

Conversation

davidkoski
Copy link
Collaborator

@davidkoski davidkoski commented Sep 24, 2025

NOTE

This change contains some breaking API changes in the area of quantization. Specifically:

  • the quantized / dequantized methods now take a mode parameter (not breaking)
  • the biases result from quantized is now optional, e.g. (wq: MLXArray, scales: MLXArray, biases: MLXArray?)

We are keeping the same semver here to match with python mlx. Although the change is breaking, it will likely be limited to implementations of quantized layers, e.g. QuantizedLinear, or other code that uses quantization directly. mlx-swift-examples will have a synchronized release to reflect this change.

If you need to make a similar change, consider the changes from QuantizedLinear:

The properties changed from this:

    public let scales: MLXArray
    public let biases: MLXArray

to:

    public let mode: QuantizationMode
    public let scales: MLXArray
    public let biases: MLXArray?

A mode with parameter with a default value was added where needed: mode: QuantizationMode = .affine and the mode parameter was used in calls to the quantization APIs:

        var x = quantizedMatmul(
            x,
            weight,
            scales: scales,
            biases: biases,
            transpose: true,
            groupSize: groupSize,
            bits: bits,
            mode: mode
        )

and the Quantizable protocol was updated to have a mode parameter (protocol methods can't have default values):

    /// Return the module as a quantized representation
    func toQuantized(groupSize: Int, bits: Int, mode: QuantizationMode) -> Module

@davidkoski davidkoski requested a review from awni September 24, 2025 20:48
"mlx/mlx/backend/metal/no_metal.cpp",

// special handling for cuda -- we need to keep one file:
// mlx/mlx/backend/cuda/no_cuda.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little more complicated than I wish, but we can't exclude the directory + include one file, so I need to just list them.

/// - ``asArray(_:)``
/// - ``asData(access:)``
public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? {
let data = asData(access: noCopy ? .noCopyIfContiguous : .copy)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #259 -- this line is unused.


// If it's just a simple slice, just do a slice update and return
if operations.count == 1, case let .slice(slice) = operations[0] {
if operations.count == 1, case .slice(let slice) = operations[0] {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the new swift-format.

/// - values: values with shape `[B, N_kv, T_kv, D]`
/// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))`
/// - mask: mask array
/// - sinks: optional array of attention sinks
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New optional argument

}

let x = MLXArray(1)
let x = MLXArray([1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was incorrect before -- a dimensionless parameter is not the same as a shaped array. Now it throws as the back end rejects it.

/// MX (Microscaling) FP4 quantization format.
///
/// MXFP4 is a specialized 4-bit floating-point format designed for neural network inference.
/// It uses a shared exponent across a block of values with individual 3-bit mantissas plus sign bits.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The individual elements are e2m1 (so 1 sign bit, 2 exponent, 1 mantissa)

///
/// MXFP4 is a specialized 4-bit floating-point format designed for neural network inference.
/// It uses a shared exponent across a block of values with individual 3-bit mantissas plus sign bits.
/// This format can provide better accuracy than standard 4-bit integer quantization for certain
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just remove that.. as it's not usually right (MLX Q4 is probably more accurate for most cases). We support this mostly because of GPT OSS (and probably future models) which were trained in mxfp4 (since the hardware has native support for it).

///
/// The format consists of:
/// - Shared 8-bit exponent per block
/// - Individual 3-bit mantissas + 1 sign bit per element
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update as comment above.

/// - Parameters:
/// - w: The quantized weight matrix to dequantize
/// - scales: Scaling factors used during quantization. Should have shape compatible with the quantized groups
/// - biases: Bias values used during quantization. Should have shape compatible with the quantized groups
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth commenting that it is optional for some modes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is already marked as optional so we are covered there

/// - bits: The number of bits occupied by each element of `w` in the returned quantized matrix. Default is `4`
/// - mode: The quantization mode. Default is `.affine`
/// - stream: Stream or device to evaluate on
/// - Returns: A tuple containing the quantized weights (`wq`), scaling factors (`scales`), and bias values (`biases`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work if the mode is mxfp4? Is the bias null?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question -- as written the values are not optional. Let me write a test and see what shows up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very crashy in that case. Hrm, this is going to change the signature of the method slightly

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks for the update! Left a few comments / questions on the new quantization stuff.

@davidkoski davidkoski merged commit 072b684 into main Oct 16, 2025
1 check passed
@davidkoski davidkoski deleted the mlx-0291 branch October 16, 2025 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants