Skip to content

[Q] How to properly save and load fp8 NumPy arrays? #207

@apivovarov

Description

@apivovarov

I would like to save and load an f8m5e2 array. I initially tried using the standard numpy.save() and numpy.load() functions, but loading fails.

.local/lib/python3.10/site-packages/numpy/lib/format.py", line 325, in descr_to_dtype
    return numpy.dtype(descr)
TypeError: data type '<f1' not understood

.local/lib/python3.10/site-packages/numpy/lib/format.py", line 683, in _read_array_header
    raise ValueError(msg.format(d['descr'])) from e
ValueError: descr is not a valid dtype descriptor: '<f1'

I found that I can save and load float8 arrays using a lower-level API (np.tobytes / np.frombuffer), as shown below:

import ml_dtypes
import numpy as np
import json

# Create the array
x = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)

# Save the array
with open("a.npy", "wb") as f:
    f.write(x.tobytes())

# Save the array's shape and dtype separately
meta = {"shape": x.shape, "dtype": str(x.dtype)}
with open("a_meta.json", "w") as f:
    json.dump(meta, f)

# Load the array
with open("a.npy", "rb") as f:
    data = f.read()

# Load the metadata
with open("a_meta.json", "r") as f:
    meta = json.load(f)

# Reconstruct the array
x2 = np.frombuffer(data, dtype=getattr(ml_dtypes, meta["dtype"])).reshape(meta["shape"])

print(x2)

Is the solution above (np.tobytes / np.frombuffer) considered best practice for this case?

@jakevdp Jake, can you comment on it?

Related Issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions