diff --git a/vmoe/initialization/initialization.py b/vmoe/initialization/initialization.py index cb3a621..b778b9c 100644 --- a/vmoe/initialization/initialization.py +++ b/vmoe/initialization/initialization.py @@ -62,7 +62,7 @@ class PyTreeCheckpointHandlerWithStructure( orbax_checkpoint.PyTreeCheckpointHandler): def structure(self, directory): - return self._handler_impl._read_aggregate_file(directory) # pylint: disable=protected-access + return self._read_aggregate_file(directory) # pylint: disable=protected-access def initialize_from_orbax(