From 389af9479f91914bd16ca1195566e6b5d99af79f Mon Sep 17 00:00:00 2001
From: Parker Schuh <parkers@google.com>
Date: Thu, 29 Aug 2024 10:24:40 -0700
Subject: [PATCH] [named shape cleanup]: Named shape no longer does anything.

PiperOrigin-RevId: 668982860
---
 vmoe/checkpoints/serialization.py | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/vmoe/checkpoints/serialization.py b/vmoe/checkpoints/serialization.py
index 96a3218..3144e40 100644
--- a/vmoe/checkpoints/serialization.py
+++ b/vmoe/checkpoints/serialization.py
@@ -160,19 +160,16 @@ class _MsgpackExtType(enum.IntEnum):
 
 
 def _shaped_array_to_bytes(x: core.ShapedArray) -> bytes:
-  tpl = (x.shape, x.dtype.name, x.weak_type, x.named_shape)
-  assert all(isinstance(key, str) for key in x.named_shape)
+  tpl = (x.shape, x.dtype.name, x.weak_type, {})
   return msgpack.packb(tpl, use_bin_type=True)
 
 
 def _shaped_array_from_bytes(data: bytes) -> core.ShapedArray:
-  shape, dtype_name, weak_type, named_shape = msgpack.unpackb(data, raw=True)
-  named_shape = {k.decode('utf-8'): v for k, v in named_shape.items()}
+  shape, dtype_name, weak_type, _ = msgpack.unpackb(data, raw=True)
   return core.ShapedArray(
       shape=shape,
       dtype=_dtype_from_name(dtype_name),
-      weak_type=weak_type,
-      named_shape=named_shape)
+      weak_type=weak_type)
 
 
 def _slice_to_bytes(x: Slice) -> bytes: