Skip to content

Avoid OOM on TPU #1690

Answered by jheek
borisdayma asked this question in General
Nov 3, 2021 · 5 comments · 5 replies
Discussion options

You must be logged in to vote

This is quite odd for sure. Fragmentation and being close to the limit in terms of memory could off course result in errors that appear almost randomly. One thing you could try is to initialize the model on CPU jax.jit(model.init, backend="cpu") The params are moved to TPU automatically during training or during replication of the state (eg jax_utils.replicate)

Replies: 5 comments 5 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
4 replies
@borisdayma
Comment options

@borisdayma
Comment options

@jheek
Comment options

@sarataylor2000
Comment options

Answer selected by borisdayma
Comment options

You must be logged in to vote
1 reply
@marcvanzee
Comment options

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
7 participants
Converted from issue

This discussion was converted from issue #1658 on November 29, 2021 13:09.