Skip to content

Conversation

TroyGarden
Copy link
Contributor

Summary:

context

  • APS is using "variable batch size" during training, e.g., using a smaller batch_size (like 32) to warm up then use a larger batch_size (like 64) for the rest of training.
    batch_size_schedule:
      - batch_size: 32
        max_iters: 5
      - batch_size: 64
        max_iters: 999999999
  • however, this becomes a problem for torch.export (PT2 IR) because the exported program assumes the batch_size to be constant.
    NOTE: this "variable batch" concept is fundamentally different from the "variable length" (VLE/VBE)
  • in the variable batch scenario, within the same batch/training iteration, each feature in the KJT shares the same batch_size (which can only vary in a later iteration), so it follows the correlation: batch_size = length(kjt._lengths) // len(kjt._keys), and kjt.stride() returns the batch_size by calculation from _lengths and _keys.
  • in the variable length scenario, within the same batch/training iteration, each feature in the KJT could have different batch_size, and there's no correlation between _lengths and _keys or batch_size.
  • so this "variable batch size" CAN NOT simply be resolved by setting all input KJTs as variable lengths, instead, it has to use batch_size as a dynamic shape implicitly from the mark_dynamic_kjt util function.
    WARNING: it's the user's responsibility to make sure that the variable_batch is only used when setting variable_length to False, otherwise it will cause unexpected behavior with the dynamic shapes in torch.export

Reviewed By: spmex, malaybag

Differential Revision: D82792378

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 19, 2025
@facebook-github-bot
Copy link
Contributor

@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating diff in D82792378.

Summary:
Pull Request resolved: pytorch#3388

Pull Request resolved: pytorch#3387

# context
* APS is using "variable batch size" during training, e.g., using a smaller `batch_size` (like 32) to warm up then use a larger `batch_size` (like 64) for the rest of training.
```
    batch_size_schedule:
      - batch_size: 32
        max_iters: 5
      - batch_size: 64
        max_iters: 999999999
```

* however, this becomes a problem for torch.export (PT2 IR) because the exported program assumes the `batch_size` to be constant.
NOTE: this "variable batch" concept is fundamentally different from the "variable length" (VLE/VBE)
* in the variable batch scenario, within the same batch/training iteration, each feature in the KJT shares the same `batch_size` (which can only vary in a later iteration), so it follows the correlation: `batch_size = length(kjt._lengths) // len(kjt._keys)`, and `kjt.stride()` returns the `batch_size` by calculation from `_lengths` and `_keys`.
* in the variable length scenario, within the same batch/training iteration, each feature in the KJT could have different `batch_size`, and there's no correlation between `_lengths` and `_keys` or `batch_size`.
* so this "variable batch size" **CAN NOT** simply be resolved by setting all input KJTs as variable lengths, instead, it has to use `batch_size` as a dynamic shape implicitly from the `mark_dynamic_kjt` util function.
WARNING: it's the user's responsibility to make sure that the `variable_batch` is only used when setting `variable_length` to `False`, otherwise it will cause unexpected behavior with the dynamic shapes in torch.export

Reviewed By: spmex, malaybag

Differential Revision: D82792378
@facebook-github-bot
Copy link
Contributor

@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating diff in D82792378.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants