Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ __marimo__/

# Benchmark output images
examples/images/
examples/outputs/

# Local training outputs and checkpoints
runs/
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,45 @@ Useful flags:
- `--train-size 12000 --val-size 3000`
- `--max-seq-len 192`

## Self-Supervised Example

`examples/language_selfsupervised.py` uses the same streaming Wikipedia data loader,
trains a self-supervised token model, exports hidden states to a memmap, and can
optionally plot UMAP projections from those exports.

Generated artifacts are written under `examples/outputs/language_selfsupervised/`
by default.

Install the example dependencies:

```bash
uv sync --group examples
```

Train, export, and plot with the default CPU UMAP backend:

```bash
uv run python examples/language_selfsupervised.py --plot-umap
```

To use the CUDA UMAP backend in the project environment:

```bash
uv sync --group examples --group cuda-umap
uv run python examples/language_selfsupervised.py --device cuda --plot-umap --umap-backend cuda
```

Useful flags:
- `--horizon single|multi`
- `--export-mode all|final`
- `--plot-umap`
- `--umap-backend auto|cpu|cuda`
- `--plot-time-indices 1,20,100,-5,-1,0` where `0` means `H_T`, the hidden state after consuming `EOS` and aligned with the first `PAD` target; negative indices are relative to that point, and positive indices are one-indexed absolute timesteps
- `--n-per-lang 4000`

The CUDA UMAP path is optional. `--umap-backend auto` falls back to CPU UMAP if the
CUDA packages are unavailable.

## Benchmarking

To compare recurrent and parallel throughput across sequence lengths and hidden dimensions:
Expand All @@ -237,7 +276,7 @@ uv sync --dev

- Python ≥ 3.11
- PyTorch ≥ 2.8.0
- NumPy ≥ 2.4.1
- NumPy ≥ 2

## License

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading