Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerEnginePrecision _convert_layers(module) fails for FSDP zero2/zero3 #19989

Open
wprazuch opened this issue Jun 18, 2024 · 0 comments
Open
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x

Comments

@wprazuch
Copy link

wprazuch commented Jun 18, 2024

Bug description

TransformerEnginePrecision.convert_module function seems to not work for the the FSDP-wrapped model.

What version are you seeing the problem on?

master

How to reproduce the bug

model = FSDP(
    model,
    sharding_strategy=sharding_strategy,
    auto_wrap_policy=custom_wrap_policy,
    device_id=local_rank,
    use_orig_params=True,
    device_mesh=mesh,
)
te_precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, replace_layers=True)
self.model = te_precision.convert_module(self.model)

Error messages and logs

[rank1]:     self.model = te_precision.convert_module(self.model)
[rank1]:     _convert_layers(module)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/fabric/plugins/precision/transformer_engine.py", line 165, in _convert_layers
[rank1]:     replacement.weight.data = child.weight.data.clone()
[rank1]: RuntimeError: Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.

More info

I actually see it for pytorch-lightning==2.3.0

@wprazuch wprazuch added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x
Projects
None yet
Development

No branches or pull requests

1 participant