diff --git a/vmoe/checkpoints/serialization.py b/vmoe/checkpoints/serialization.py index f5b6723..96a3218 100644 --- a/vmoe/checkpoints/serialization.py +++ b/vmoe/checkpoints/serialization.py @@ -287,7 +287,7 @@ def _msgpack_ext_pack(x): return msgpack.ExtType(_MsgpackExtType.index_info, _index_info_to_bytes(x)) if isinstance(x, (np.ndarray, jax.Array)): return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x)) - if np.issctype(type(x)): + if isinstance(x, np.generic): return msgpack.ExtType(_MsgpackExtType.npscalar, _ndarray_to_bytes(np.asarray(x))) return x