Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't save checkpoint with orbax when using zero-size parameters #4309

Open
dionhaefner opened this issue Oct 18, 2024 · 3 comments
Open

Can't save checkpoint with orbax when using zero-size parameters #4309

dionhaefner opened this issue Oct 18, 2024 · 3 comments
Assignees

Comments

@dionhaefner
Copy link

Trying to save a checkpoint when there are zero-size variables present raises an exception. Used to work fine pre-orbax. (This is part of a bigger model that has conditional logic where some of the variables are unused in certain configurations.)

Reproducer:

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training.checkpoints import save_checkpoint

class DummyModule(nn.Module):
    @nn.compact
    def __call__(self, x):
        var = self.variable('batch_stats', 'var', lambda: jnp.zeros((0,)))
        return x
    
state = DummyModule().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
save_checkpoint('/tmp/foo', state, 1)

This prints:

$ python flax_bug_repro.py
Traceback (most recent call last):
  File "/Users/dion/codes/supersede/flax_bug_repro.py", line 17, in <module>
    save_checkpoint('/tmp/foo', state, 1)
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/flax/training/checkpoints.py", line 694, in save_checkpoint
    orbax_checkpointer.save(
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 216, in save
    self._handler.finalize(tmpdir.get())
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 998, in finalize
    self._handler_impl.finalize(directory)
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 782, in finalize
    asyncio.run(
  File "/opt/homebrew/Cellar/[email protected]/3.10.13/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/homebrew/Cellar/[email protected]/3.10.13/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 657, in merge_ocdbt_per_process_files
    await _validate_params(ts_kv_store, use_zarr3=use_zarr3)
  File "/Users/dion/.virtualenvs/tempenv-7b402269854a9/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 574, in _validate_params
    raise ValueError(
ValueError: Save failed: 1/1 params are missing in checkpoint:
batch_stats.var.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/tmp/foo/checkpoint_1.orbax-checkpoint-tmp-0/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '1dc83c4d929da7f10e13b4bd3f592ccd',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).

System information

Flax==0.10.0
orbax==0.1.9
@hrbigelow
Copy link

This seems an orbax issue rather than flax. Looks like a recent change to Orbax assumes each checkpoint entry with a '.zarray' should have at least one entry without. Relevant function is _validate_params

Image

For instance:

import jax, jax.numpy as jnp
from flax.training import checkpoints
import tempfile

with tempfile.TemporaryDirectory() as dir_path:
  test_object = {
    'a': jnp.array([1, 2, 3], jnp.int32),
    'z': jnp.zeros((0,)),
  }
  file_path = checkpoints.save_checkpoint(
    dir_path, target=test_object, step=0, prefix='test_', keep=1
  )
  restored_object = checkpoints.restore_checkpoint(
    file_path, target=None
  )

print(restored_object)
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
...

Produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.

@dionhaefner
Copy link
Author

Sooo should I take this up with the orbax people or are you already in contact?

@hrbigelow
Copy link

hrbigelow commented Nov 6, 2024

Hi @IvyZX do you mind if I take a look and try to solve this on the Orbax side?

EDIT: @dionhaefner I opened an orbax issue 1309 for this. It's a bug either in orbax or tensorstore, not flax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants