diff --git a/vmoe/checkpoints/serialization.py b/vmoe/checkpoints/serialization.py index 96a3218..3144e40 100644 --- a/vmoe/checkpoints/serialization.py +++ b/vmoe/checkpoints/serialization.py @@ -160,19 +160,16 @@ class _MsgpackExtType(enum.IntEnum): def _shaped_array_to_bytes(x: core.ShapedArray) -> bytes: - tpl = (x.shape, x.dtype.name, x.weak_type, x.named_shape) - assert all(isinstance(key, str) for key in x.named_shape) + tpl = (x.shape, x.dtype.name, x.weak_type, {}) return msgpack.packb(tpl, use_bin_type=True) def _shaped_array_from_bytes(data: bytes) -> core.ShapedArray: - shape, dtype_name, weak_type, named_shape = msgpack.unpackb(data, raw=True) - named_shape = {k.decode('utf-8'): v for k, v in named_shape.items()} + shape, dtype_name, weak_type, _ = msgpack.unpackb(data, raw=True) return core.ShapedArray( shape=shape, dtype=_dtype_from_name(dtype_name), - weak_type=weak_type, - named_shape=named_shape) + weak_type=weak_type) def _slice_to_bytes(x: Slice) -> bytes: