You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #3388
Pull Request resolved: #3387
# context
* APS is using "variable batch size" during training, e.g., using a smaller `batch_size` (like 32) to warm up then use a larger `batch_size` (like 64) for the rest of training.
```
batch_size_schedule:
- batch_size: 32
max_iters: 5
- batch_size: 64
max_iters: 999999999
```
* however, this becomes a problem for torch.export (PT2 IR) because the exported program assumes the `batch_size` to be constant.
NOTE: this "variable batch" concept is fundamentally different from the "variable length" (VLE/VBE)
* in the variable batch scenario, within the same batch/training iteration, each feature in the KJT shares the same `batch_size` (which can only vary in a later iteration), so it follows the correlation: `batch_size = length(kjt._lengths) // len(kjt._keys)`, and `kjt.stride()` returns the `batch_size` by calculation from `_lengths` and `_keys`.
* in the variable length scenario, within the same batch/training iteration, each feature in the KJT could have different `batch_size`, and there's no correlation between `_lengths` and `_keys` or `batch_size`.
* so this "variable batch size" **CAN NOT** simply be resolved by setting all input KJTs as variable lengths, instead, it has to use `batch_size` as a dynamic shape implicitly from the `mark_dynamic_kjt` util function.
WARNING: it's the user's responsibility to make sure that the `variable_batch` is only used when setting `variable_length` to `False`, otherwise it will cause unexpected behavior with the dynamic shapes in torch.export
Reviewed By: spmex, malaybag
Differential Revision: D82792378
it will use the default name "vlen" for values, and "llen", "lofs" if variable length.
212
213
A passed-in dynamic dim is useful if the dynamic dim is already used in other places.
213
214
215
+
variable batch size means the batch size is dynamic during different training iterations
216
+
the batch size for all features are the same within one iteration/batch. so it still follows
217
+
the correlation: len(lengths) == len(keys) * batch_size
218
+
219
+
in the variable length scenario, the batch size could be different for each feature within
220
+
the iteration/batch, so it doesn't follow the correlation: len(lengths) == len(keys) * batch_size
221
+
214
222
Args:
215
223
kjt (KeyedJaggedTensor): The KJT to make dynamic.
216
224
shapes_collection (Optional[ShapesCollection]): The collection to update.
217
-
variable_length (bool): Whether the KJT is variable length.
225
+
variable_length (bool): Whether the KJT is variable length len(lengths) != len(keys) * batch_size
226
+
variable_batch (bool): Whether the KJT is variable batch size, len(lengths) == len(keys) * batch_size, it only works when variable_length is False.
218
227
vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen".
219
228
llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
220
229
batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
0 commit comments