Skip to content

MPS support #43

@m-clark

Description

@m-clark

Pytorch has advanced mps support quite a bit in recent years, but some functionality can still be problematic. Here is at least one example where torchast has some issues.

Summary: Training a torchcast Kalman / state‑space model on MPS runs autograd through a graph that includes batched torch.linalg.solve (Kalman gain) and related matrix ops over many timesteps. On MPS, that backward path is unreliable in current PyTorch and can fail with torch. AcceleratorError or bogus indexing is not because the data or venv is wrong, but because the MPS implementation of backward for this graph is possibly buggy or immature. Forward on MPS should work; fit (which needs backward) is what breaks. Workaround: run fit on CPU, then optionally model.to("mps") for inference-only on MPS.

See the mps-support branch for some changes that would allow for mps to work correctly, but will currently fall back to CPU for model fitting. I did run the quickstart example as an additional test.

Related Pytorch issues:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions