diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 00000000..0d644a87 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,49 @@ +# Codecov Configuration +# Documentation: https://docs.codecov.com/docs/codecovyml-reference + +coverage: + precision: 2 # Number of decimal places (0-5) + round: down # How to round coverage (down/up/nearest) + range: 70..100 # Color coding range (red at 70%, green at 100%) + + status: + # Project coverage: overall repository coverage + project: + default: + target: 90% # Minimum coverage threshold + threshold: null # No additional threshold; fail check if below target + base: auto # Compare against base branch + informational: false # Fail the check if below target + + # Patch coverage: coverage on changed lines only + patch: + default: + target: 80% # New code should have at least 80% coverage + threshold: 0% # No wiggle room for patch coverage + base: auto + informational: false # Fail if new code doesn't meet target + +# Pull request comment configuration +comment: + layout: "diff, flags, files, footer" # What to show in PR comments + behavior: default # Comment on all PRs + require_changes: false # Comment even if coverage unchanged + require_base: false # Comment even without base report + require_head: true # Only comment if head report exists + +# Paths to ignore in coverage reports +ignore: + - "tests/*" # Test files + - "tests/**/*" # All test subdirectories + - "docs/**/*" # Documentation + - "site/*" # Built documentation site + - "htmlcov/*" # Coverage HTML reports + - ".venv/*" # Virtual environment + - ".tox/*" # Tox environments + - "**/__pycache__/*" # Python cache + - "**/conftest.py" # Pytest configuration + - "update_tests.py" # Utility scripts + +# GitHub Checks configuration +github_checks: + annotations: true # Show coverage annotations on changed files diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index f40ac81c..00000000 --- a/.dockerignore +++ /dev/null @@ -1,39 +0,0 @@ -# Git -.git -.gitignore -.github - -# Docker -.dockerignore - -# IDE -.idea -.vscode - -# Byte-compiled / optimized / DLL files -__pycache__/ -**/__pycache__/ -*.pyc -*.pyo -*.pyd -.Python -*.py[cod] -*$py.class -.pytest_cache/ -..mypy_cache/ - -# poetry -.venv -venv - -# C extensions -*.so - -# Virtual environment -.venv -venv - -.DS_Store -.AppleDouble -.LSOverride -._* diff --git a/.editorconfig b/.editorconfig deleted file mode 100644 index 7f578f15..00000000 --- a/.editorconfig +++ /dev/null @@ -1,24 +0,0 @@ -# Check http://editorconfig.org for more information -# This is the main config file for this project: -root = true - -[*] -charset = utf-8 -end_of_line = lf -insert_final_newline = true -indent_style = space -indent_size = 2 -trim_trailing_whitespace = true - -[*.{py, pyi}] -indent_style = space -indent_size = 4 - -[Makefile] -indent_style = tab - -[*.md] -trim_trailing_whitespace = false - -[*.{diff,patch}] -trim_trailing_whitespace = false diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml new file mode 100644 index 00000000..efc5fbce --- /dev/null +++ b/.github/actions/setup/action.yml @@ -0,0 +1,28 @@ +name: 'Setup Environment' +description: 'Setup Python with uv and install dependencies' +inputs: + python-version: + description: 'Python version' + required: false + default: '3.12' + install-deps: + description: 'Install dependencies' + required: false + default: 'true' + install-groups: + description: 'Dependency groups to install' + required: false + default: '--all-extras --all-groups' +runs: + using: 'composite' + steps: + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + python-version: ${{ inputs.python-version }} + enable-cache: true + + - name: Install dependencies + if: inputs.install-deps == 'true' + run: uv sync ${{ inputs.install-groups }} + shell: bash diff --git a/.github/scripts/test_summary.py b/.github/scripts/test_summary.py new file mode 100755 index 00000000..bff5d1f6 --- /dev/null +++ b/.github/scripts/test_summary.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +"""Generate GitHub Actions test summary from JUnit XML.""" + +import sys +import xml.etree.ElementTree as ET +from pathlib import Path + + +def main() -> int: + """Parse JUnit XML and write summary to GitHub Actions step summary.""" + xml_path = Path("test-results.xml") + + if not xml_path.exists(): + print("⚠️ No test results found", file=sys.stderr) + return 1 + + try: + tree = ET.parse(xml_path) + root = tree.getroot() + + tests = int(root.get("tests", 0)) + failures = int(root.get("failures", 0)) + errors = int(root.get("errors", 0)) + skipped = int(root.get("skipped", 0)) + passed = tests - failures - errors - skipped + + # Determine status emoji + if failures + errors > 0: + status = "❌" + elif skipped == tests: + status = "⏭️" + else: + status = "✅" + + # Print summary lines + print(f"{status} **Test Results Summary**") + print(f"- ✅ Passed: {passed}") + print(f"- ❌ Failed: {failures}") + print(f"- ⚠️ Errors: {errors}") + print(f"- ⏭️ Skipped: {skipped}") + print(f"- **Total: {tests}**") + + # Exit with error if tests failed + return 1 if (failures + errors > 0) else 0 + + except ET.ParseError as e: + print(f"❌ Failed to parse XML: {e}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/workflows/check_pull_request_title.yml b/.github/workflows/check_pull_request_title.yml index 403661e1..2c68fbc0 100644 --- a/.github/workflows/check_pull_request_title.yml +++ b/.github/workflows/check_pull_request_title.yml @@ -1,35 +1,47 @@ name: "Check PR title" + on: pull_request: types: [edited, opened, synchronize, reopened] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + +permissions: + contents: read + pull-requests: read + statuses: write + jobs: pr-title-check: runs-on: ubuntu-latest + timeout-minutes: 5 if: ${{ github.event.pull_request.user.login != 'allcontributors[bot]' }} steps: - # Echo the user's login - - name: Echo user login - run: echo ${{ github.event.pull_request.user.login }} - - - uses: naveenk1223/action-pr-title@master + - uses: amannn/action-semantic-pull-request@v5.5.3 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: + # Require imperative mood (e.g. "Add feature" not "Adds feature") # ^ Start of string # [A-Z] First character must be an uppercase ASCII letter - # [a-zA-Z]* Followed by zero or more ASCII letters - # (?> $GITHUB_STEP_SUMMARY + python3 .github/scripts/test_summary.py >> $GITHUB_STEP_SUMMARY + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results-py${{ matrix.python }} + retention-days: 7 + path: | + test-results.xml + test-report.html + .coverage + + - name: Generate coverage report + if: matrix.python == '3.12' + run: uv run coverage xml + + - name: Upload coverage to Codecov + if: matrix.python == '3.12' + uses: codecov/codecov-action@v5.4.3 + with: + files: ./coverage.xml + fail_ci_if_error: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..d191d93a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,100 @@ +name: CI + +on: + pull_request: + push: + branches: + - main + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +permissions: + contents: read + checks: write + +env: + FORCE_COLOR: 1 + +jobs: + quality: + name: ${{ matrix.check }} + runs-on: ubuntu-latest + timeout-minutes: 10 + strategy: + fail-fast: false + matrix: + check: + - format + - lint + - types + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: ./.github/actions/setup + with: + python-version: "3.12" + install-deps: ${{ matrix.check == 'types' && 'true' || 'false' }} + + - name: Run format check + if: matrix.check == 'format' + run: uvx ruff format --diff + + - name: Run lint check + if: matrix.check == 'lint' + run: uvx ruff check + + - name: Run type check + if: matrix.check == 'types' + run: uv run mypy src + + tests: + name: Tests (Python 3.12) + runs-on: ubuntu-latest + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: ./.github/actions/setup + with: + python-version: "3.12" + + - name: Run pytest with coverage + run: | + uv run coverage run -m pytest tests --durations=10 -m "not slow" \ + --junit-xml=test-results.xml \ + --html=test-report.html --self-contained-html + + - name: Generate test summary + if: always() + run: python3 .github/scripts/test_summary.py >> $GITHUB_STEP_SUMMARY + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results-py312 + retention-days: 7 + path: | + test-results.xml + test-report.html + .coverage + + - name: Generate coverage report + run: uv run coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5.4.3 + with: + files: ./coverage.xml + fail_ci_if_error: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml deleted file mode 100644 index d1520009..00000000 --- a/.github/workflows/code_quality.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Code quality - -on: - push: - branches: - - main - pull_request: - workflow_dispatch: - schedule: - - cron: "0 4 * * *" - -env: - FORCE_COLOR: 1 - -jobs: - check: - name: ${{ matrix.check }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - check: - - format - - lint - # - types - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Install the latest version of uv - uses: astral-sh/setup-uv@v6 - - - name: Install dependencies - if: matrix.check == 'types' - run: uv sync --all-extras --all-groups - - - name: Run format check - if: matrix.check == 'format' - run: uvx ruff format --diff - - - name: Run lint check - if: matrix.check == 'lint' - run: uvx ruff check - - - name: Run type check - if: matrix.check == 'types' - run: uv run mypy src diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml new file mode 100644 index 00000000..4101e9e3 --- /dev/null +++ b/.github/workflows/dependency-review.yml @@ -0,0 +1,22 @@ +name: Dependency Review + +on: + pull_request: + +permissions: + contents: read + pull-requests: write + +jobs: + dependency-review: + name: Review dependencies + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + + - name: Dependency Review + uses: actions/dependency-review-action@v4 + with: + fail-on-severity: moderate + comment-summary-in-pr: on-failure diff --git a/.github/workflows/docs-publish.yml b/.github/workflows/docs-publish.yml index 2453634c..87f01f5a 100644 --- a/.github/workflows/docs-publish.yml +++ b/.github/workflows/docs-publish.yml @@ -1,24 +1,30 @@ name: Docs Publish + on: push: branches: - main + permissions: contents: write + jobs: deploy: runs-on: ubuntu-latest + timeout-minutes: 15 steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: ./.github/actions/setup + with: + install-groups: '--only-group doc' + - name: Configure Git Credentials run: | git config user.name github-actions[bot] git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - name: Install the latest version of uv - uses: astral-sh/setup-uv@v6 - - - name: Install dependencies - run: uv sync --only-group doc - name: Deploy docs run: uv run --only-group doc mkdocs gh-deploy --force diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 33fad1cf..3e035888 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,18 +6,28 @@ on: - '*' workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: false + jobs: build: name: Build the package runs-on: ubuntu-latest - # Only run if it's a tagged commit or manual dispatch + timeout-minutes: 10 if: startsWith(github.ref, 'refs/tags') || github.event_name == 'workflow_dispatch' + permissions: + contents: read steps: - uses: actions/checkout@v4 with: fetch-depth: 0 + - uses: ./.github/actions/setup + with: + install-deps: 'false' + - name: Verify tag is on main branch if: startsWith(github.ref, 'refs/tags') run: | @@ -27,9 +37,6 @@ jobs: exit 1 fi - - name: Install uv - uses: astral-sh/setup-uv@v6 - - name: Build a binary wheel and a source tarball run: uv build @@ -43,13 +50,11 @@ jobs: name: Publish the package needs: build runs-on: ubuntu-latest + timeout-minutes: 10 permissions: id-token: write steps: - - name: Print ref - run: echo ${{ github.ref }} - - name: Download all workflow run artifacts uses: actions/download-artifact@v4 with: @@ -58,6 +63,8 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v6 + with: + enable-cache: true - name: Publish package to PyPI run: uv publish --verbose --token ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..9d8aaacd --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,36 @@ +name: Release + +on: + push: + tags: + - '*' + +permissions: + contents: write + +jobs: + release: + name: Create GitHub Release + runs-on: ubuntu-latest + timeout-minutes: 5 + if: startsWith(github.ref, 'refs/tags') + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Verify tag is on main branch + run: | + git fetch origin main + if ! git merge-base --is-ancestor ${{ github.sha }} origin/main; then + echo "Error: Tag is not on the main branch" + exit 1 + fi + + - name: Create Release + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + draft: false + prerelease: false diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 86628ff3..00000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Tests - -on: - push: - branches: - - main - pull_request: - workflow_dispatch: - schedule: - - cron: "0 4 * * *" - -env: - FORCE_COLOR: 1 - -jobs: - pytest: - name: Unit tests - strategy: - matrix: - os: [ubuntu-latest] - python: ["3.10", "3.11", "3.12"] - fail-fast: false - - runs-on: ${{ matrix.os }} - env: - OS: ${{ matrix.os }} - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Install the latest version of uv - uses: astral-sh/setup-uv@v6 - with: - python-version: ${{ matrix.python }} - - - name: Install dependencies - run: uv sync --all-extras --all-groups - - # Run all tests on schedule, but only non-slow tests on push - - name: Run pytest with coverage - run: | - if [ "${{ github.event_name }}" == "schedule" ]; then - uv run coverage run -m pytest tests --durations=0 - else - uv run coverage run -m pytest tests --durations=0 -m "not slow" - fi - shell: bash - - - name: Generate coverage report - if: ${{ matrix.os == 'ubuntu-latest' && matrix.python == '3.12' }} - run: | - uv run coverage xml - ls -la coverage.xml - - - name: Upload coverage reports to Codecov - if: ${{ matrix.os == 'ubuntu-latest' && matrix.python == '3.12' }} - uses: codecov/codecov-action@v5.4.3 - with: - files: ./coverage.xml - fail_ci_if_error: true - verbose: true - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 2c69e633..3755fab6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,16 +159,15 @@ cifar10/ # Paper paper/jats/ -# Projects folder -projects/* -!projects/README.md -!projects/cifar10 -# !projects/01_simple_classification -# !projects/02_self_supervised_learning -# !projects/03_multi_task_learning -# !projects/04_image_segmentation # Coding agents CLAUDE.md GEMINI.md AGENTS.md + +**/.claude/ + +**/lightning_logs/ +**/outputs/ +**/.datasets/ +**/*.zip diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e69de29b..00000000 diff --git a/.safety-policy.yml b/.safety-policy.yml deleted file mode 100644 index c21c3c95..00000000 --- a/.safety-policy.yml +++ /dev/null @@ -1,81 +0,0 @@ -# Safety Security and License Configuration file -# We recommend checking this file into your source control in the root of your Python project -# If this file is named .safety-policy.yml and is in the same directory where you run `safety check` it will be used by default. -# Otherwise, you can use the flag `safety check --policy-file ` to specify a custom location and name for the file. -# To validate and review your policy file, run the validate command: `safety validate policy_file --path ` -security: # configuration for the `safety check` command - ignore-cvss-severity-below: 4 # A severity number between 0 and 10. Some helpful reference points: 9=ignore all vulnerabilities except CRITICAL severity. 7=ignore all vulnerabilities except CRITICAL & HIGH severity. 4=ignore all vulnerabilities except CRITICAL, HIGH & MEDIUM severity. - ignore-cvss-unknown-severity: False # True or False. We recommend you set this to False. - ignore-vulnerabilities: # Here you can list multiple specific vulnerabilities you want to ignore (optionally for a time period) - # We recommend making use of the optional `reason` and `expires` keys for each vulnerability that you ignore. - 70612: - reason: "https://data.safetycli.com/v/70612/97c/" - expires: '2025-12-31' # datetime string - date this ignore will expire, best practice to use this variable - continue-on-vulnerability-error: True # Suppress non-zero exit codes when vulnerabilities are found. Enable this in pipelines and CI/CD processes if you want to pass builds that have vulnerabilities. We recommend you set this to False. -alert: # configuration for the `safety alert` command - security: - # Configuration specific to Safety's GitHub Issue alerting - github-issue: - # Same as for security - these allow controlling if this alert will fire based - # on severity information. - # default: not set - # ignore-cvss-severity-below: 6 - # ignore-cvss-unknown-severity: False - - # Add a label to pull requests with the cvss severity, if available - # default: true - # label-severity: True - - # Add a label to pull requests, default is 'security' - # requires private repo permissions, even on public repos - # default: security - # labels: - # - security - - # Assign users to pull requests, default is not set - # requires private repo permissions, even on public repos - # default: empty - # assignees: - # - example-user - - # Prefix to give issues when creating them. Note that changing - # this might cause duplicate issues to be created. - # default: "[PyUp] " - # issue-prefix: "[PyUp] " - - # Configuration specific to Safety's GitHub PR alerting - github-pr: - # Same as for security - these allow controlling if this alert will fire based - # on severity information. - # default: not set - # ignore-cvss-severity-below: 6 - # ignore-cvss-unknown-severity: False - - # Set the default branch (ie, main, master) - # default: empty, the default branch on GitHub - branch: '' - - # Add a label to pull requests with the cvss severity, if available - # default: true - # label-severity: True - - # Add a label to pull requests, default is 'security' - # requires private repo permissions, even on public repos - # default: security - # labels: - # - security - - # Assign users to pull requests, default is not set - # requires private repo permissions, even on public repos - # default: empty - # assignees: - # - example-user - - # Configure the branch prefix for PRs created by this alert. - # NB: Changing this will likely cause duplicate PRs. - # default: pyup/ - branch-prefix: pyup/ - - # Set a global prefix for PRs - # default: "[PyUp] " - pr-prefix: "[PyUp] " diff --git a/README.md b/README.md index 82e7c1a1..da80e21a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

- Tests + CI Coverage PyPI License @@ -12,16 +12,33 @@


-
- - - - Features - -
-
+**Lighter** makes PyTorch Lightning experiments reproducible and composable through YAML configuration. Stop hardcoding hyperparameters—configure everything from the command line. + +## Why Lighter? + +You're already using PyTorch Lightning. But every experiment requires editing Python code to change hyperparameters: + +```python +# Want to try a different learning rate? Edit the code. +optimizer = Adam(params, lr=0.001) # Change this line + +# Want to use a different batch size? Edit the code. +train_loader = DataLoader(dataset, batch_size=32) # And this one -**Lighter** is a YAML-driven deep learning framework built on PyTorch Lightning. Define your model, data, and training in config files instead of writing boilerplate code. +# Want to train longer? Edit the code again. +trainer = Trainer(max_epochs=10) # And this one too +``` + +**With Lighter, configure everything in YAML and override from the CLI:** + +```bash +# Try different learning rates without touching code +lighter fit config.yaml model::optimizer::lr=0.001 +lighter fit config.yaml model::optimizer::lr=0.01 +lighter fit config.yaml model::optimizer::lr=0.1 + +# Every experiment is reproducible - just version control your configs +``` ## Quick Start @@ -29,69 +46,172 @@ pip install lighter ``` -Create `config.yaml`: +**Use your existing PyTorch Lightning code:** + +```python +# model.py +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +class MyModel(pl.LightningModule): + def __init__(self, network, learning_rate=0.001): + super().__init__() + self.network = network + self.lr = learning_rate + + def training_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self.network(x), y) + self.log("train/loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.lr) +``` + +**Configure in YAML instead of hardcoding:** + ```yaml +# config.yaml trainer: _target_: pytorch_lightning.Trainer max_epochs: 10 -system: - _target_: lighter.System - model: +model: + _target_: model.MyModel + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + learning_rate: 0.001 + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true +``` + +**Run and iterate fast:** + +```bash +# Run your experiment +lighter fit config.yaml + +# Try different hyperparameters - no code editing needed +lighter fit config.yaml model::learning_rate=0.01 +lighter fit config.yaml trainer::max_epochs=50 +lighter fit config.yaml data::train_dataloader::batch_size=64 + +# Use multiple GPUs +lighter fit config.yaml trainer::devices=4 + +# Every run creates timestamped outputs with saved configs +# outputs/2025-11-21/14-30-45/config.yaml # Fully reproducible +``` + +## Key Benefits + +- **Reproducible**: Every experiment = one YAML file. Version control configs like code. +- **Fast iteration**: Override any parameter from CLI without editing code. +- **Zero lock-in**: Works with any PyTorch Lightning module. Your code, your logic. +- **Composable**: Merge configs, create recipes, share experiments as files. +- **Organized**: Automatic timestamped output directories with saved configs. +- **Simple**: ~500 lines of code. Read the framework in 30 minutes. + +## Optional: Use LighterModule for Less Boilerplate + +If you want automatic optimizer configuration and dual logging (step + epoch), use `LighterModule`: + +```python +from lighter import LighterModule + +class MyModel(LighterModule): + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) + + if self.train_metrics: + self.train_metrics(pred, y) + + return {"loss": loss} # Framework logs automatically + + # validation_step, test_step, predict_step... +``` + +```yaml +model: + _target_: model.MyModel + network: _target_: torchvision.models.resnet18 num_classes: 10 criterion: _target_: torch.nn.CrossEntropyLoss optimizer: _target_: torch.optim.Adam - params: "$@system::model.parameters()" + params: "$@model::network.parameters()" lr: 0.001 - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 32 - dataset: - _target_: torchvision.datasets.CIFAR10 - root: ./data - train: true - download: true - transform: - _target_: torchvision.transforms.ToTensor + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 ``` -Run: -```bash -lighter fit config.yaml -``` +**LighterModule gives you:** +- Automatic `configure_optimizers()` handling +- Automatic dual logging (step + epoch) +- Config-driven criterion and metrics + +**But you still control:** +- All step implementations +- Loss computation logic +- When to call metrics + +## Example: Running a Hyperparameter Sweep -Override from CLI: ```bash -lighter fit config.yaml system::optimizer::lr=0.01 +# Run grid search without editing code +for lr in 0.001 0.01 0.1; do + for bs in 32 64 128; do + lighter fit config.yaml \ + model::optimizer::lr=$lr \ + data::train_dataloader::batch_size=$bs + done +done + +# Each run saved in outputs/YYYY-MM-DD/HH-MM-SS/ with config.yaml +# Compare experiments by diffing configs ``` -**[→ Full tutorial](https://project-lighter.github.io/lighter/tutorials/get-started/)** - ## Documentation -- 📚 [Get Started](https://project-lighter.github.io/lighter/tutorials/get-started/) -- ⚙️ [Configuration Guide](https://project-lighter.github.io/lighter/how-to/configuration/) -- 🔌 [Adapters](https://project-lighter.github.io/lighter/how-to/adapters/) -- 🏗️ [Architecture](https://project-lighter.github.io/lighter/design/overview/) +- 📚 [Get Started Tutorial](https://project-lighter.github.io/lighter/tutorials/get-started/) - 15 min walkthrough +- ⚙️ [Configuration Guide](https://project-lighter.github.io/lighter/how-to/configuration/) - Master the syntax +- 🎯 [LighterModule Design](https://project-lighter.github.io/lighter/design/model/) - Understand the internals +- 🏗️ [Architecture Overview](https://project-lighter.github.io/lighter/design/overview/) - How it all works -## Projects Using Lighter +## Real-World Usage - 🏥 [Foundation Models for Cancer Imaging](https://aim.hms.harvard.edu/foundation-cancer-image-biomarker) - 🧠 [Vision Foundation Models for CT](https://arxiv.org/abs/2501.09001) ## Community -- 💬 [Discord](https://discord.gg/zJcnp6KrUp) -- 🐛 [GitHub Issues](https://github.com/project-lighter/lighter/issues) -- 📺 [YouTube](https://www.youtube.com/channel/UCef1oTpv2QEBrD2pZtrdk1Q) -- 🤝 [Contributing](CONTRIBUTING.md) +- 💬 [Discord](https://discord.gg/zJcnp6KrUp) - Chat with users +- 🐛 [GitHub Issues](https://github.com/project-lighter/lighter/issues) - Report bugs +- 📺 [YouTube](https://www.youtube.com/channel/UCef1oTpv2QEBrD2pZtrdk1Q) - Video tutorials +- 🤝 [Contributing](CONTRIBUTING.md) - Help improve Lighter ## Citation +If Lighter helps your research, please cite our [JOSS paper](https://joss.theoj.org/papers/10.21105/joss.08101): + ```bibtex @article{lighter, doi = {10.21105/joss.08101}, diff --git a/assets/images/coverage.svg b/assets/images/coverage.svg deleted file mode 100644 index 0fa96494..00000000 --- a/assets/images/coverage.svg +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - - - - - - - - coverage - coverage - 98% - 98% - - diff --git a/assets/images/features_dark.png b/assets/images/features_dark.png deleted file mode 100644 index eb9bc2c2..00000000 Binary files a/assets/images/features_dark.png and /dev/null differ diff --git a/assets/images/features_light.png b/assets/images/features_light.png deleted file mode 100644 index b6ecde7e..00000000 Binary files a/assets/images/features_light.png and /dev/null differ diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 05746468..00000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.8-slim-buster - -ENV LANG=C.UTF-8 \ - LC_ALL=C.UTF-8 \ - PATH="${PATH}:/root/.poetry/bin" - -RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - curl \ - && rm -rf /var/lib/apt/lists/* - -COPY pyproject.toml ./ - -# Install Poetry -RUN curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | POETRY_HOME=/opt/poetry python && \ - cd /usr/local/bin && \ - ln -s /opt/poetry/bin/poetry && \ - poetry config virtualenvs.create false - -# Allow installing dev dependencies to run tests -ARG INSTALL_DEV=false -RUN bash -c "if [ $INSTALL_DEV == 'true' ] ; then poetry install --no-root ; else poetry install --no-root --no-dev ; fi" - -CMD mkdir -p /workspace -WORKDIR /workspace diff --git a/docker/README.md b/docker/README.md deleted file mode 100644 index dab70c6e..00000000 --- a/docker/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Docker for lighter - -## Installation - -To create Docker you need to run: - -```bash -make docker-build -``` - -which is equivalent to: - -```bash -make docker-build VERSION=latest -``` - -You may provide name and version for the image. -Default name is `IMAGE := lighter`. -Default version is `VERSION := latest`. - -```bash -make docker-build IMAGE=some_name VERSION=0.0.1 -``` - -## Usage - -```bash -docker run -it --rm \ - -v $(pwd):/workspace \ - lighter bash -``` - -## How to clean up - -To uninstall docker image run `make docker-remove` with `VERSION`: - -```bash -make docker-remove VERSION=0.0.1 -``` - -you may also choose the image name - -```bash -make docker-remove IMAGE=some_name VERSION=latest -``` - -If you want to clean all, including `build` and `pycache` run `make cleanup` diff --git a/docs/assets/images/overview_all.png b/docs/assets/images/overview_all.png deleted file mode 100644 index 97226bfe..00000000 Binary files a/docs/assets/images/overview_all.png and /dev/null differ diff --git a/docs/assets/images/overview_system.png b/docs/assets/images/overview_system.png deleted file mode 100644 index 6ebaca1d..00000000 Binary files a/docs/assets/images/overview_system.png and /dev/null differ diff --git a/docs/design/adapters.md b/docs/design/adapters.md deleted file mode 100644 index cf6a9293..00000000 --- a/docs/design/adapters.md +++ /dev/null @@ -1,61 +0,0 @@ -# The Adapter Pattern - -## The Problem - -Different ML components may expect different data formats. Consider a scenario where: - -- Dataset returns dictionaries of tensors -- Model expects tensors -- Loss function needs specific argument order -- Metrics need different format than loss - -Traditionally, you'd implement a pipeline specific to this scenario. This tightly couples components, making reuse and experimentation difficult. - -## The Solution: Adapters - -![System Data Flow](../assets/images/overview_system.png) -*Data flow through Lighter's System. Adapters bridge components with incompatible interfaces.* - -In software engineering, the [adapter pattern](https://refactoring.guru/design-patterns/adapter) allows incompatible interfaces to work together. Lighter uses adapters to handle variability in data formats. - -## Lighter's Adapter Types - -| Adapter | Purpose | When to Use | -|---------|---------|-------------| -| **BatchAdapter** | Extract data from batches | Different dataset formats | -| **CriterionAdapter** | Format loss inputs | Custom loss functions | -| **MetricsAdapter** | Format metric inputs | Third-party metrics | -| **LoggingAdapter** | Transform before logging | Visualization needs | - -## Example: Task-Agnostic Configuration - -```yaml -adapters: - train: - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_transforms: # Apply sigmoid before loss - _target_: torch.sigmoid - pred_argument: 0 # Map pred to first argument - target_argument: 1 # Map target to second argument -``` - -This enables **any task**—classification, segmentation, self-supervised learning—without framework modifications. - -## Under the Hood - -Adapters are invoked in this order during training: - -1. **BatchAdapter** - Extract input/target from batch -2. **Forward pass** - Model processes input -3. **CriterionAdapter** - Format for loss computation -4. **MetricsAdapter** - Format for metric computation -5. **LoggingAdapter** - Transform for visualization - -## Practical Usage - -For detailed adapter configuration and examples, see: - -- [Adapters How-To Guide](../how-to/adapters.md) - Complete usage guide -- [Metrics Guide](../how-to/metrics.md) - Using MetricsAdapter -- [Writers Guide](../how-to/writers.md) - Using LoggingAdapter diff --git a/docs/design/overview.md b/docs/design/overview.md deleted file mode 100644 index ceba03fa..00000000 --- a/docs/design/overview.md +++ /dev/null @@ -1,239 +0,0 @@ ---- -title: Architecture & Design ---- - -# Architecture & Design Overview - -Lighter is a configuration-driven deep learning framework that separates experimental setup from code implementation. - -## Core Architecture - -![Lighter Overview](../assets/images/overview_all.png) -*Figure: Lighter's three-component (bolded) architecture. Config parses YAML definitions, System encapsulates DL components, and Trainer executes training.* - -### 1. Config -Transforms YAML experiment definitions into Python objects using Sparkwheel. One config file = one reproducible experiment. - -[→ Configuration guide](../how-to/configuration.md) - -### 2. System -Orchestrates your deep learning pipeline—model, optimizer, loss, metrics, data. Extends PyTorch Lightning's LightningModule. - -[→ System internals](system.md) - -### 3. Trainer -PyTorch Lightning's Trainer executes experiments with multi-GPU, mixed precision, gradient accumulation, and checkpointing. - -[→ Running experiments](../how-to/run.md) - -## Stages and Modes - -A key concept in Lighter is the distinction between **Stages** and **Modes**: - -### Stages -Stages are what you invoke from the CLI. They represent high-level operations: - -- `lighter fit` - Train and validate a model -- `lighter validate` - Run validation only -- `lighter test` - Evaluate on test set -- `lighter predict` - Generate predictions - -### Modes -Modes are internal execution contexts that the System uses during a stage. Each mode has its own dataloader, metrics, and adapters: - -- `train` - Training loop with backpropagation -- `val` - Validation loop (no gradients) -- `test` - Testing loop (no gradients) -- `predict` - Inference loop (no targets or loss) - -### Stage-to-Mode Mapping - -Each stage executes one or more modes: - -``` -FIT → [train, val] # Training with validation -VALIDATE → [val] # Validation only -TEST → [test] # Testing only -PREDICT → [predict] # Inference only -``` - -This means when you run `lighter fit`, the system will execute both training and validation modes. When you run `lighter validate`, only the validation mode executes. - -**Code reference**: `src/lighter/engine/runner.py:26-31` - -## Auto Config Pruning - -Lighter automatically prunes (removes) unused components from your configuration based on the stage you're running. This allows you to define a single comprehensive configuration file that works for all stages. - -### What Gets Pruned - -The Runner (`src/lighter/engine/runner.py:80-112`) automatically removes: - -1. **Unused mode components**: Dataloaders and metrics for modes not required by the current stage - - Example: When running `lighter test`, train and val dataloaders/metrics are removed - -2. **Training-only components** (for non-FIT stages): - - Optimizer - - Learning rate scheduler - - Criterion (except for VALIDATE stage, which needs it for computing validation loss) - -3. **Stage-specific arguments**: Arguments defined for other stages are removed - - Example: `args::fit` is removed when running `lighter test` - -### Why This Matters - -Configuration pruning means you can: - -- **Define once, use everywhere**: Write one config with train/val/test dataloaders, then use it for any stage -- **Avoid duplication**: No need for separate train.yaml, test.yaml, predict.yaml files -- **Reduce errors**: The system ensures only relevant components are used for each stage - -### Example - -```yaml -system: - dataloaders: - train: # Used only in FIT stage - _target_: torch.utils.data.DataLoader - # ... config ... - - val: # Used in FIT and VALIDATE stages - _target_: torch.utils.data.DataLoader - # ... config ... - - test: # Used only in TEST stage - _target_: torch.utils.data.DataLoader - # ... config ... - - predict: # Used only in PREDICT stage - _target_: torch.utils.data.DataLoader - # ... config ... - - optimizer: # Pruned for VALIDATE, TEST, PREDICT stages - _target_: torch.optim.Adam - lr: 0.001 - - criterion: # Pruned for TEST, PREDICT stages - _target_: torch.nn.CrossEntropyLoss -``` - -When you run `lighter test`, the Runner automatically removes the train, val, and predict dataloaders, as well as the optimizer and criterion, keeping only what's needed for testing. - -## System Data Flow - -Understanding how data flows through the System is crucial for working with Lighter effectively. The `System._step()` method (`src/lighter/system.py:74-94`) orchestrates this flow: - -### The Pipeline - -``` -1. Batch (from DataLoader) - ↓ -2. BatchAdapter → [input, target, identifier] - ↓ -3. Model.forward(input) → prediction - ↓ (or Inferer in val/test/predict modes) - ↓ -4. CriterionAdapter → Criterion(pred, target) → loss - ↓ -5. MetricsAdapter → Metrics(pred, target) → metric values - ↓ -6. LoggingAdapter → Logger - ↓ -7. Output dict returned to callbacks -``` - -### Step-by-Step Breakdown - -**1. Batch Preparation** (`System._prepare_batch`) -- The raw batch from the DataLoader is passed to the BatchAdapter -- Returns: `(input, target, identifier)` tuple -- Identifier is optional and used for tracking samples - -**2. Forward Pass** (`System.forward`) -- Input goes through `model.forward()` to produce predictions -- In val/test/predict modes, the Inferer replaces model.forward() if specified -- Automatically injects `epoch` and `step` arguments if model accepts them - -**3. Loss Calculation** (`System._calculate_loss`) -- Only in train and val modes (test/predict skip this) -- CriterionAdapter transforms data before passing to criterion -- Supports dict-based losses with sublosses (must include 'total' key) - -**4. Metrics Calculation** (`System._calculate_metrics`) -- Only in train, val, and test modes (predict skips this) -- MetricsAdapter transforms data before passing to metrics -- Returns None if no metrics are defined - -**5. Logging** (`System._log_stats`) -- Logs loss and metrics to the logger -- Automatically logs optimizer stats (lr, momentum, beta) once per epoch in train mode - -**6. Output Preparation** (`System._prepare_output`) -- LoggingAdapter can transform data for cleaner callback access -- Returns a dictionary with all step information - -### The Output Dictionary - -Each step returns a dictionary with the following keys: - -```python -{ - "identifier": batch_identifier, # Optional, for tracking - "input": input_data, - "target": target_data, - "pred": predictions, - "loss": loss_value, # None in test/predict - "metrics": metrics_dict, # None in predict - "step": global_step, - "epoch": current_epoch, -} -``` - -This dictionary is accessible in callbacks, allowing you to write predictions to disk, visualize results, or perform custom analysis. - -**Code references**: -- Data flow orchestration: `src/lighter/system.py:74-94` -- Dict-based loss handling: `src/lighter/system.py:156-160` -- Epoch/step injection: `src/lighter/system.py:121-126` - -### Key Behaviors - -- **Inferer replaces forward()**: In val/test/predict modes, if an inferer is specified, it's called instead of `model.forward()`. This is useful for inference-specific logic like sliding window or test-time augmentation. - -- **Dict-based losses**: If your criterion returns a dictionary of sublosses, it must include a `'total'` key that combines all sublosses. This is used for backpropagation. - -- **Mode-specific adapters**: Each mode (train/val/test/predict) has its own set of adapters, allowing different preprocessing for different stages. - -- **Automatic optimizer stats**: Learning rate, momentum, and beta values are logged automatically once per epoch during training. - -## The Adapter Pattern - -Adapters make Lighter task-agnostic by handling data format differences between components. - -[→ Learn more about adapters](adapters.md) - -## Design Philosophy - -Lighter follows four core principles: **Configuration over Code**, **Composition over Inheritance**, **Convention over Configuration**, and **Separation of Concerns**. - -[→ Understand the philosophy](philosophy.md) - -## Framework Comparison - -Lighter's goal is to brings reproducibility and structure, while keeping you in full control of your code. This is different from other configuration-driven frameworks that provide higher-level abstractions. - -| Feature | **Lighter** | **[Ludwig](https://github.com/ludwig-ai/ludwig)** | **[Quadra](https://github.com/orobix/quadra)** | **[GaNDLF](https://github.com/mlcommons/GaNDLF)** | -|---|---|---|---|---| -| **Primary Focus** | Config-driven, task-agnostic DL | Config-driven, multi-task DL | Config-driven computer vision | Config-driven medical imaging | -| **Configuration** | YAML (Sparkwheel) | YAML (Custom) | YAML (Hydra) | YAML (Custom) | -| **Abstraction** | Medium. Extends PyTorch Lightning, expects standard PyTorch components. | High. Provides pre-built flows for various tasks. | High. Pre-defined structures for computer vision. | High. Pre-defined structures for medical imaging. | -| **Flexibility** | High. New components are added via project module. | Medium. Adding new components requires code editing. | Low. Adding new components requires code editing. | Low. Adding new components requires code editing. | -| **Use Case** | Organized experimentation | Production-level applications | Traditional computer vision | Established medical imaging methods | - -Lighter is the tool for you if you like PyTorch's flexibility but want to manage your experiments in a structured and reproducible way. - -## Next Steps - -- Deep dive into [the Adapter Pattern](adapters.md) -- Understand [Design Philosophy](philosophy.md) -- Get started with the [Zero to Hero tutorial](../tutorials/get-started.md) diff --git a/docs/design/philosophy.md b/docs/design/philosophy.md deleted file mode 100644 index 3025571e..00000000 --- a/docs/design/philosophy.md +++ /dev/null @@ -1,81 +0,0 @@ -# Design Philosophy - -Why Lighter is designed the way it is. - -## Core Principles - -### 1. Configuration Over Code -Experiments are data, not code. YAML configs are easier to: -- Version control and compare -- Share and reproduce -- Parametrize and sweep -- Audit and validate - -### 2. Composition Over Inheritance -Instead of subclassing for behavior changes, compose with adapters. More flexible, less coupled. - -### 3. Convention Over Configuration -Sensible defaults (like BatchAdapter assuming `(input, target)` tuples) reduce boilerplate. Override when needed. - -### 4. Separation of Concerns -Clear boundaries: -- **Config** - Experiment definition -- **System** - Component orchestration -- **Trainer** - Execution engine -- **Adapters** - Interface translation - -### 5. Task-Agnostic by Design -No per-task pipelines. Adapters handle variability, enabling unlimited flexibility for novel research. - -## Why ~1,000 Lines of Code? - -**Benefits:** -- Read entire codebase in an afternoon -- Easy to debug and understand -- Simple to extend and maintain -- Low long-term maintenance burden - -**Achieved by:** -- Leveraging PyTorch Lightning (not reinventing training loops) -- Using Sparkwheel's config system (powerful, flexible) -- Focusing on core value: config-driven experiments + adapters - -## Integration Philosophy - -### Standing on Shoulders of Giants - -**PyTorch Lightning** - Battle-tested training engine -- Multi-GPU/TPU support -- Callbacks, loggers, profilers -- Gradient accumulation, mixed precision -- [→ PL Trainer docs](https://lightning.ai/docs/pytorch/stable/common/trainer.html) - -**Sparkwheel** - Powerful configuration system -- Config parsing and validation -- Reference resolution -- Dynamic instantiation -- [→ Sparkwheel docs](https://project-lighter.github.io/sparkwheel/) - -Lighter adds: adapters + System orchestration. - -## Trade-offs - -### When to Use Lighter - -- Configuration-driven experiments are valuable -- You need task-agnostic flexibility -- You want minimal framework overhead -- Reproducibility and sharing are priorities - -### When NOT to Use Lighter - -- Highly custom training loops (use PyTorch directly) -- Prefer code over configuration -- Need high-level AutoML (use Ludwig) -- Domain-specific pipelines sufficient (use GaNDLF/Quadra) - -## Learn More - -- [Architecture Overview](overview.md) - Component details -- [Adapter Pattern](adapters.md) - Deep dive -- [Configuration Guide](../how-to/configuration.md) - Practical usage diff --git a/docs/design/system.md b/docs/design/system.md deleted file mode 100644 index a78cc213..00000000 --- a/docs/design/system.md +++ /dev/null @@ -1,212 +0,0 @@ ---- -title: System Internals ---- - -# System Internals - -The `System` class extends PyTorch Lightning's `LightningModule` and orchestrates your entire training pipeline. Understanding its operation helps with debugging and customization. - -## Overview - -System manages: - -- Model architecture -- Optimizer and scheduler -- Loss function (criterion) -- Metrics computation -- Data loading -- Adapters for data transformation -- Inference strategies (inferer) - -## The Unified `_step()` Method - -All modes (train, val, test, predict) use the same `_step()` method: - -``` -1. Batch → BatchAdapter → [input, target, identifier] -2. Model.forward(input) → prediction - (or Inferer in val/test/predict modes) -3. CriterionAdapter → Criterion → loss (train/val only) -4. MetricsAdapter → Metrics → values (train/val/test) -5. LoggingAdapter → Logger -6. Output dict → callbacks -``` - -This unified approach ensures consistency while allowing mode-specific behavior through adapters. - -## Automatic Pruning - -The Runner automatically removes unused components based on stage: - -```yaml -system: - dataloaders: - train: ... # Removed for TEST, PREDICT - val: ... # Removed for TEST, PREDICT - test: ... # Removed for FIT, VALIDATE, PREDICT - predict: ... # Removed for FIT, VALIDATE, TEST - - optimizer: ... # Removed for VALIDATE, TEST, PREDICT - criterion: ... # Removed for TEST, PREDICT -``` - -This enables **one config for all stages**. - -## Mode-Specific Behavior - -### Loss Calculation - -Loss is calculated only in **train** and **val** modes: - -```python -if self.mode in [Mode.TRAIN, Mode.VAL]: - loss = adapters.criterion(self.criterion, input, target, pred) -``` - -Test and predict modes return `None`. - -### Dict-Based Losses - -For multi-task learning, return a dict with `"total"` key: - -```python -def my_criterion(pred, target): - return { - "total": loss1 + loss2, # Required for backprop - "classification": loss1, - "segmentation": loss2, - } -``` - -All sublosses logged automatically; `"total"` used for gradients. - -### Metrics Calculation - -Metrics calculated in **train**, **val**, and **test** modes (not predict): - -```python -if self.mode == Mode.PREDICT or self.metrics[self.mode] is None: - return None -``` - -## Special Features - -### Epoch/Step Injection - -If your model accepts `epoch` or `step` parameters, they're injected automatically: - -```python -class MyModel(nn.Module): - def forward(self, x, epoch=None, step=None): - # Use for curriculum learning - if epoch is not None: - difficulty = min(epoch / self.max_epochs, 1.0) - x = self.apply_difficulty(x, difficulty) - return self.process(x) -``` - -No configuration needed—works automatically. - -### Inferer in Val/Test/Predict - -In validation, testing, and prediction modes, an inferer can replace the forward pass: - -```python -if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]: - return self.inferer(input, self.model, **kwargs) -return self.model(input, **kwargs) -``` - -Use for: - -- Sliding window inference -- Test-time augmentation -- Ensemble methods -- Custom post-processing - -## Automatic Logging - -System logs automatically: - -### Loss - -- Step and epoch level: `{mode}/loss/step`, `{mode}/loss/epoch` -- Individual sublosses for dict-based losses - -### Metrics - -- Step and epoch level: `{mode}/metrics/{name}/step`, `{mode}/metrics/{name}/epoch` - -### Optimizer Stats - -Once per epoch during training: - -- Learning rate: `train/lr` -- Momentum (SGD) -- Beta values (Adam/AdamW) - -## Output Dictionary - -Each step returns: - -```python -{ - "identifier": batch_identifier, # Optional - "input": input_data, # After LoggingAdapter - "target": target_data, # After LoggingAdapter - "pred": predictions, # After LoggingAdapter - "loss": loss_value, # None in test/predict - "metrics": metrics_dict, # None in predict - "step": self.global_step, - "epoch": self.current_epoch, -} -``` - -This dictionary is passed to callbacks for custom processing. - -## Customization - -Extend System for advanced use cases: - -```python -from lighter.system import System - -class CustomSystem(System): - def _log_stats(self, loss, metrics, batch_idx): - super()._log_stats(loss, metrics, batch_idx) - # Add custom logging - if self.mode == Mode.TRAIN: - self.log("custom/my_metric", my_value) - - def on_train_epoch_end(self): - # Custom behavior at epoch end - pass -``` - -Use in config: - -```yaml -system: - _target_: project.CustomSystem - model: ... -``` - -## Summary - -System provides: - -1. **Unified execution**: Same `_step()` for all modes -2. **Automatic pruning**: Unused components removed by stage -3. **Flexible loss**: Scalar or dict-based -4. **Smart injection**: Epoch/step passed to model automatically -5. **Inferer support**: Custom inference logic -6. **Comprehensive logging**: Loss, metrics, optimizer stats -7. **Extensibility**: Subclass for custom behavior - -Understanding System helps you debug issues, optimize performance, and implement advanced training strategies. - -## Next Steps - -- [Adapters](../how-to/adapters.md) - Data transformation -- [Architecture Overview](overview.md) - High-level design -- [API Reference](../reference/) - Complete documentation diff --git a/docs/examples/image-classification.md b/docs/examples/image-classification.md new file mode 100644 index 00000000..fc0cfe5f --- /dev/null +++ b/docs/examples/image-classification.md @@ -0,0 +1,685 @@ +--- +title: Image Classification Example +--- + +# Complete Image Classification Example + +Train a CIFAR-10 classifier from scratch with all the features. + +This example shows a complete, production-ready setup including: + +- Both LightningModule and LighterModule approaches +- Data augmentation +- Learning rate scheduling +- Multiple metrics +- Checkpointing +- Early stopping +- TensorBoard logging +- Multi-GPU support + +## Project Setup + +### Directory Structure + +``` +cifar10/ +├── __lighter__.py # Marker file (enables project.* imports) +├── __init__.py +├── models.py # Model definitions +├── data.py # Data utilities (optional) +├── configs/ +│ ├── resnet18.yaml # ResNet-18 config +│ ├── resnet50.yaml # ResNet-50 config +│ └── efficientnet.yaml # EfficientNet config +└── outputs/ # Generated by Lighter +``` + +### Installation + +```bash +pip install lighter torch torchvision pytorch-lightning torchmetrics +``` + +## Approach 1: Using LightningModule + +### Step 1: Create the Module + +`models.py`: + +```python +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics + + +class CIFAR10Classifier(pl.LightningModule): + """CIFAR-10 image classifier with custom training logic.""" + + def __init__( + self, + network, + learning_rate=0.001, + weight_decay=0.0001, + max_epochs=100, + ): + super().__init__() + self.save_hyperparameters(ignore=['network']) + self.network = network + + # Metrics + self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=10) + self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=10) + self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10) + + def forward(self, x): + return self.network(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + + # Update metrics + self.train_acc(logits, y) + + # Log + self.log('train/loss', loss, on_step=True, on_epoch=True) + self.log('train/acc', self.train_acc, on_step=False, on_epoch=True) + + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + + # Update metrics + self.val_acc(logits, y) + self.val_f1(logits, y) + + # Log + self.log('val/loss', loss) + self.log('val/acc', self.val_acc, on_step=False, on_epoch=True) + self.log('val/f1', self.val_f1, on_step=False, on_epoch=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + ) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=self.hparams.max_epochs, + eta_min=1e-6, + ) + + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'epoch', + } + } +``` + +### Step 2: Create the Config + +`configs/resnet18.yaml`: + +```yaml +# CIFAR-10 Classification with ResNet-18 + +vars: + num_classes: 10 + batch_size: 128 + num_workers: 4 + base_lr: 0.001 + max_epochs: 100 + +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: "%vars::max_epochs" + accelerator: auto + devices: 1 + + callbacks: + # Save best models + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 3 + filename: 'best-acc-{epoch:02d}-{val/acc:.4f}' + + # Early stopping + - _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val/loss + patience: 15 + mode: min + verbose: true + + # Log learning rate + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + + logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: logs + name: cifar10_resnet18 + +model: + _target_: project.models.CIFAR10Classifier + learning_rate: "%vars::base_lr" + weight_decay: 0.0001 + max_epochs: "%vars::max_epochs" + + network: + _target_: torchvision.models.resnet18 + num_classes: "%vars::num_classes" + weights: null # Train from scratch + +data: + _target_: lighter.LighterDataModule + + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: "%vars::batch_size" + shuffle: true + num_workers: "%vars::num_workers" + pin_memory: true + persistent_workers: true + + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + + transform: + _target_: torchvision.transforms.Compose + transforms: + # Augmentation + - _target_: torchvision.transforms.RandomCrop + size: 32 + padding: 4 + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + + # Normalization + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: "%vars::batch_size" + num_workers: "%vars::num_workers" + pin_memory: true + + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + download: true + + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] +``` + +### Step 3: Run Training + +```bash +cd cifar10 +lighter fit configs/resnet18.yaml +``` + +**Expected output:** + +``` +Epoch 0: 100%|████████| 391/391 [00:45<00:00, 8.58it/s, loss=1.82, train/acc=0.335, v_num=0] +Validation: 100%|████████| 79/79 [00:05<00:00, 14.23it/s] +Epoch 1: 100%|████████| 391/391 [00:44<00:00, 8.76it/s, loss=1.45, train/acc=0.472, v_num=0] +... +Epoch 99: 100%|████████| 391/391 [00:43<00:00, 8.98it/s, loss=0.23, train/acc=0.921, v_num=0] +``` + +**Expected accuracy:** ~90-92% on validation set after 100 epochs. + +## Approach 2: Using LighterModule + +### Step 1: Create the Module + +`models.py`: + +```python +from lighter import LighterModule + + +class CIFAR10Model(LighterModule): + """CIFAR-10 classifier using LighterModule.""" + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.criterion(logits, y) + + # Update metrics + if self.train_metrics: + self.train_metrics(logits, y) + + return {'loss': loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.criterion(logits, y) + + # Update metrics + if self.val_metrics: + self.val_metrics(logits, y) + + return {'loss': loss} +``` + +### Step 2: Create the Config + +`configs/resnet18.yaml`: + +```yaml +# CIFAR-10 with LighterModule + +vars: + num_classes: 10 + batch_size: 128 + num_workers: 4 + base_lr: 0.001 + max_epochs: 100 + +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: "%vars::max_epochs" + accelerator: auto + devices: 1 + + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/Accuracy + mode: max + save_top_k: 3 + filename: 'best-acc-{epoch:02d}-{val/Accuracy:.4f}' + + - _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val/loss + patience: 15 + mode: min + + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + + logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: logs + name: cifar10_resnet18 + +model: + _target_: project.models.CIFAR10Model + + network: + _target_: torchvision.models.resnet18 + num_classes: "%vars::num_classes" + weights: null + + criterion: + _target_: torch.nn.CrossEntropyLoss + label_smoothing: 0.1 # Regularization + + optimizer: + _target_: torch.optim.AdamW + params: "$@model::network.parameters()" + lr: "%vars::base_lr" + weight_decay: 0.0001 + + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + optimizer: "@model::optimizer" + T_max: "%vars::max_epochs" + eta_min: 0.000001 + + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: "%vars::num_classes" + - _target_: torchmetrics.F1Score + task: multiclass + num_classes: "%vars::num_classes" + average: macro + + val_metrics: "%model::train_metrics" + +data: + # Same as above... + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: "%vars::batch_size" + shuffle: true + num_workers: "%vars::num_workers" + pin_memory: true + persistent_workers: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.RandomCrop + size: 32 + padding: 4 + - _target_: torchvision.transforms.RandomHorizontalFlip + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: "%vars::batch_size" + num_workers: "%vars::num_workers" + pin_memory: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] +``` + +### Step 3: Run Training + +```bash +lighter fit configs/resnet18.yaml +``` + +Same results, less code! + +## Experiments + +### Try Different Architectures + +Override the network directly from CLI: + +```bash +# ResNet-50 +lighter fit configs/resnet18.yaml \ + model::network::_target_=torchvision.models.resnet50 \ + vars::batch_size=64 \ + vars::base_lr=0.0005 + +# EfficientNet +lighter fit configs/resnet18.yaml \ + model::network::_target_=torchvision.models.efficientnet_b0 \ + vars::batch_size=64 +``` + +### Hyperparameter Tuning + +Try different learning rates: + +```bash +lighter fit configs/resnet18.yaml vars::base_lr=0.0001 +lighter fit configs/resnet18.yaml vars::base_lr=0.001 +lighter fit configs/resnet18.yaml vars::base_lr=0.01 +``` + +Try different batch sizes: + +```bash +lighter fit configs/resnet18.yaml vars::batch_size=64 +lighter fit configs/resnet18.yaml vars::batch_size=256 +``` + +## Multi-GPU Training + +Use all GPUs: + +```bash +lighter fit configs/resnet18.yaml \ + trainer::devices=-1 \ + trainer::strategy=ddp +``` + +Or specific number: + +```bash +lighter fit configs/resnet18.yaml \ + trainer::devices=4 \ + trainer::strategy=ddp +``` + +**Note:** With DDP, effective batch size = `batch_size × num_gpus`. + +## Mixed Precision + +Train faster with 16-bit precision: + +```bash +lighter fit configs/resnet18.yaml trainer::precision=16 +``` + +Or BFloat16 (on A100/H100): + +```bash +lighter fit configs/resnet18.yaml trainer::precision="bf16-mixed" +``` + +## Transfer Learning + +Use pretrained ImageNet weights: + +`configs/pretrained.yaml`: + +```yaml +model: + network: + weights: IMAGENET1K_V2 # Pretrained weights + +# Lower LR for finetuning +vars: + base_lr: 0.0001 +``` + +Run by composing with base config: + +```bash +lighter fit configs/resnet18.yaml configs/pretrained.yaml +``` + +**Expected:** ~94-95% accuracy (better than training from scratch). + +## Monitoring Training + +### TensorBoard + +View logs: + +```bash +tensorboard --logdir logs +``` + +Open browser to `http://localhost:6006`. + +You'll see: +- Train/val loss curves +- Accuracy curves +- Learning rate schedule +- Model graph + +### Weights & Biases + +Use W&B for experiment tracking: + +```yaml +trainer: + logger: + _target_: pytorch_lightning.loggers.WandbLogger + project: cifar10 + name: resnet18_experiment +``` + +Run: + +```bash +wandb login +lighter fit configs/resnet18.yaml +``` + +## Testing + +After training, test on the test set: + +Add test dataloader to config: + +```yaml +data: + test_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 128 + num_workers: 4 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] +``` + +Add test_metrics to model: + +```yaml +model: + test_metrics: "%model::train_metrics" +``` + +Run test: + +```bash +lighter test configs/resnet18.yaml \ + args::test::ckpt_path=outputs/.../checkpoints/best-acc-*.ckpt +``` + +## Complete Working Example + +The complete code is available in the Lighter repository: + +[View on GitHub →](https://github.com/project-lighter/lighter/tree/main/projects/cifar10) + +## Common Issues + +### Out of Memory + +**Problem:** CUDA out of memory error + +**Solutions:** + +1. Reduce batch size: + ```bash + lighter fit config.yaml vars::batch_size=64 + ``` + +2. Use gradient accumulation: + ```yaml + trainer: + accumulate_grad_batches: 2 + ``` + +3. Use mixed precision: + ```yaml + trainer: + precision: 16 + ``` + +### Slow Training + +**Problem:** Training too slow + +**Solutions:** + +1. Increase num_workers: + ```yaml + data: + train_dataloader: + num_workers: 8 + ``` + +2. Enable pin_memory and persistent_workers: + ```yaml + data: + train_dataloader: + pin_memory: true + persistent_workers: true + ``` + +3. Use mixed precision: + ```yaml + trainer: + precision: 16 + ``` + +### Low Accuracy + +**Problem:** Model not learning well + +**Solutions:** + +1. Check data normalization matches pretrained weights +2. Try different learning rate: + ```bash + lighter fit config.yaml vars::base_lr=0.01 + ``` +3. Increase max_epochs +4. Add more augmentation +5. Use pretrained weights + +## Next Steps + +- [Multi-GPU Example](multi-gpu.md) - Distributed training setup +- [Training Guide](../guides/training.md) - More training tips +- [Best Practices](../guides/best-practices.md) - Production patterns + +## Summary + +This example showed: + +- ✅ Complete CIFAR-10 classifier setup +- ✅ Both LightningModule and LighterModule approaches +- ✅ Data augmentation and normalization +- ✅ Learning rate scheduling +- ✅ Multiple metrics +- ✅ Checkpointing and early stopping +- ✅ TensorBoard logging +- ✅ Hyperparameter tuning from CLI +- ✅ Multi-GPU support +- ✅ Mixed precision training +- ✅ Transfer learning + +**Key takeaway:** One config file controls everything. Iterate fast without code changes. diff --git a/docs/examples/multi-gpu.md b/docs/examples/multi-gpu.md new file mode 100644 index 00000000..1e45f18d --- /dev/null +++ b/docs/examples/multi-gpu.md @@ -0,0 +1,704 @@ +--- +title: Multi-GPU Training +--- + +# Multi-GPU Training + +Scale your training across multiple GPUs with Distributed Data Parallel (DDP). + +This guide shows how to train on multiple GPUs using PyTorch Lightning's DDP strategy through Lighter configs. + +## Quick Start + +Train on all available GPUs: + +```bash +lighter fit config.yaml trainer::devices=-1 trainer::strategy=ddp +``` + +Train on specific number of GPUs: + +```bash +lighter fit config.yaml trainer::devices=4 trainer::strategy=ddp +``` + +That's it! Your code works unchanged. + +## Configuration + +### In Config File + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + devices: 4 # Use 4 GPUs + strategy: ddp # Distributed Data Parallel + accelerator: auto # Automatically use CUDA if available +``` + +### From CLI + +Override devices: + +```bash +# All GPUs +lighter fit config.yaml trainer::devices=-1 + +# Specific GPUs (0, 1, 2, 3) +lighter fit config.yaml trainer::devices=4 + +# Specific GPU IDs +lighter fit config.yaml 'trainer::devices=[0,2,3]' +``` + +## DDP Strategy + +### Basic DDP + +Recommended for most cases: + +```yaml +trainer: + strategy: ddp +``` + +Features: +- Each GPU gets own process +- Gradients synchronized across GPUs +- Model replicated on each GPU +- Data split across GPUs + +### DDP Spawn + +Alternative that spawns subprocesses: + +```yaml +trainer: + strategy: ddp_spawn +``` + +Use when: +- DDP doesn't work on your system +- Debugging (easier to see errors) + +**Note:** Slightly slower than DDP. + +### DDP Find Unused Parameters + +If you get "unused parameters" error: + +```yaml +trainer: + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + find_unused_parameters: true +``` + +## Batch Size Considerations + +### Per-GPU Batch Size + +Each GPU processes `batch_size` samples: + +```yaml +data: + train_dataloader: + batch_size: 32 # Each GPU: 32 samples +``` + +**Effective batch size** with 4 GPUs = 32 × 4 = 128 + +### Keep Total Batch Size + +To keep same total batch size across different GPU counts: + +```yaml +vars: + num_gpus: 4 + total_batch_size: 128 + +data: + train_dataloader: + batch_size: "$%vars::total_batch_size // %vars::num_gpus" +``` + +Override for different GPU counts: + +```bash +# 1 GPU: batch_size = 128 +lighter fit config.yaml vars::num_gpus=1 + +# 4 GPUs: batch_size = 32 per GPU +lighter fit config.yaml vars::num_gpus=4 + +# 8 GPUs: batch_size = 16 per GPU +lighter fit config.yaml vars::num_gpus=8 +``` + +## Learning Rate Scaling + +### Linear Scaling Rule + +When increasing batch size, scale LR proportionally: + +```yaml +vars: + num_gpus: 4 + base_lr: 0.001 + +model: + optimizer: + lr: "$%vars::base_lr * %vars::num_gpus" +``` + +**Example:** +- 1 GPU: LR = 0.001, batch = 32 +- 4 GPUs: LR = 0.004, batch = 128 (32×4) + +### Square Root Scaling + +Alternative for very large batch sizes: + +```yaml +model: + optimizer: + lr: "$%vars::base_lr * (%vars::num_gpus ** 0.5)" +``` + +## Complete Multi-GPU Example + +`experiments/multi_gpu.yaml`: + +```yaml +vars: + # Hardware + num_gpus: 4 + + # Dataset + num_classes: 10 + + # Hyperparameters + base_lr: 0.001 + total_batch_size: 512 # Total across all GPUs + max_epochs: 100 + + # Computed + per_gpu_batch_size: "$%vars::total_batch_size // %vars::num_gpus" + scaled_lr: "$%vars::base_lr * %vars::num_gpus" + +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: "%vars::max_epochs" + devices: "%vars::num_gpus" + strategy: ddp + accelerator: auto + + # Recommended settings + sync_batchnorm: true # Sync batch normalization + precision: 16 # Mixed precision + + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 3 + + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + + logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: logs + name: multi_gpu_experiment + +model: + _target_: lighter.LighterModule + + network: + _target_: torchvision.models.resnet50 + num_classes: "%vars::num_classes" + + criterion: + _target_: torch.nn.CrossEntropyLoss + + optimizer: + _target_: torch.optim.AdamW + params: "$@model::network.parameters()" + lr: "%vars::scaled_lr" + weight_decay: 0.0001 + + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + optimizer: "@model::optimizer" + T_max: "%vars::max_epochs" + + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: "%vars::num_classes" + + val_metrics: "%model::train_metrics" + +data: + _target_: lighter.LighterDataModule + + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: "%vars::per_gpu_batch_size" + shuffle: true + num_workers: 8 # Increase for multi-GPU + pin_memory: true + persistent_workers: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.RandomCrop + size: 32 + padding: 4 + - _target_: torchvision.transforms.RandomHorizontalFlip + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: "%vars::per_gpu_batch_size" + num_workers: 8 + pin_memory: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] +``` + +Run: + +```bash +lighter fit experiments/multi_gpu.yaml +``` + +## Advanced Strategies + +### FSDP (Fully Sharded Data Parallel) + +For very large models that don't fit on single GPU: + +```yaml +trainer: + strategy: fsdp + devices: 4 +``` + +FSDP shards: +- Model parameters +- Gradients +- Optimizer states + +Across GPUs, saving memory. + +### DeepSpeed + +For even larger models: + +```yaml +trainer: + strategy: + _target_: pytorch_lightning.strategies.DeepSpeedStrategy + stage: 2 # ZeRO Stage 2 + devices: 4 + precision: 16 +``` + +Stages: +- **Stage 1**: Shard optimizer states +- **Stage 2**: Shard gradients + Stage 1 +- **Stage 3**: Shard parameters + Stage 2 + +### DDP with Static Graph + +For maximum performance (PyTorch 1.11+): + +```yaml +trainer: + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + static_graph: true +``` + +**Requirements:** +- Model structure doesn't change between steps +- No dynamic control flow + +**Benefit:** ~10% speedup. + +## Data Loading Optimization + +### Increase num_workers + +More workers for multi-GPU: + +```yaml +data: + train_dataloader: + num_workers: "$%vars::num_gpus * 4" # 4 workers per GPU +``` + +### Use Persistent Workers + +Avoid worker respawning: + +```yaml +data: + train_dataloader: + persistent_workers: true +``` + +### Pin Memory + +Faster GPU transfer: + +```yaml +data: + train_dataloader: + pin_memory: true +``` + +## Gradient Accumulation + +Simulate even larger batch sizes: + +```yaml +trainer: + accumulate_grad_batches: 4 +``` + +**Effective batch size** = `batch_size × num_gpus × accumulate_grad_batches` + +Example with 4 GPUs: +- `batch_size = 32` +- `num_gpus = 4` +- `accumulate_grad_batches = 4` +- **Effective = 32 × 4 × 4 = 512** + +## Sync Batch Normalization + +Important for small per-GPU batch sizes: + +```yaml +trainer: + sync_batchnorm: true +``` + +Synchronizes batch norm statistics across GPUs. + +**Use when:** Per-GPU batch size < 8. + +## Mixed Precision + +Combine with multi-GPU for maximum speed: + +```yaml +trainer: + precision: 16 # or "bf16-mixed" + devices: 4 + strategy: ddp +``` + +**Speedup:** ~2-3× faster than FP32. + +## Monitoring Multi-GPU Training + +### TensorBoard + +Same as single GPU: + +```bash +tensorboard --logdir logs +``` + +Metrics automatically aggregated across GPUs. + +### Weights & Biases + +Works out of the box: + +```yaml +trainer: + logger: + _target_: pytorch_lightning.loggers.WandbLogger + project: my_project +``` + +Only rank 0 process logs to avoid duplicates. + +## Checkpointing + +Same as single GPU: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + save_top_k: 3 +``` + +Only rank 0 saves checkpoints automatically. + +## Testing Locally + +Test DDP on single machine with multiple physical GPUs: + +```bash +# Run DDP on 2 physical GPUs +lighter fit config.yaml trainer::devices=2 trainer::strategy=ddp +``` + +**Important:** `devices=k` selects `k` physical GPUs per node (equivalent to `list(range(k))`). No GPU virtualization or simulation is performed. You must have at least as many physical GPUs as specified (e.g., at least 2 physical GPUs for the example above). + +## Common Issues + +### Out of Memory + +**Solutions:** + +1. Reduce per-GPU batch size: + ```bash + lighter fit config.yaml data::train_dataloader::batch_size=16 + ``` + +2. Use gradient accumulation: + ```yaml + trainer: + accumulate_grad_batches: 2 + ``` + +3. Use FSDP or DeepSpeed for large models + +### Slow Startup + +**Problem:** Long startup time with DDP + +**Cause:** Dataset download or preprocessing on each rank + +**Solution:** Download data before training: + +```bash +# Download once +python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)" + +# Then train +lighter fit config.yaml +``` + +### Hanging at Initialization + +**Problem:** Process hangs at "Initializing distributed" + +**Solutions:** + +1. Check firewall settings +2. Try different DDP backend: + ```yaml + trainer: + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + process_group_backend: gloo # Instead of nccl + ``` + +### Different Results Across GPUs + +**Problem:** Metrics differ between runs + +**Cause:** Random seed not set or data shuffling + +**Solution:** + +```python +# In __lighter__.py +import pytorch_lightning as pl +pl.seed_everything(42, workers=True) +``` + +```yaml +data: + train_dataloader: + shuffle: true # Ensure shuffling +``` + +### Unused Parameters Error + +**Problem:** "RuntimeError: Expected to have finished reduction in the prior iteration" + +**Solution:** + +```yaml +trainer: + strategy: + _target_: pytorch_lightning.strategies.DDPStrategy + find_unused_parameters: true +``` + +## Performance Tips + +### 1. Use All CPU Cores + +```yaml +data: + train_dataloader: + num_workers: "$%vars::num_gpus * 4" +``` + +### 2. Prefetch Data + +```yaml +data: + train_dataloader: + prefetch_factor: 2 +``` + +### 3. Mixed Precision + +```yaml +trainer: + precision: 16 +``` + +### 4. Compile Model (PyTorch 2.0+) + +In your module: + +```python +def __init__(self, network, ...): + super().__init__() + self.network = torch.compile(network) +``` + +### 5. Optimize Data Loading + +- Cache dataset if it fits in RAM +- Preprocess data offline +- Use fast storage (SSD > HDD) + +## Scaling Example + +Compare different GPU counts: + +```bash +# 1 GPU baseline +lighter fit config.yaml vars::num_gpus=1 + +# 2 GPUs (~1.8× speedup) +lighter fit config.yaml vars::num_gpus=2 + +# 4 GPUs (~3.5× speedup) +lighter fit config.yaml vars::num_gpus=4 + +# 8 GPUs (~6.5× speedup) +lighter fit config.yaml vars::num_gpus=8 +``` + +**Expected scaling:** ~85-90% efficiency (linear would be 100%). + +## Multi-Node Training + +For training across multiple machines: + +```yaml +trainer: + strategy: ddp + devices: 4 # GPUs per node + num_nodes: 2 # Number of machines +``` + +Run on each node: + +```bash +# Node 0 +MASTER_ADDR=node0_address MASTER_PORT=12345 \ + lighter fit config.yaml \ + trainer::num_nodes=2 \ + trainer::devices=4 + +# Node 1 +MASTER_ADDR=node0_address MASTER_PORT=12345 NODE_RANK=1 \ + lighter fit config.yaml \ + trainer::num_nodes=2 \ + trainer::devices=4 +``` + +Requires: +- Shared filesystem for checkpoints +- Network connectivity between nodes +- Matching software environment + +## Quick Reference + +```yaml +# Basic multi-GPU +trainer: + devices: 4 + strategy: ddp + +# All GPUs +trainer: + devices: -1 + strategy: ddp + +# Large models +trainer: + strategy: fsdp + devices: 4 + +# Very large models +trainer: + strategy: + _target_: pytorch_lightning.strategies.DeepSpeedStrategy + stage: 3 + devices: 4 + +# Batch size scaling +vars: + num_gpus: 4 + total_batch: 512 + +data: + train_dataloader: + batch_size: "$%vars::total_batch // %vars::num_gpus" + +# LR scaling +model: + optimizer: + lr: "$%vars::base_lr * %vars::num_gpus" +``` + +## Next Steps + +- [Training Guide](../guides/training.md) - More training strategies +- [Best Practices](../guides/best-practices.md) - Production optimization +- [Image Classification Example](image-classification.md) - Complete example + +## Summary + +Multi-GPU training with Lighter: + +- ✅ Simple config changes only +- ✅ Code works unchanged +- ✅ Automatic gradient synchronization +- ✅ Linear scaling with proper settings +- ✅ Multiple strategies (DDP, FSDP, DeepSpeed) +- ✅ Works with all Lightning features + +**Key takeaway:** Add `trainer::devices=4 trainer::strategy=ddp` to use 4 GPUs. That's it! diff --git a/docs/faq.md b/docs/faq.md index a9cf2a5f..fbd7299e 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -1,250 +1,270 @@ -# Frequently Asked Questions - -## General +--- +title: FAQ +--- -**What is Lighter?** +# Frequently Asked Questions -A configuration-driven deep learning framework built on PyTorch Lightning. Define experiments in YAML instead of writing training code. [Get Started →](tutorials/get-started.md) +Quick answers to common questions. -**How does Lighter compare to PyTorch Lightning?** +## What is Lighter? -Lighter extends Lightning's `LightningModule` but uses YAML configs instead of Python classes. You get all Lightning features (multi-GPU, callbacks, loggers) plus config-driven simplicity. [Migration Guide →](migration/from-pytorch-lightning.md) +Lighter is a YAML configuration layer for PyTorch Lightning experiments. -**When should I use Lighter?** +**You write:** Standard PyTorch Lightning code (LightningModule, datasets, etc.) -Use Lighter when: +**Lighter provides:** YAML configs, CLI overrides, experiment tracking -- Running many experiments with different hyperparameters -- Reproducibility and experiment tracking are priorities -- You prefer configuration over code -- You want PyTorch's flexibility with structure +**Result:** Reproducible experiments without hardcoded hyperparameters. -Don't use Lighter when: +## Do I need to rewrite my LightningModule? -- You need ultra-custom training loops -- Rapid architecture prototyping (code first, config later) -- You prefer code-only workflows +No! Use it directly: -**Is Lighter only for medical imaging?** +```yaml +model: + _target_: my_project.MyLightningModule # Your existing code + learning_rate: 0.001 +``` -No. Lighter is task-agnostic and works for any deep learning task: classification, detection, segmentation, NLP, self-supervised learning, etc. [See examples →](how-to/recipes.md) +No changes to your Python code required. -**What's the performance overhead?** +## When should I use LighterModule vs my own LightningModule? -Minimal (<1%). Config resolution happens once at startup. Training speed is identical to PyTorch Lightning. +**Use LightningModule when:** -## Configuration +- Migrating existing projects +- Need custom training logic +- Want full control +- Team knows Lightning well -**What's the difference between `@`, `%`, and `$`?** +**Use LighterModule when:** -Lighter uses [Sparkwheel](https://project-lighter.github.io/sparkwheel/) for configuration: +- Starting new projects +- Want less boilerplate +- Standard workflows (classification, segmentation, etc.) +- Config-driven everything -| Symbol | Purpose | Example | -|--------|---------|---------| -| `@` | Resolved reference (instantiated object) | `"@system::optimizer"` | -| `%` | Raw reference (unprocessed YAML) | `"%system::metrics::train"` | -| `$` | Evaluate Python expression | `"$0.001 * 2"` | +Both are equally supported and give you YAML configs + CLI overrides. -- `@` gets the final instantiated object after processing -- `%` copies raw YAML config (creates new instance when used with `_target_`) -- `$` evaluates Python code in expressions +## How do I use my custom models and datasets? -[Complete syntax guide →](how-to/configuration.md) | [Sparkwheel docs →](https://project-lighter.github.io/sparkwheel/) +Three steps: -**How do I pass model parameters to the optimizer?** +1. Add `__lighter__.py` to your project root (can be empty) +2. Ensure all directories have `__init__.py` +3. Reference as `project.module.ClassName` in config ```yaml -optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 +model: + network: + _target_: my_project.models.CustomNet + num_classes: 10 ``` -The `$` evaluates Python, `@` gets the resolved model instance, `.parameters()` calls the method. +[Full guide →](guides/custom-code.md) -**Can I use Python code in configs?** +## How do I override config from CLI? -Yes. Use `$` prefix for expressions: +Use `::` to navigate config paths: -```yaml -optimizer: - lr: "$0.001 * 2" # Evaluates to 0.002 +```bash +# Single override +lighter fit config.yaml model::optimizer::lr=0.01 + +# Multiple overrides +lighter fit config.yaml \ + model::optimizer::lr=0.01 \ + trainer::max_epochs=100 \ + data::train_dataloader::batch_size=64 ``` -[Advanced configuration →](how-to/configuration.md#advanced-features) +No file editing needed! + +## What's the difference between `@` and `%`? -**How do I add callbacks without replacing existing ones?** +- **`@`** = Resolved reference (gets the instantiated Python object) +- **`%`** = Raw reference (copies the YAML config to create new instance) -By default, configs merge automatically. Later configs add to earlier ones: +**Critical:** Always use `%` for metrics: ```yaml -# base.yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint +# ❌ WRONG - Shared instance pollutes metrics +val_metrics: "@model::train_metrics" -# experiment.yaml (merges automatically) -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.EarlyStopping -# Result: Both ModelCheckpoint AND EarlyStopping +# ✅ CORRECT - New instance for validation +val_metrics: "%model::train_metrics" +``` -# To replace instead of merge, use = -trainer: - =callbacks: - - _target_: pytorch_lightning.callbacks.EarlyStopping -# Result: Only EarlyStopping +Use `@` for everything else (optimizer, scheduler, network): + +```yaml +# ✅ CORRECT - Pass actual object +optimizer: + params: "$@model::network.parameters()" ``` -**How do I remove specific items from lists or dicts?** +## Can I use multiple GPUs? -Use `~` with path notation or batch syntax: +Yes! Same as PyTorch Lightning: ```yaml -# Delete entire key trainer: - ~callbacks: null + devices: 4 # Use 4 GPUs + strategy: ddp # Distributed Data Parallel +``` -# Delete single list item by index -trainer: - ~callbacks::1: null # Removes item at index 1 +Or use all available GPUs: -# Delete multiple list items (batch syntax - recommended) +```yaml trainer: - ~callbacks: [1, 3] # Removes items at indices 1 and 3 - -# Delete dict keys (batch syntax) -system: - ~dataloaders: ["train", "test"] # Removes train and test loaders - -# Delete nested dict key (path notation) -system: - ~model::pretrained: null + devices: -1 # All GPUs + strategy: ddp ``` -!!! tip - For multiple list items, use batch syntax `~key: [indices]` to avoid index shifting issues when doing sequential deletions. +[Multi-GPU guide →](guides/training.md#multi-gpu-training) -[Merging guide →](how-to/configuration.md#advanced-merging-and) +## How do I debug config errors? -## Training +### 1. Start Simple -**How do I resume training?** +Run one batch to catch errors fast: ```bash -lighter fit config.yaml args::fit::ckpt_path="checkpoint.ckpt" +lighter fit config.yaml trainer::fast_dev_run=true ``` -**How do I use multiple GPUs?** +### 2. Check Common Issues -```bash -lighter fit config.yaml trainer::devices=2 trainer::strategy=ddp -``` +**Import errors** - Check: -[Multi-GPU recipes →](how-to/recipes.md#multi-gpu-ddp) +- `__lighter__.py` exists in project root +- All directories have `__init__.py` +- Running `lighter` from directory with `__lighter__.py` -**How do I freeze layers?** +**Attribute errors** - Using `::` for Python methods? ```yaml -trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - modules: ["backbone.layer1", "backbone.layer2"] +# ❌ WRONG +params: "$@model::network::parameters()" + +# ✅ CORRECT +params: "$@model::network.parameters()" ``` -[Freezers guide →](how-to/freezers.md) +Remember: `::` for config, `.` for Python. -**Can I use custom training loops?** +### 3. Validate Config Syntax -Lighter uses Lightning's standard loop. For exotic training logic: +Use quotes for expressions: -1. Extend System class and override methods -2. Use Lightning directly +```yaml +# ❌ WRONG +lr: $0.001 * 2 -Most customizations achievable through callbacks or System extension. [System internals →](design/system.md) +# ✅ CORRECT +lr: "$0.001 * 2" +``` -## Design +[Full troubleshooting →](guides/training.md#debugging) -**Why adapters instead of custom LightningModule?** +## How do I save predictions? -Adapters separate data transformation from model logic, making both reusable. Configure transforms in YAML, reuse models across tasks. [Adapter pattern →](how-to/adapters.md) +Use Writers: -**What's the difference between stages and modes?** +```yaml +trainer: + callbacks: + - _target_: lighter.callbacks.CSVWriter + write_interval: batch +``` -- **Stages**: CLI commands (fit, validate, test, predict) -- **Modes**: Internal execution contexts (train, val, test, predict) +In your module: -Example: `lighter fit` executes train + val modes. [Architecture →](design/overview.md#understanding-stages-and-modes) +```python +def predict_step(self, batch, batch_idx): + x = batch + pred = self(x) -**How does config pruning work?** + return { + "prediction": pred.argmax(dim=1), + "probability": pred.max(dim=1).values, + } +``` -Lighter automatically removes unused components based on stage. `lighter test` removes train/val dataloaders, optimizer, and scheduler. One config works for all stages. [Pruning details →](design/overview.md#automatic-configuration-pruning) +Results saved to `predictions.csv`. -## Troubleshooting +[Writers guide →](guides/training.md#saving-predictions) -**Training is slow?** +## Can I merge multiple configs? -Check: +Yes! Separate comma-separated paths: -- Increase `num_workers` in dataloaders -- Enable mixed precision: `trainer::precision="16-mixed"` -- Profile: `trainer::profiler="simple"` +```bash +lighter fit base.yaml,experiment.yaml +``` -[Performance recipes →](how-to/recipes.md#performance-optimization) +Later files override earlier ones. Use this for: -**Loss is NaN?** +- Base config + experiment-specific overrides +- Shared settings + dataset configs +- Default values + hyperparameter sweeps -Common causes: +[Config merging →](guides/training.md#merging-configs) -1. Learning rate too high → reduce by 10x -2. Gradient explosion → `trainer::gradient_clip_val=1.0` -3. Wrong loss function → verify for your task -4. Bad data → check for inf/nan in inputs +## How does Lighter compare to Hydra? -[Full troubleshooting guide →](how-to/troubleshooting.md) +Similar concept, different focus: -**ModuleNotFoundError: No module named 'project'?** +**Hydra:** -Ensure: +- General-purpose Python app configuration +- Works with any codebase +- Manual integration -1. Project path set: `project: ./path` -2. All directories have `__init__.py` +**Lighter:** -[Project module guide →](how-to/project_module.md) +- Specialized for PyTorch Lightning +- Built-in CLI commands (`fit`, `test`, `predict`) +- Automatic object instantiation +- Zero-config experiment tracking -## Comparisons +If you're doing deep learning with Lightning, Lighter is simpler. If you need general Python app config, use Hydra. -**Lighter vs Hydra?** +## Where can I get help? -- **Hydra**: General config framework for any Python app -- **Lighter**: Deep learning-specific with built-in training pipeline +- **Discord**: [discord.gg/zJcnp6KrUp](https://discord.gg/zJcnp6KrUp) - Community support +- **GitHub**: [github.com/project-lighter/lighter](https://github.com/project-lighter/lighter) - Issues and discussions +- **Docs**: You're here! Check the guides. -Use Lighter for DL experiments with automatic training loops. +## How do I cite Lighter? -**Lighter vs Ludwig?** +```bibtex +@article{lighter2024, + title={Lighter: A lightweight deep learning framework for rapid experimentation}, + author={...}, + journal={Journal of Open Source Software}, + year={2024}, + doi={10.21105/joss.08101} +} +``` -- **Ludwig**: High-level, declarative ML with pre-built flows -- **Lighter**: Mid-level, requires standard PyTorch components +[Paper →](https://joss.theoj.org/papers/10.21105/joss.08101) -Use Ludwig for no-code ML. Use Lighter when you write custom PyTorch but want config-driven experiments. +## Can I contribute? -**Can I migrate from Lightning to Lighter?** +Yes! Lighter is open source: -Yes. Main steps: +- **Report bugs**: [GitHub Issues](https://github.com/project-lighter/lighter/issues) +- **Request features**: [GitHub Discussions](https://github.com/project-lighter/lighter/discussions) +- **Contribute code**: Submit PRs -1. Convert LightningModule to YAML config -2. Move training_step logic to adapters (if needed) -3. Configure dataloaders in YAML +See [CONTRIBUTING.md](https://github.com/project-lighter/lighter/blob/main/CONTRIBUTING.md) for guidelines. -[Complete migration guide →](migration/from-pytorch-lightning.md) +## More Questions? -## Getting Help +Check out: -| Need | Resource | -|------|----------| -| Getting started | [Tutorials](tutorials/get-started.md) | -| Configuration help | [Configuration Guide](how-to/configuration.md) | -| Common errors | [Troubleshooting](how-to/troubleshooting.md) | -| Examples | [Recipes](how-to/recipes.md) | -| Community | [Discord](https://discord.gg/zJcnp6KrUp) | -| Bug reports | [GitHub Issues](https://github.com/project-lighter/lighter/issues) | +- [Quick Start](quickstart.md) - Get started in 10 minutes +- [Configuration Guide](guides/configuration.md) - Master the syntax +- [Training Guide](guides/training.md) - Run experiments +- [Best Practices](guides/best-practices.md) - Production patterns diff --git a/docs/guides/best-practices.md b/docs/guides/best-practices.md new file mode 100644 index 00000000..dc184247 --- /dev/null +++ b/docs/guides/best-practices.md @@ -0,0 +1,1011 @@ +--- +title: Best Practices +--- + +# Best Practices + +Production-ready patterns for Lighter projects. + +This guide collects battle-tested patterns for structuring projects, organizing configs, and debugging effectively. + +## Project Structure + +### Recommended Layout + +``` +my_project/ +├── __lighter__.py # Marker file +├── __init__.py # Package root +├── pyproject.toml # Dependencies +├── README.md # Project docs +│ +├── src/ # Source code +│ ├── __init__.py +│ ├── models/ +│ │ ├── __init__.py +│ │ ├── resnet.py +│ │ └── unet.py +│ ├── data/ +│ │ ├── __init__.py +│ │ ├── datasets.py +│ │ └── transforms.py +│ └── callbacks/ +│ ├── __init__.py +│ └── custom.py +│ +├── experiments/ # Configs +│ ├── base.yaml # Shared settings +│ ├── image_classification/ +│ │ ├── resnet18.yaml +│ │ └── resnet50.yaml +│ └── segmentation/ +│ └── unet.yaml +│ +├── tests/ # Tests +│ ├── test_models.py +│ └── test_data.py +│ +└── outputs/ # Generated (gitignore) + └── YYYY-MM-DD/ +``` + +### Why This Structure? + +- **`__lighter__.py`**: Marks project root for auto-discovery +- **`src/`**: Clear separation of code +- **`experiments/`**: Version-controlled configs +- **`outputs/`**: Generated artifacts (not tracked) +- **`tests/`**: Ensure code quality + +## Config Organization + +### Base + Experiments Pattern + +**`experiments/base.yaml`** - Shared settings: + +```yaml +# Base configuration shared across experiments + +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 100 + accelerator: auto + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 3 + + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + num_workers: 4 + pin_memory: true + persistent_workers: true + + val_dataloader: + _target_: torch.utils.data.DataLoader + num_workers: 4 + pin_memory: true +``` + +**`experiments/resnet18.yaml`** - Specific experiment: + +```yaml +# Experiment-specific settings (merged with base.yaml via CLI) +model: + _target_: src.models.ImageClassifier + learning_rate: 0.001 + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + +data: + train_dataloader: + batch_size: 128 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true +``` + +Run with both configs - they merge in order: + +```bash +lighter fit experiments/base.yaml experiments/resnet18.yaml +``` + +Or override specific values from CLI: + +```bash +lighter fit experiments/base.yaml experiments/resnet18.yaml trainer::max_epochs=50 +``` + +!!! info "How config composition works" + Each config file (and CLI override) is applied sequentially via Sparkwheel's `.update()` method. + Dictionaries merge recursively, lists extend by default. Use `=key:` to replace instead of merge, + or `~key:` to delete. See the [Sparkwheel docs](https://project-lighter.github.io/sparkwheel/) for details. + +### Use Variables for Reusability + +```yaml +vars: + # Dataset settings + num_classes: 10 + img_size: 224 + + # Training settings + batch_size: 32 + base_lr: 0.001 + max_epochs: 100 + + # Hardware settings + num_workers: 4 + devices: 1 + +trainer: + max_epochs: "%vars::max_epochs" + devices: "%vars::devices" + +model: + network: + num_classes: "%vars::num_classes" + optimizer: + lr: "%vars::base_lr" + +data: + train_dataloader: + batch_size: "%vars::batch_size" + num_workers: "%vars::num_workers" +``` + +Override easily: + +```bash +lighter fit config.yaml vars::batch_size=64 vars::base_lr=0.01 +``` + +### Separate Data Configs + +For large datasets, separate data configs: + +**`experiments/data/cifar10.yaml`**: + +```yaml +data: + _target_: lighter.LighterDataModule + + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 128 + shuffle: true + num_workers: 4 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.RandomCrop + size: 32 + padding: 4 + - _target_: torchvision.transforms.RandomHorizontalFlip + - _target_: torchvision.transforms.ToTensor +``` + +**`experiments/my_experiment.yaml`**: + +```yaml +model: + # ... model config ... +``` + +Run with both configs: + +```bash +lighter fit experiments/data/cifar10.yaml experiments/my_experiment.yaml +``` + +## Module Design + +### Save Hyperparameters + +Always save hyperparameters for reproducibility: + +```python +class MyModule(pl.LightningModule): + def __init__(self, network, learning_rate=0.001, weight_decay=0.0): + super().__init__() + # Save all args except objects + self.save_hyperparameters(ignore=['network']) + self.network = network +``` + +Access with `self.hparams.learning_rate`. + +### Use Type Hints + +```python +from typing import Dict, Any +import torch + +class MyModule(LighterModule): + def training_step( + self, + batch: tuple[torch.Tensor, torch.Tensor], + batch_idx: int + ) -> Dict[str, torch.Tensor]: + x, y = batch + loss = self.criterion(self(x), y) + return {"loss": loss} +``` + +### Document __init__ Parameters + +Config values map to `__init__` args: + +```python +class MyDataset(Dataset): + """Custom dataset for my task. + + Args: + root: Path to data directory + split: One of 'train', 'val', 'test' + transform: Optional transform to apply + cache: Whether to cache preprocessed data in memory + """ + + def __init__( + self, + root: str, + split: str = 'train', + transform: Optional[Callable] = None, + cache: bool = False + ): + ... +``` + +Users can see available options in docstrings. + +### Separate Concerns + +**Bad** - Everything in one module: + +```python +class MyModule(pl.LightningModule): + def __init__(self): + # Network definition + self.conv1 = nn.Conv2d(...) + self.conv2 = nn.Conv2d(...) + # Loss + self.criterion = nn.CrossEntropyLoss() + # Metrics + self.acc = Accuracy() +``` + +**Good** - Modular design: + +```python +class MyModule(pl.LightningModule): + def __init__(self, network, criterion, metrics): + super().__init__() + self.network = network + self.criterion = criterion + self.metrics = metrics +``` + +Config controls composition: + +```yaml +model: + network: + _target_: src.networks.ResNet + criterion: + _target_: torch.nn.CrossEntropyLoss + metrics: + _target_: torchmetrics.Accuracy +``` + +## Data Best Practices + +### Use num_workers + +```yaml +data: + train_dataloader: + num_workers: 4 # Parallelize data loading + pin_memory: true # Faster GPU transfer + persistent_workers: true # Keep workers alive +``` + +Start with `num_workers = num_cpus / num_gpus`. + +### Prefetch Factor + +For slow data loading: + +```yaml +data: + train_dataloader: + num_workers: 4 + prefetch_factor: 2 # Each worker prefetches 2 batches +``` + +### Cache Small Datasets + +For datasets that fit in RAM: + +```python +class CachedDataset(Dataset): + def __init__(self, root, transform=None, cache=True): + self.transform = transform + self.cache = cache + + # Load all data at init + if self.cache: + self.data = [self._load_item(i) for i in range(len(files))] + + def __getitem__(self, idx): + if self.cache: + item = self.data[idx] + else: + item = self._load_item(idx) + + if self.transform: + item = self.transform(item) + + return item +``` + +### Validate Data in __init__ + +Check data availability early: + +```python +class MyDataset(Dataset): + def __init__(self, root, split='train'): + self.root = Path(root) + + # Validate + if not self.root.exists(): + raise ValueError(f"Data directory not found: {root}") + + split_file = self.root / f"{split}.txt" + if not split_file.exists(): + raise ValueError(f"Split file not found: {split_file}") + + # Load + self.samples = self._load_split(split_file) + + if len(self.samples) == 0: + raise ValueError(f"No samples found for split '{split}'") +``` + +Fails fast with clear error messages. + +## Training Best Practices + +### Learning Rate Warmup + +Stabilize early training: + +```python +def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + + # Warmup for first 1000 steps + warmup = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.1, + total_iters=1000 + ) + + # Then cosine decay + cosine = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=self.trainer.max_epochs - 10 + ) + + # Combine + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[10] + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } +``` + +### Gradient Clipping + +Prevent exploding gradients: + +```yaml +trainer: + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm +``` + +Essential for RNNs and transformers. + +### Model EMA (Exponential Moving Average) + +Smoother model weights: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.EMA + decay: 0.999 +``` + +### Mixed Precision + +Faster training on modern GPUs: + +```yaml +trainer: + precision: 16 # or "bf16-mixed" for BFloat16 +``` + +### Accumulate Gradients + +Simulate large batch size: + +```yaml +trainer: + accumulate_grad_batches: 4 # Effective batch = batch_size × 4 +``` + +## Logging Best Practices + +### Log Hyperparameters + +Log all hyperparameters for comparison: + +```python +def __init__(self, ...): + super().__init__() + self.save_hyperparameters() # Logs to tensorboard/wandb +``` + +### Dual Logging (Step + Epoch) + +LighterModule does this automatically. For Lightning: + +```python +def training_step(self, batch, batch_idx): + loss = ... + + # Log both step and epoch + self.log("train/loss", loss, on_step=True, on_epoch=True) + + return loss +``` + +### Log Learning Rate + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step +``` + +### Log Sample Images + +```python +def validation_step(self, batch, batch_idx): + if batch_idx == 0: + x, y = batch + pred = self(x) + + # Log first 8 images + self.logger.experiment.add_images( + "val/predictions", + x[:8], + self.global_step + ) +``` + +### Use Structured Logging + +Group related metrics: + +```python +# Good - grouped +self.log("train/loss", loss) +self.log("train/acc", acc) +self.log("val/loss", val_loss) +self.log("val/acc", val_acc) + +# Bad - flat +self.log("loss", loss) +self.log("acc", acc) +``` + +## Debugging Strategies + +### Start Simple + +1. **Overfit 1 batch**: + ```bash + lighter fit config.yaml trainer::overfit_batches=1 + ``` + Should reach ~0 loss quickly. + +2. **Fast dev run**: + ```bash + lighter fit config.yaml trainer::fast_dev_run=true + ``` + Catches basic errors. + +3. **Limit batches**: + ```bash + lighter fit config.yaml trainer::limit_train_batches=10 + ``` + Faster iteration during development. + +### Add Assertions + +```python +def training_step(self, batch, batch_idx): + x, y = batch + + # Validate shapes + assert x.dim() == 4, f"Expected 4D input, got {x.dim()}D" + assert y.dim() == 1, f"Expected 1D targets, got {y.dim()}D" + + pred = self(x) + + # Validate output + assert pred.shape[0] == x.shape[0], "Batch size mismatch" + assert not torch.isnan(pred).any(), "NaN in predictions" + + loss = self.criterion(pred, y) + return loss +``` + +Remove in production. + +### Log Distributions + +```python +def training_step(self, batch, batch_idx): + loss = ... + + # Log gradient norms + if batch_idx % 100 == 0: + for name, param in self.named_parameters(): + if param.grad is not None: + self.logger.experiment.add_histogram( + f"gradients/{name}", + param.grad, + self.global_step + ) + + return loss +``` + +### Use Anomaly Detection + +For NaN debugging: + +```python +torch.autograd.set_detect_anomaly(True) +``` + +Or in config: + +```yaml +trainer: + detect_anomaly: true +``` + +Slower, but helps find NaN sources. + +## Config Best Practices + +### Use Comments + +```yaml +model: + optimizer: + _target_: torch.optim.AdamW + params: "$@model::network.parameters()" + lr: 0.001 # Tuned via LR finder + weight_decay: 0.01 # L2 regularization + +trainer: + max_epochs: 100 # Convergence around epoch 80 + devices: 4 # 4× A100 GPUs +``` + +### Avoid Hardcoded Paths + +**Bad**: + +```yaml +data: + dataset: + root: /home/user/data/cifar10 # Breaks on other machines +``` + +**Good**: + +```yaml +vars: + data_root: ./data # Relative path + +data: + dataset: + root: "%vars::data_root/cifar10" +``` + +Or use environment variables: + +```yaml +vars: + data_root: "$os.environ.get('DATA_ROOT', './data')" +``` + +### Version Configs + +```yaml +# config.yaml +_meta_: + version: "1.2.0" + description: "ResNet50 with strong augmentation" + created: "2024-01-15" + author: "your-name" + +# ... rest of config ... +``` + +### Keep Configs DRY + +Use references to avoid duplication: + +```yaml +vars: + num_classes: 10 + +model: + network: + num_classes: "%vars::num_classes" + + train_metrics: + - _target_: torchmetrics.Accuracy + num_classes: "%vars::num_classes" # Same value +``` + +## Reproducibility + +### Set Seeds + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + deterministic: true # Reproducible + +# In __lighter__.py +import pytorch_lightning as pl +pl.seed_everything(42, workers=True) +``` + +### Log Everything + +```python +def __init__(self, ...): + super().__init__() + self.save_hyperparameters() # Save all args + +# In trainer +trainer: + logger: + _target_: pytorch_lightning.loggers.WandbLogger + log_model: true # Save model to wandb +``` + +### Version Dependencies + +`pyproject.toml`: + +```toml +[project] +dependencies = [ + "torch==2.1.0", + "pytorch-lightning==2.1.0", + "lighter>=3.0.0", +] +``` + +### Save Config with Outputs + +Lightning does this automatically: + +``` +outputs/ +└── 2024-01-15/ + └── 10-30-45/ + ├── config.yaml # Exact config used + └── checkpoints/ +``` + +## Testing + +### Unit Test Models + +```python +# tests/test_models.py +import pytest +import torch +from src.models import MyModule + +def test_forward_pass(): + model = MyModule(network=...) + x = torch.randn(2, 3, 32, 32) + y = model(x) + + assert y.shape == (2, 10) + assert not torch.isnan(y).any() + +def test_training_step(): + model = MyModule(...) + batch = (torch.randn(2, 3, 32, 32), torch.randint(0, 10, (2,))) + + result = model.training_step(batch, 0) + + assert "loss" in result + assert result["loss"].dim() == 0 # Scalar +``` + +### Integration Test Configs + +```python +# tests/test_configs.py +def test_config_loads(): + """Test that config instantiates correctly.""" + from lighter.engine.runner import load_config + + config = load_config("experiments/resnet18.yaml") + + assert config.trainer is not None + assert config.model is not None + assert config.data is not None + +def test_fast_dev_run(tmp_path): + """Test full pipeline on 1 batch.""" + import subprocess + + result = subprocess.run( + [ + "lighter", "fit", + "experiments/resnet18.yaml", + f"trainer::default_root_dir={tmp_path}", + "trainer::fast_dev_run=true", + ], + capture_output=True, + ) + + assert result.returncode == 0 +``` + +## Performance Optimization + +### Profile First + +```bash +lighter fit config.yaml trainer::profiler=simple +``` + +Identify bottlenecks before optimizing. + +### Data Loading + +Most common bottleneck: + +1. Increase `num_workers` +2. Add `pin_memory=true` +3. Add `persistent_workers=true` +4. Cache small datasets +5. Preprocess data offline + +### Model Optimization + +```python +# Compile model (PyTorch 2.0+) +self.network = torch.compile(self.network) + +# Use channels_last memory format +self.network = self.network.to(memory_format=torch.channels_last) +``` + +### Mixed Precision + +```yaml +trainer: + precision: "bf16-mixed" # BFloat16 on A100/H100 +``` + +### Gradient Checkpointing + +For large models: + +```python +from torch.utils.checkpoint import checkpoint + +def forward(self, x): + # Trade compute for memory + x = checkpoint(self.layer1, x) + x = checkpoint(self.layer2, x) + return x +``` + +## Production Deployment + +### Export to TorchScript + +```python +# After training +model = MyModule.load_from_checkpoint("best.ckpt") +model.eval() + +scripted = torch.jit.script(model) +scripted.save("model.pt") +``` + +### Export to ONNX + +```python +model = MyModule.load_from_checkpoint("best.ckpt") +model.eval() + +dummy_input = torch.randn(1, 3, 224, 224) + +torch.onnx.export( + model, + dummy_input, + "model.onnx", + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch"}, + "output": {0: "batch"}, + }, +) +``` + +### Serve with TorchServe + +See [TorchServe docs](https://pytorch.org/serve/) for deployment. + +## Common Pitfalls + +### ❌ Using `@` for Metrics + +```yaml +# WRONG - Shared instance! +val_metrics: "@model::train_metrics" +``` + +```yaml +# CORRECT - New instance +val_metrics: "%model::train_metrics" +``` + +### ❌ Forgetting to Call .eval() + +```python +# Inference +model = MyModule.load_from_checkpoint("best.ckpt") +model.eval() # Important! + +with torch.no_grad(): + pred = model(x) +``` + +### ❌ Not Saving Hyperparameters + +```python +def __init__(self, learning_rate=0.001): + super().__init__() + self.save_hyperparameters() # Don't forget! + self.lr = learning_rate +``` + +### ❌ Leaking Validation Data + +```python +# BAD - Using training mode on validation +def validation_step(self, batch, batch_idx): + self.train() # ❌ Wrong! + ... + +# GOOD - Lightning handles this +def validation_step(self, batch, batch_idx): + # Already in eval mode + ... +``` + +### ❌ Ignoring Batch Size in Metrics + +```python +# BAD +self.log("train/loss", loss) # Averages across uneven batches + +# GOOD +self.log("train/loss", loss, batch_size=x.size(0)) +``` + +## Checklist + +Before training: + +- [ ] Seeds set for reproducibility +- [ ] Hyperparameters saved (`save_hyperparameters()`) +- [ ] Metrics use `%` not `@` +- [ ] Data loading optimized (`num_workers`, `pin_memory`) +- [ ] Checkpointing configured +- [ ] Logging configured +- [ ] Config tested with `fast_dev_run=true` + +Before production: + +- [ ] Model validated on held-out test set +- [ ] Inference tested on sample data +- [ ] Export format chosen (TorchScript/ONNX) +- [ ] Dependencies versioned +- [ ] Config and checkpoints saved +- [ ] Documentation updated + +## Next Steps + +- [Training Guide](training.md) - Run experiments +- [Examples](../examples/image-classification.md) - Complete working code +- [FAQ](../faq.md) - Common questions + +## Quick Reference + +```python +# Save hyperparameters +self.save_hyperparameters(ignore=['network']) + +# Dual logging +self.log("train/loss", loss, on_step=True, on_epoch=True, batch_size=x.size(0)) + +# Assertions (dev only) +assert not torch.isnan(x).any() + +# Profile +torch.autograd.set_detect_anomaly(True) +``` + +```yaml +# Reproducibility +trainer: + deterministic: true + +# Performance +trainer: + precision: "bf16-mixed" + accumulate_grad_batches: 4 + +data: + train_dataloader: + num_workers: 4 + pin_memory: true + persistent_workers: true +``` diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md new file mode 100644 index 00000000..d9b944f4 --- /dev/null +++ b/docs/guides/configuration.md @@ -0,0 +1,594 @@ +--- +title: Configuration Guide +--- + +# Configuration Guide + +Master Lighter's configuration system in 15 minutes. + +Lighter uses [Sparkwheel](https://project-lighter.github.io/sparkwheel/) for configuration - a powerful YAML-based system with references, expressions, and object instantiation. + +## The 5 Essential Symbols + +You only need to understand 5 symbols to use Lighter effectively: + +### 1. `_target_`: Create Objects + +Instantiate any Python class from YAML: + +```yaml +model: + _target_: torch.nn.Linear + in_features: 784 + out_features: 10 +``` + +**Equivalent Python:** +```python +model = torch.nn.Linear(in_features=784, out_features=10) +``` + +Works with **any** Python class - PyTorch, third-party libraries, or your own code. + +**Using `_args_` for positional arguments:** + +When you need to pass positional arguments instead of keyword arguments: + +```yaml +model: + _target_: torch.nn.Sequential + _args_: + - _target_: torch.nn.Linear + in_features: 784 + out_features: 128 + - _target_: torch.nn.ReLU +``` + +**Equivalent Python:** +```python +model = torch.nn.Sequential( + torch.nn.Linear(in_features=784, out_features=128), + torch.nn.ReLU() +) +``` + +The `_args_` list contains positional arguments passed to the target class. Each item can have its own `_target_` for nested instantiation. + +**Using `_mode_` to control instantiation:** + +Control how `_target_` instantiates objects: + +```yaml +model: + # Factory function that returns a model + _target_: project.model_factory.create_model + _mode_: callable # Returns a partial function + model_name: resnet50 + num_classes: 10 +``` + +**Available modes:** + +- `"default"` (default): Normal instantiation - `Component(*args, **kwargs)` +- `"callable"`: Returns a partial function - `functools.partial(Component, *args, **kwargs)` +- `"debug"`: Runs in debugger - `pdb.runcall(Component, *args, **kwargs)` + +**When to use `callable` mode:** + +Use `_mode_: callable` when you need a factory function or lazy instantiation: + +```yaml +data: + train_dataloader: + _target_: torch.utils.data.DataLoader + collate_fn: + _target_: project.collate.custom_collate + _mode_: callable # DataLoader needs the function, not the result + padding_value: 0 +``` + +**Equivalent Python:** +```python +from functools import partial + +collate_fn = partial(custom_collate, padding_value=0) +dataloader = DataLoader(..., collate_fn=collate_fn) +``` + +**When to use `debug` mode:** + +Use `_mode_: debug` to debug instantiation issues by entering the debugger when the component is created: + +```yaml +model: + network: + _target_: project.model.ComplexModel + _mode_: debug # Will enter pdb when instantiating + num_layers: 12 + hidden_size: 768 +``` + +This is equivalent to: +```python +import pdb +network = pdb.runcall(ComplexModel, num_layers=12, hidden_size=768) +``` + +Useful when you need to step through the `__init__` method to diagnose instantiation errors. + +**Using `_disabled_` to skip instantiation:** + +Skip instantiation of a component without removing it from config: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val_loss + patience: 3 + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + _disabled_: true # This callback is removed from the list + save_top_k: 3 + +system: + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + _disabled_: true # Disable while debugging optimizer issues + optimizer: "@system::optimizer" + T_max: 100 +``` + +When `_disabled_: true`: + +- **Inline in lists/dicts**: Disabled components are **removed** from the parent structure +- **Direct resolution or `@` references**: Returns `None` + +The config is preserved—re-enable by setting `_disabled_: false` or removing the key. + +For complete details (string values, expressions, use cases), see the [Sparkwheel documentation](https://project-lighter.github.io/sparkwheel/user-guide/instantiation/#_disabled_-skip-instantiation). + +### 2. `@`: Resolved References (Lazy) + +Reference values that are **resolved lazily** when needed: + +```yaml +model: + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" # Resolved lazily + lr: 0.001 +``` + +**What `@` does:** +- Resolves **lazily** (when you call `resolve()`, not when loading config) +- Returns **final computed values** after instantiation and evaluation +- Gets the actual Python object, so you can call methods on it + +### 3. `%`: Raw References (Eager) + +Copy **raw YAML content** that is processed **eagerly** during config merge: + +```yaml +model: + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + + val_metrics: "%model::train_metrics" # Copies raw YAML +``` + +**What `%` does:** +- Processes **eagerly** (during config loading, before instantiation) +- Copies **unprocessed YAML** definition +- Creates a **new instance** when later resolved (not shared!) + +!!! danger "Critical: Use `%` for Metrics, Not `@`" + Metrics accumulate state. Using `@` shares the same instance between train and val: + + ```yaml + # ❌ WRONG - Shares the same metric instance + val_metrics: "@model::train_metrics" + + # ✅ CORRECT - Copies config, creates separate instance + val_metrics: "%model::train_metrics" + ``` + + **Why this matters:** `%` copies the raw YAML template, so when each is resolved, you get separate instances. `@` would resolve once and share that same object. + +### 4. `$`: Evaluate Python Expressions + +Run Python code in your configs: + +```yaml +# Simple math +lr: "$0.001 * 2" # = 0.002 + +# Call methods +optimizer: + params: "$@model::network.parameters()" + +# Conditionals +batch_size: "$64 if %vars::large_batch else 32" + +# List comprehensions +layer_sizes: "$[64 * (2**i) for i in range(4)]" # [64, 128, 256, 512] + +# Type conversions +warmup_steps: "$int(%vars::total_steps * 0.1)" +``` + +### 5. `::`: Navigate Config Paths + +Access nested values using `::` separator: + +```yaml +model::optimizer::lr # Navigate to nested value +data::train_dataloader::batch_size +``` + +Use in CLI overrides: +```bash +lighter fit config.yaml model::optimizer::lr=0.01 +``` + +## The Critical Rule: `::` vs `.` + +- `::` navigates **config** structure +- `.` accesses **Python** attributes/methods + +```yaml +# ❌ WRONG +params: "$@model::network::parameters()" # :: for Python method + +# ✅ CORRECT +params: "$@model::network.parameters()" # . for Python method +``` + +## Common Patterns + +### Pattern 1: Network → Optimizer + +```yaml +model: + network: + _target_: torchvision.models.resnet50 + num_classes: 10 + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" # Pass network params + lr: 0.001 +``` + +### Pattern 2: Optimizer → Scheduler + +```yaml +model: + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + optimizer: "@model::optimizer" # Pass optimizer object + T_max: 100 +``` + +### Pattern 3: Reusing Metrics + +```yaml +model: + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + - _target_: torchmetrics.F1Score + task: multiclass + num_classes: 10 + + val_metrics: "%model::train_metrics" # Reuse config + test_metrics: "%model::train_metrics" +``` + +### Pattern 4: Shared Variables + +```yaml +vars: + num_classes: 10 + base_lr: 0.001 + batch_size: 32 + +model: + network: + _target_: torchvision.models.resnet18 + num_classes: "%vars::num_classes" + + optimizer: + lr: "%vars::base_lr" + +data: + train_dataloader: + batch_size: "%vars::batch_size" +``` + +### Pattern 5: Differential Learning Rates + +```yaml +model: + optimizer: + _target_: torch.optim.SGD + params: + - params: "$@model::network.backbone.parameters()" + lr: 0.0001 # Low LR for pretrained backbone + - params: "$@model::network.head.parameters()" + lr: 0.01 # High LR for new head + momentum: 0.9 +``` + +## Config Structure + +Every Lighter config has three main sections: + +```yaml +trainer: # PyTorch Lightning Trainer + _target_: pytorch_lightning.Trainer + max_epochs: 10 + accelerator: auto + devices: 1 + +model: # LightningModule or LighterModule + _target_: your.Module + # ... module arguments ... + +data: # LighterDataModule or custom LightningDataModule + _target_: lighter.LighterDataModule + train_dataloader: ... + val_dataloader: ... +``` + +### Optional Sections + +```yaml +_requires_: # Import Python modules + - "$import torch" + - "$from datetime import datetime" + +vars: # Reusable variables + num_classes: 10 + lr: 0.001 + +args: # Stage-specific arguments + fit: + ckpt_path: null + test: + ckpt_path: "checkpoints/best.ckpt" +``` + +## CLI Overrides + +Override any config value from command line: + +```bash +# Single override +lighter fit config.yaml trainer::max_epochs=100 + +# Nested values +lighter fit config.yaml model::optimizer::lr=0.001 + +# Multiple overrides +lighter fit config.yaml \ + trainer::max_epochs=100 \ + model::optimizer::lr=0.001 \ + data::train_dataloader::batch_size=64 +``` + +## Merging Configs + +Combine multiple YAML files: + +```bash +lighter fit base.yaml,experiment.yaml +``` + +### Default Behavior: Merge + +Configs merge automatically: + +```yaml +# base.yaml +trainer: + max_epochs: 10 + accelerator: auto + +# experiment.yaml +trainer: + max_epochs: 100 # Overrides + devices: 4 # Adds +``` + +**Result:** `max_epochs=100`, `accelerator=auto`, `devices=4` + +### Replace with `=` + +Replace instead of merge: + +```yaml +# experiment.yaml +trainer: + =callbacks: # Replace entire callbacks list + - _target_: pytorch_lightning.callbacks.EarlyStopping +``` + +### Delete with `~` + +Remove keys: + +```yaml +# Delete entire key +trainer: + ~callbacks: null + +# Delete list items by index +trainer: + ~callbacks: [1, 3] # Remove items at indices 1 and 3 + +# Delete dict keys +data: + ~test_dataloader: null +``` + +## Common Pitfalls + +### ❌ Wrong: Using `@` for Metrics + +```yaml +val_metrics: "@model::train_metrics" # Shared instance! +``` + +### ✅ Correct: Using `%` for Metrics + +```yaml +val_metrics: "%model::train_metrics" # New instance +``` + +--- + +### ❌ Wrong: Using `::` for Python Attributes + +```yaml +params: "$@model::network::parameters()" +``` + +### ✅ Correct: Using `.` for Python Attributes + +```yaml +params: "$@model::network.parameters()" +``` + +--- + +### ❌ Wrong: Missing `$` for Expressions + +```yaml +batch_size: "@vars::base_batch * 2" # Treated as string! +``` + +### ✅ Correct: Using `$` for Expressions + +```yaml +batch_size: "$%vars::base_batch * 2" # Evaluated +``` + +## Advanced: Conditional Config + +```yaml +vars: + use_pretrained: true + +model: + network: + _target_: torchvision.models.resnet18 + weights: "$'IMAGENET1K_V2' if %vars::use_pretrained else None" + num_classes: 10 +``` + +## Advanced: Dynamic Imports + +```yaml +_requires_: + - "$import datetime" + - "$from pathlib import Path" + +trainer: + logger: + name: "$datetime.datetime.now().strftime('%Y%m%d_%H%M%S')" +``` + +## Complete Example + +```yaml +_requires_: + - "$import torch" + +vars: + num_classes: 10 + base_lr: 0.001 + max_epochs: 100 + +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: "%vars::max_epochs" + accelerator: auto + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_loss + mode: min + save_top_k: 3 + +model: + _target_: lighter.LighterModule + + network: + _target_: torchvision.models.resnet18 + num_classes: "%vars::num_classes" + + criterion: + _target_: torch.nn.CrossEntropyLoss + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: "%vars::base_lr" + + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + optimizer: "@model::optimizer" + T_max: "%vars::max_epochs" + + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: "%vars::num_classes" + + val_metrics: "%model::train_metrics" + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true +``` + +## Quick Reference + +| Symbol | Use | Example | +|--------|-----|---------| +| `_target_` | Instantiate class | `_target_: torch.nn.Linear` | +| `_args_` | Positional arguments | `_args_: [arg1, arg2]` | +| `_mode_` | Instantiation mode | `_mode_: callable` | +| `_disabled_` | Skip instantiation (removed from parent) | `_disabled_: true` | +| `@` | Resolved reference | `@model::optimizer` | +| `%` | Raw reference | `%model::train_metrics` | +| `$` | Python expression | `$0.001 * 2` | +| `::` | Config path | `model::optimizer::lr` | +| `.` | Python attribute | `@model::network.parameters()` | +| `=` | Replace operator | `=callbacks:` | +| `~` | Delete operator | `~callbacks: [0, 2]` | + +## Next Steps + +- [Custom Code Guide](custom-code.md) - Use your own models/datasets +- [Training Guide](training.md) - Run experiments +- [Sparkwheel Docs](https://project-lighter.github.io/sparkwheel/) - Complete reference diff --git a/docs/guides/custom-code.md b/docs/guides/custom-code.md new file mode 100644 index 00000000..eef268d9 --- /dev/null +++ b/docs/guides/custom-code.md @@ -0,0 +1,829 @@ +--- +title: Custom Code Guide +--- + +# Custom Code Guide + +Use your own models, datasets, and transforms with Lighter. + +Lighter works with **any** Python class. This guide shows how to structure your project and reference your custom code in configs. + +## The Project Folder Pattern + +Lighter uses a **project folder** with auto-discovery. When you add `__lighter__.py` to your folder, Lighter automatically makes it available as `project`, allowing you to reference your code as `project.module.Class`. + +### Quick Example + +``` +cifar10/ +├── __lighter__.py # Marker file +├── __init__.py # Python package +├── model.py # Your models +└── configs/ + └── config.yaml +``` + +```yaml +# configs/config.yaml +model: + _target_: project.model.CIFAR10Model # Auto-discovered! +``` + +This is the **recommended approach** for organizing Lighter projects. + +## Project Structure + +A typical Lighter project looks like this: + +``` +my_project/ +├── __lighter__.py # Marker file (can be empty) +├── __init__.py # Makes it a package +├── model.py # Your models +├── data.py # Your datasets +├── transforms.py # Your transforms +├── configs/ +│ ├── baseline.yaml +│ └── improved.yaml +└── outputs/ # Created by Lighter +``` + +### The `__lighter__.py` Marker + +This file tells Lighter where your project root is. It can be empty: + +```python +# __lighter__.py +# This file can be empty - it just marks your project root +``` + +Or use it for project-level imports: + +```python +# __lighter__.py +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + +# Any imports here run before config loading +``` + +### The `__init__.py` Files + +Every directory containing code needs `__init__.py`: + +``` +my_project/ +├── __init__.py # Required +├── models/ +│ ├── __init__.py # Required +│ └── resnet.py +└── data/ + ├── __init__.py # Required + └── dataset.py +``` + +Without `__init__.py`, Python can't import from that directory. + +## Using Custom Models + +### Example: Custom Model + +Create **`model.py`**: + +```python +import torch.nn as nn + +class SimpleNet(nn.Module): + """Custom network for CIFAR-10.""" + + def __init__(self, num_classes=10, dropout=0.5): + super().__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Dropout(dropout), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Dropout(dropout), + ) + self.classifier = nn.Sequential( + nn.Linear(128 * 8 * 8, 256), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(256, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier(x) + return x +``` + +### Reference in Config + +```yaml +model: + _target_: lighter.LighterModule + + network: + _target_: project.model.SimpleNet # project.file.Class + num_classes: 10 + dropout: 0.3 + + criterion: + _target_: torch.nn.CrossEntropyLoss + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 +``` + +**Pattern**: `project.module_name.ClassName` + +- `project` - auto-discovered from `__lighter__.py` +- `model` - Python file (model.py) +- `SimpleNet` - class name + +## Using Custom Datasets + +### Example: Custom Dataset + +Create **`data.py`**: + +```python +import torch +from torch.utils.data import Dataset +from pathlib import Path +from PIL import Image + +class CustomImageDataset(Dataset): + """Load images from directory structure.""" + + def __init__(self, root_dir, transform=None): + self.root_dir = Path(root_dir) + self.transform = transform + + # Assume structure: root_dir/class_name/image.jpg + self.samples = [] + self.class_to_idx = {} + + for idx, class_dir in enumerate(sorted(self.root_dir.iterdir())): + if class_dir.is_dir(): + self.class_to_idx[class_dir.name] = idx + for img_path in class_dir.glob("*.jpg"): + self.samples.append((img_path, idx)) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + img_path, label = self.samples[idx] + image = Image.open(img_path).convert("RGB") + + if self.transform: + image = self.transform(image) + + return image, label +``` + +### Reference in Config + +```yaml +data: + _target_: lighter.LighterDataModule + + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: true + num_workers: 4 + dataset: + _target_: project.data.CustomImageDataset + root_dir: ./data/train + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.Resize + size: [224, 224] + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] +``` + +## Using Custom Transforms + +### Example: Custom Transform + +Create **`transforms.py`**: + +```python +import torch +import random + +class RandomCutout: + """Randomly mask out a square patch from the image.""" + + def __init__(self, size=16, p=0.5): + self.size = size + self.p = p + + def __call__(self, img): + if random.random() > self.p: + return img + + h, w = img.shape[1:] + y = random.randint(0, h - self.size) + x = random.randint(0, w - self.size) + + img[:, y:y+self.size, x:x+self.size] = 0 + return img +``` + +### Reference in Config + +```yaml +data: + train_dataloader: + dataset: + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: project.transforms.RandomCutout + size: 16 + p: 0.5 + - _target_: torchvision.transforms.Normalize + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] +``` + +## Complete Example: Custom LightningModule + +### Step 1: Create Your Module + +**`model.py`**: + +```python +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +class MyCIFAR10Module(pl.LightningModule): + """Custom training logic for CIFAR-10.""" + + def __init__(self, network, learning_rate=0.001, weight_decay=1e-4): + super().__init__() + self.save_hyperparameters(ignore=['network']) + self.network = network + self.lr = learning_rate + self.weight_decay = weight_decay + + def forward(self, x): + return self.network(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + + # Log metrics + acc = (logits.argmax(dim=1) == y).float().mean() + self.log("train/loss", loss) + self.log("train/acc", acc) + + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = (logits.argmax(dim=1) == y).float().mean() + + self.log("val/loss", loss) + self.log("val/acc", acc) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.lr, + weight_decay=self.weight_decay + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=self.trainer.max_epochs + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } +``` + +### Step 2: Create Config + +**`configs/custom.yaml`**: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 100 + accelerator: auto + devices: 1 + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 3 + +model: + _target_: project.model.MyCIFAR10Module + learning_rate: 0.001 + weight_decay: 0.0001 + network: + _target_: project.model.SimpleNet + num_classes: 10 + dropout: 0.3 + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 128 + shuffle: true + num_workers: 4 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.RandomCrop + size: 32 + padding: 4 + - _target_: torchvision.transforms.RandomHorizontalFlip + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 128 + num_workers: 4 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + download: true + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.4914, 0.4822, 0.4465] + std: [0.2470, 0.2435, 0.2616] +``` + +### Step 3: Run + +```bash +cd my_project +lighter fit configs/custom.yaml +``` + +## Common Patterns + +### Pattern 1: Separate Network and Module + +Keep network architecture separate from training logic: + +``` +my_project/ +├── networks/ +│ ├── __init__.py +│ ├── resnet.py +│ └── unet.py +├── modules/ +│ ├── __init__.py +│ ├── classifier.py +│ └── segmentation.py +└── configs/ + └── config.yaml +``` + +Config: + +```yaml +model: + _target_: project.modules.classifier.ClassificationModule + network: + _target_: project.networks.resnet.ResNet50 + num_classes: 10 +``` + +### Pattern 2: Shared Base Classes + +Create base modules for common functionality: + +**`modules/base.py`**: + +```python +from lighter import LighterModule + +class BaseVisionModule(LighterModule): + """Base module with common vision model utilities.""" + + def on_train_start(self): + # Log model architecture + self.logger.experiment.add_text( + "model/architecture", + str(self.network) + ) + + def log_images(self, images, name, n=8): + # Helper to log images + import torchvision + grid = torchvision.utils.make_grid(images[:n]) + self.logger.experiment.add_image(name, grid, self.global_step) +``` + +Use in your modules: + +```python +from project.modules.base import BaseVisionModule + +class MyModule(BaseVisionModule): + def training_step(self, batch, batch_idx): + x, y = batch + + # Log images every 100 steps + if batch_idx % 100 == 0: + self.log_images(x, "train/inputs") + + # ... rest of training step ... +``` + +### Pattern 3: Config Inheritance + +Use YAML anchors for shared config: + +**`configs/base.yaml`**: + +```yaml +# Shared settings +defaults: &defaults + trainer: + max_epochs: 100 + accelerator: auto + + data: + train_dataloader: + batch_size: 32 + num_workers: 4 + +# Experiment inherits defaults +<<: *defaults + +model: + _target_: project.model.SimpleNet +``` + +**`configs/large_batch.yaml`**: + +```yaml +# Override just batch size (compose with base.yaml via CLI) +data: + train_dataloader: + batch_size: 128 +``` + +Run by composing configs: + +```bash +lighter fit configs/base.yaml configs/large_batch.yaml +``` + +## Organizing Larger Projects + +For projects with many modules, organize by functionality: + +``` +my_project/ +├── __lighter__.py +├── __init__.py +├── models/ +│ ├── __init__.py +│ ├── classifier.py +│ ├── segmentation.py +│ └── detection.py +├── data/ +│ ├── __init__.py +│ ├── datasets.py +│ ├── samplers.py +│ └── augmentation.py +├── utils/ +│ ├── __init__.py +│ ├── losses.py +│ └── metrics.py +└── configs/ + ├── classification/ + │ ├── resnet18.yaml + │ └── efficientnet.yaml + └── segmentation/ + └── unet.yaml +``` + +Reference as: + +```yaml +model: + _target_: project.models.classifier.ImageClassifier + network: + _target_: project.models.classifier.ResNet18 + +data: + train_dataloader: + dataset: + _target_: project.data.datasets.CustomDataset + sampler: + _target_: project.data.samplers.BalancedSampler + + criterion: + _target_: project.utils.losses.FocalLoss +``` + +## Troubleshooting + +### Import Error: `ModuleNotFoundError` + +**Problem**: `ModuleNotFoundError: No module named 'project'` + +**Solution**: Check these in order: + +1. `__lighter__.py` exists in project root +2. You're running `lighter` from the directory containing `__lighter__.py` +3. All directories have `__init__.py` + +```bash +# Run from here +my_project/ +├── __lighter__.py # ✅ Exists +├── __init__.py # ✅ Exists +└── model.py + +# Not from here +parent/ +└── my_project/ + └── ... +``` + +### Import Error: `cannot import name 'MyClass'` + +**Problem**: `ImportError: cannot import name 'MyClass' from 'project.model'` + +**Solution**: Check class name matches exactly: + +```python +# model.py +class SimpleNet(nn.Module): # Must match config exactly + ... +``` + +```yaml +# config.yaml +network: + _target_: project.model.SimpleNet # Exact match +``` + +### Attribute Error in Config + +**Problem**: `AttributeError: 'SimpleNet' object has no attribute 'parameters'` + +**Solution**: You're using `::` instead of `.` for Python methods: + +```yaml +# ❌ WRONG +params: "$@model::network::parameters()" + +# ✅ CORRECT +params: "$@model::network.parameters()" +``` + +Remember: `::` navigates config, `.` accesses Python attributes. + +## Best Practices + +### 1. Use Descriptive Module Names + +```python +# ❌ Avoid generic names +class Net(nn.Module): + ... + +# ✅ Use descriptive names +class ResNetCIFAR10(nn.Module): + """ResNet-18 adapted for CIFAR-10.""" + ... +``` + +### 2. Document __init__ Parameters + +Config values map to `__init__` arguments, so document them: + +```python +class CustomDataset(Dataset): + """Custom dataset for my task. + + Args: + root_dir: Path to data directory + split: One of 'train', 'val', 'test' + transform: Optional transform to apply + target_transform: Optional target transform + """ + + def __init__(self, root_dir, split='train', transform=None, target_transform=None): + ... +``` + +### 3. Keep Configs DRY with Variables + +```yaml +vars: + num_classes: 10 + img_size: 224 + base_lr: 0.001 + +model: + network: + num_classes: "%vars::num_classes" + optimizer: + lr: "%vars::base_lr" + +data: + train_dataloader: + dataset: + transform: + - _target_: torchvision.transforms.Resize + size: ["%vars::img_size", "%vars::img_size"] +``` + +### 4. Version Control Configs + +```bash +git add configs/baseline.yaml +git commit -m "Add baseline experiment config" +``` + +Compare experiments: + +```bash +git diff configs/baseline.yaml configs/improved.yaml +``` + +## Complete Project Example + +Here's a full working example: + +``` +cifar10/ +├── __lighter__.py +├── __init__.py +├── model.py +├── data.py +├── configs/ +│ ├── baseline.yaml +│ └── augmented.yaml +└── README.md +``` + +**model.py**: + +```python +import torch.nn as nn +import pytorch_lightning as pl +import torch.nn.functional as F + +class SimpleCNN(nn.Module): + def __init__(self, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(3, 32, 3, padding=1) + self.conv2 = nn.Conv2d(32, 64, 3, padding=1) + self.pool = nn.MaxPool2d(2) + self.fc1 = nn.Linear(64 * 8 * 8, 128) + self.fc2 = nn.Linear(128, num_classes) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + +class CIFAR10Module(pl.LightningModule): + def __init__(self, network, lr=0.001): + super().__init__() + self.network = network + self.lr = lr + + def forward(self, x): + return self.network(x) + + def training_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self(x), y) + self.log("train/loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = (logits.argmax(1) == y).float().mean() + self.log("val/loss", loss) + self.log("val/acc", acc) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.lr) +``` + +**configs/baseline.yaml**: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 + accelerator: auto + +model: + _target_: project.model.CIFAR10Module + lr: 0.001 + network: + _target_: project.model.SimpleCNN + num_classes: 10 + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 64 + shuffle: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.ToTensor + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 64 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + transform: + _target_: torchvision.transforms.ToTensor +``` + +**Run it**: + +```bash +cd cifar10 +lighter fit configs/baseline.yaml +``` + +## Next Steps + +- [Training Guide](training.md) - Run experiments, save outputs +- [Best Practices](best-practices.md) - Production patterns +- [Complete Example](../examples/image-classification.md) - Full CIFAR-10 with all features + +## Quick Reference + +```python +# Project structure +my_project/ +├── __lighter__.py # Required marker +├── __init__.py # Required for imports +└── code.py # Your code + +# Import syntax in config +_target_: project.module.ClassName + +# Common issues +# ❌ No __lighter__.py +# ❌ Missing __init__.py +# ❌ Wrong working directory +# ❌ Using :: for Python methods (use . instead) +``` diff --git a/docs/guides/lighter-module.md b/docs/guides/lighter-module.md new file mode 100644 index 00000000..2668197e --- /dev/null +++ b/docs/guides/lighter-module.md @@ -0,0 +1,1013 @@ +--- +title: Using LighterModule +--- + +# Using LighterModule + +Build models with less boilerplate using `LighterModule`. + +**Key insight**: You write step logic. LighterModule handles optimizers, schedulers, and logging automatically. + +## When to Use This Approach + +**Use LighterModule when:** + +- Starting new projects +- Want less boilerplate code +- Standard training workflows +- Config-driven everything + +**You get:** + +- Automatic `configure_optimizers()` +- Dual logging (step + epoch) +- Config-driven metrics +- All PyTorch Lightning features + +**You write:** + +- Step implementations only (`training_step`, `validation_step`, etc.) +- Your model's forward logic +- That's it! + +## Basic Example + +### Minimal Implementation + +`model.py`: + +```python +from lighter import LighterModule + +class MyModel(LighterModule): + """Minimal model - just implement steps.""" + + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) # Forward pass through self.network + loss = self.criterion(pred, y) # Use self.criterion + + # Update metrics + if self.train_metrics: + self.train_metrics(pred, y) + + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) + + if self.val_metrics: + self.val_metrics(pred, y) + + return {"loss": loss} +``` + +### Config + +`config.yaml`: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 + accelerator: auto + +model: + _target_: model.MyModel + + # Network architecture + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + + # Loss function + criterion: + _target_: torch.nn.CrossEntropyLoss + + # Optimizer (auto-configured!) + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + + # Metrics (optional) + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + + val_metrics: "%model::train_metrics" # Copy config + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.ToTensor +``` + +**That's it!** No `configure_optimizers()`, no manual logging. + +## How It Works + +LighterModule provides these automatically: + +### 1. Forward Pass + +```python +def forward(self, x): + """Calls self.network(x) automatically.""" + return self.network(x) +``` + +You can override if needed: + +```python +def forward(self, x): + # Custom forward logic + features = self.network.encoder(x) + output = self.network.decoder(features) + return output +``` + +### 2. Configure Optimizers + +```python +def configure_optimizers(self): + """Auto-creates optimizer and optional scheduler.""" + # Creates optimizer from config + # Adds scheduler if provided + # Returns proper format for Lightning +``` + +No manual implementation needed! + +### 3. Automatic Logging + +**LighterModule automatically logs:** + +1. **Loss values** - Dual logging (step + epoch) +2. **Metrics** - Dual logging (step + epoch) +3. **Optimizer stats** - Learning rate, momentum, betas, weight decay (epoch only) + +#### Loss Logging + +Return loss from your step methods: + +```python +def training_step(self, batch, batch_idx): + loss = self.criterion(pred, y) + return {"loss": loss} +``` + +**Automatically logged as:** +- `train/loss/step` - Per-step values +- `train/loss/epoch` - Epoch average + +**Multi-component loss:** + +```python +def training_step(self, batch, batch_idx): + return { + "loss": { + "total": total_loss, # Required key + "ce": ce_loss, # Optional component + "reg": reg_loss # Optional component + } + } +``` + +**Logged as:** +- `train/loss/total/step`, `train/loss/total/epoch` +- `train/loss/ce/step`, `train/loss/ce/epoch` +- `train/loss/reg/step`, `train/loss/reg/epoch` + +#### Metrics Logging + +Call metrics in your step methods: + +```python +def training_step(self, batch, batch_idx): + if self.train_metrics: + self.train_metrics(pred, y) + return {"loss": loss} +``` + +**Automatically logged as:** +- `train/metrics/Accuracy/step` - Per-step values +- `train/metrics/Accuracy/epoch` - Epoch average +- `train/metrics/F1Score/step`, `train/metrics/F1Score/epoch` + +#### Optimizer Stats Logging + +Automatically logged at the start of each training epoch: +- `train/optimizer/Adam/lr/epoch` +- `train/optimizer/Adam/beta1/epoch` +- `train/optimizer/Adam/beta2/epoch` + +See [Automatic Optimizer Stats Logging](#automatic-optimizer-stats-logging) for details. + +## What LighterModule Provides + +### Attributes Available + +```python +class MyModel(LighterModule): + def training_step(self, batch, batch_idx): + # Available attributes: + self.network # From config: model::network + self.criterion # From config: model::criterion + self.optimizer # From config: model::optimizer + self.scheduler # From config: model::scheduler (optional) + self.train_metrics # From config: model::train_metrics (optional) + self.val_metrics # From config: model::val_metrics (optional) + self.test_metrics # From config: model::test_metrics (optional) +``` + +All optional except `network` (you need something to run!). + +### Required Implementations + +You **must** implement: + +```python +def training_step(self, batch, batch_idx): + """Required.""" + return {"loss": loss} +``` + +Optional but common: + +```python +def validation_step(self, batch, batch_idx): + """Optional.""" + return {"loss": loss} + +def test_step(self, batch, batch_idx): + """Optional.""" + return {"loss": loss} + +def predict_step(self, batch, batch_idx): + """Optional.""" + return predictions +``` + +## Complete Examples + +### Example 1: Image Classification + +`models.py`: + +```python +from lighter import LighterModule + +class ImageClassifier(LighterModule): + """Image classification with metrics.""" + + def training_step(self, batch, batch_idx): + images, labels = batch + + # Forward pass + logits = self(images) + + # Loss + loss = self.criterion(logits, labels) + + # Metrics + if self.train_metrics: + self.train_metrics(logits, labels) + + # Return dict - all values logged automatically + return { + "loss": loss, + } + + def validation_step(self, batch, batch_idx): + images, labels = batch + logits = self(images) + loss = self.criterion(logits, labels) + + if self.val_metrics: + self.val_metrics(logits, labels) + + return {"loss": loss} + + def test_step(self, batch, batch_idx): + images, labels = batch + logits = self(images) + + if self.test_metrics: + self.test_metrics(logits, labels) + + return {"predictions": logits.argmax(dim=1)} +``` + +`config.yaml`: + +```yaml +model: + _target_: models.ImageClassifier + + network: + _target_: torchvision.models.resnet50 + weights: IMAGENET1K_V2 # Pretrained + num_classes: 10 + + criterion: + _target_: torch.nn.CrossEntropyLoss + label_smoothing: 0.1 + + optimizer: + _target_: torch.optim.AdamW + params: "$@model::network.parameters()" + lr: 0.001 + weight_decay: 0.01 + + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + optimizer: "@model::optimizer" + T_max: 100 + + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + - _target_: torchmetrics.F1Score + task: multiclass + num_classes: 10 + average: macro + + val_metrics: "%model::train_metrics" + test_metrics: "%model::train_metrics" +``` + +### Example 2: Semantic Segmentation + +`models.py`: + +```python +from lighter import LighterModule +import torch.nn.functional as F + +class SemanticSegmentation(LighterModule): + """Semantic segmentation with dice loss.""" + + def training_step(self, batch, batch_idx): + images, masks = batch + + # Forward + logits = self(images) + + # Resize logits to match mask size if needed + if logits.shape[-2:] != masks.shape[-2:]: + logits = F.interpolate( + logits, + size=masks.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Loss + loss = self.criterion(logits, masks) + + # Metrics + if self.train_metrics: + preds = logits.argmax(dim=1) + self.train_metrics(preds, masks) + + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + images, masks = batch + logits = self(images) + + if logits.shape[-2:] != masks.shape[-2:]: + logits = F.interpolate( + logits, + size=masks.shape[-2:], + mode='bilinear', + align_corners=False + ) + + loss = self.criterion(logits, masks) + + if self.val_metrics: + preds = logits.argmax(dim=1) + self.val_metrics(preds, masks) + + return {"loss": loss} +``` + +`config.yaml`: + +```yaml +model: + _target_: models.SemanticSegmentation + + network: + _target_: segmentation_models_pytorch.Unet + encoder_name: resnet34 + encoder_weights: imagenet + in_channels: 3 + classes: 21 + + criterion: + _target_: segmentation_models_pytorch.losses.DiceLoss + mode: multiclass + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.0001 + + train_metrics: + - _target_: torchmetrics.JaccardIndex + task: multiclass + num_classes: 21 + + val_metrics: "%model::train_metrics" +``` + +### Example 3: Multi-Task Learning + +`models.py`: + +```python +from lighter import LighterModule + +class MultiTaskModel(LighterModule): + """Multi-task: classification + regression.""" + + def __init__(self, network, criterion_cls, criterion_reg, + optimizer, alpha=0.5): + super().__init__( + network=network, + criterion=None, # We have multiple + optimizer=optimizer + ) + self.criterion_cls = criterion_cls + self.criterion_reg = criterion_reg + self.alpha = alpha # Task weighting + + def training_step(self, batch, batch_idx): + images, labels_cls, labels_reg = batch + + # Forward + out_cls, out_reg = self(images) + + # Two losses + loss_cls = self.criterion_cls(out_cls, labels_cls) + loss_reg = self.criterion_reg(out_reg, labels_reg) + + # Combined loss + loss = self.alpha * loss_cls + (1 - self.alpha) * loss_reg + + # Return all for logging + return { + "loss": loss, + "loss_cls": loss_cls.detach(), + "loss_reg": loss_reg.detach(), + } + + def validation_step(self, batch, batch_idx): + images, labels_cls, labels_reg = batch + out_cls, out_reg = self(images) + + loss_cls = self.criterion_cls(out_cls, labels_cls) + loss_reg = self.criterion_reg(out_reg, labels_reg) + loss = self.alpha * loss_cls + (1 - self.alpha) * loss_reg + + # Accuracy for classification head + acc = (out_cls.argmax(1) == labels_cls).float().mean() + + return { + "loss": loss, + "loss_cls": loss_cls, + "loss_reg": loss_reg, + "acc": acc, + } +``` + +### Example 4: Custom Forward Pass + +Override `forward` for custom logic: + +```python +from lighter import LighterModule + +class AutoencoderModel(LighterModule): + """Autoencoder with custom forward.""" + + def forward(self, x): + """Custom forward through encoder-decoder.""" + latent = self.network.encoder(x) + reconstruction = self.network.decoder(latent) + return reconstruction, latent + + def training_step(self, batch, batch_idx): + images, _ = batch + + # Forward returns tuple + reconstruction, latent = self(images) + + # Reconstruction loss + loss_recon = self.criterion(reconstruction, images) + + # Optional: regularization on latent + loss_kl = 0.001 * (latent ** 2).mean() + + loss = loss_recon + loss_kl + + return { + "loss": loss, + "loss_recon": loss_recon.detach(), + "loss_kl": loss_kl.detach(), + } + + def validation_step(self, batch, batch_idx): + images, _ = batch + reconstruction, _ = self(images) + loss = self.criterion(reconstruction, images) + return {"loss": loss} +``` + +## Adding Schedulers + +LighterModule handles schedulers automatically: + +```yaml +model: + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + + scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + optimizer: "@model::optimizer" # Reference optimizer + T_max: 100 + eta_min: 0.00001 +``` + +**Supported scheduler types:** + +### Step-based + +```yaml +scheduler: + _target_: torch.optim.lr_scheduler.StepLR + optimizer: "@model::optimizer" + step_size: 30 + gamma: 0.1 +``` + +### Plateau-based + +```yaml +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + optimizer: "@model::optimizer" + mode: min + factor: 0.5 + patience: 10 +``` + +### Warmup + +```yaml +scheduler: + _target_: torch.optim.lr_scheduler.LinearLR + optimizer: "@model::optimizer" + start_factor: 0.1 + total_iters: 1000 +``` + +### Chained Schedulers + +For complex schedules, override `configure_optimizers`: + +```python +def configure_optimizers(self): + # Warmup then cosine + warmup = torch.optim.lr_scheduler.LinearLR( + self.optimizer, + start_factor=0.1, + total_iters=1000 + ) + cosine = torch.optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, + T_max=self.trainer.max_epochs - 10 + ) + + scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, + schedulers=[warmup, cosine], + milestones=[10] + ) + + return { + "optimizer": self.optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + } + } +``` + +## Automatic Optimizer Stats Logging + +**LighterModule automatically logs optimizer statistics** including learning rate, momentum, betas, and weight decay at the start of each training epoch. You do **not** need to add `LearningRateMonitor` callback. + +Logged stats (per parameter group): +- **Learning rate**: `train/optimizer/{OptimizerName}/lr/epoch` +- **Momentum**: `train/optimizer/{OptimizerName}/momentum/epoch` (SGD, RMSprop) +- **Beta1/Beta2**: `train/optimizer/{OptimizerName}/beta1/epoch`, `beta2/epoch` (Adam variants) +- **Weight decay**: `train/optimizer/{OptimizerName}/weight_decay/epoch` (if non-zero) + +For multiple parameter groups (e.g., differential learning rates): +- `train/optimizer/{OptimizerName}/lr/group1/epoch` +- `train/optimizer/{OptimizerName}/lr/group2/epoch` + +**Example logged metrics:** + +``` +train/optimizer/Adam/lr/epoch: 0.001 +train/optimizer/Adam/beta1/epoch: 0.9 +train/optimizer/Adam/beta2/epoch: 0.999 +``` + +!!! note "No LearningRateMonitor needed" + The PyTorch Lightning `LearningRateMonitor` callback is redundant with LighterModule since optimizer stats are already logged automatically. + +## Working with Metrics + +### Single Metric + +```yaml +model: + train_metrics: + _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 +``` + +Update in code: + +```python +if self.train_metrics: + self.train_metrics(preds, targets) +``` + +### Multiple Metrics + +```yaml +model: + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + - _target_: torchmetrics.F1Score + task: multiclass + num_classes: 10 + - _target_: torchmetrics.Precision + task: multiclass + num_classes: 10 +``` + +Update in code (same!): + +```python +if self.train_metrics: + self.train_metrics(preds, targets) # Updates all metrics +``` + +### Metric Collections + +Use `MetricCollection` for grouped metrics: + +```yaml +model: + train_metrics: + _target_: torchmetrics.MetricCollection + metrics: + accuracy: + _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + f1: + _target_: torchmetrics.F1Score + task: multiclass + num_classes: 10 +``` + +### Per-Class Metrics + +```yaml +model: + val_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + average: none # Per-class accuracy +``` + +## Custom Initialization + +Need custom setup? Override `__init__`: + +```python +from lighter import LighterModule + +class MyModel(LighterModule): + def __init__(self, network, criterion, optimizer, + special_param=42): + super().__init__( + network=network, + criterion=criterion, + optimizer=optimizer + ) + + # Custom initialization + self.special_param = special_param + self.custom_buffer = [] + + # Freeze backbone + for param in self.network.backbone.parameters(): + param.requires_grad = False + + def training_step(self, batch, batch_idx): + # Use custom attributes + if batch_idx % self.special_param == 0: + self.custom_buffer.append(batch_idx) + + # ... rest of step ... +``` + +Config: + +```yaml +model: + _target_: models.MyModel + special_param: 100 + network: + _target_: ... + criterion: + _target_: ... + optimizer: + _target_: ... +``` + +## Using Lightning Hooks + +All Lightning hooks work: + +```python +class MyModel(LighterModule): + def on_train_start(self): + print("Training starting!") + + def on_train_epoch_end(self): + # Log custom metrics + avg_loss = self.trainer.callback_metrics.get('train/loss') + if avg_loss is not None: + print(f"Epoch {self.current_epoch}: {avg_loss:.4f}") + + def on_validation_epoch_end(self): + # Custom validation logic + pass + + def on_save_checkpoint(self, checkpoint): + # Add custom data + checkpoint['my_data'] = self.custom_buffer + + def on_load_checkpoint(self, checkpoint): + # Load custom data + self.custom_buffer = checkpoint.get('my_data', []) +``` + +## Differential Learning Rates + +Use parameter groups in optimizer config: + +```yaml +model: + optimizer: + _target_: torch.optim.SGD + params: + - params: "$@model::network.backbone.parameters()" + lr: 0.0001 # Low LR for pretrained backbone + - params: "$@model::network.head.parameters()" + lr: 0.01 # High LR for new head + momentum: 0.9 +``` + +## Gradient Accumulation + +Use Trainer config: + +```yaml +trainer: + accumulate_grad_batches: 4 # Accumulate 4 batches +``` + +Effective batch size = batch_size × accumulate_grad_batches. + +## Mixed Precision Training + +```yaml +trainer: + precision: 16 # Use 16-bit precision +``` + +Or: + +```yaml +trainer: + precision: "bf16-mixed" # BFloat16 mixed precision +``` + +## Saving Predictions + +Override `predict_step`: + +```python +def predict_step(self, batch, batch_idx): + images, _ = batch + predictions = self(images) + + return { + "predictions": predictions.argmax(dim=1), + "probabilities": predictions.softmax(dim=1), + } +``` + +Run: + +```bash +lighter predict config.yaml +``` + +Use Writers to save to files - see [Training Guide](training.md#saving-predictions). + +## Validation Without Training + +```python +def validation_step(self, batch, batch_idx): + # Just metrics, no loss needed + x, y = batch + pred = self(x) + + if self.val_metrics: + self.val_metrics(pred, y) + + return {} # Empty dict is fine +``` + +## Common Patterns + +### Pattern 1: Return Dict for Auto-Logging + +```python +def training_step(self, batch, batch_idx): + # Everything in the dict gets logged automatically + return { + "loss": loss, # Required + "accuracy": accuracy, # Optional + "learning_rate": current_lr, # Optional + "custom_metric": custom_value, # Optional + } +``` + +Logs as: + +- `train/loss` +- `train/accuracy` +- `train/learning_rate` +- `train/custom_metric` + +### Pattern 2: Conditional Logging + +```python +def training_step(self, batch, batch_idx): + loss = self.criterion(self(x), y) + + # Only log images every 100 steps + if batch_idx % 100 == 0: + self.logger.experiment.add_images( + "train/images", + x[:8], + self.global_step + ) + + return {"loss": loss} +``` + +### Pattern 3: Custom Metric Update + +```python +def validation_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) + + # Update specific metrics with transforms + if self.val_metrics: + # Apply softmax before metric + probs = pred.softmax(dim=1) + self.val_metrics(probs, y) + + return {"loss": loss} +``` + +## Comparison: LighterModule vs LightningModule + +| Feature | LighterModule | LightningModule | +|---------|--------------|---------------------| +| Boilerplate | Less | More | +| configure_optimizers | Automatic | Manual | +| Logging | Automatic dual logging | Manual | +| Metrics | Config-driven | Code-driven | +| Learning curve | Learn LighterModule | Just Lightning | +| Flexibility | Standard patterns | Full control | +| Migration | Adapt existing code | Use as-is | + +**Choose LighterModule when:** + +- Starting fresh +- Want minimal code +- Standard workflows +- Config everything + +**Choose LightningModule when:** + +- Have existing code +- Need custom logic +- Want full control +- Complex training loops + +Both give you YAML configs and CLI overrides! + +## Next Steps + +- [Lightning Module Guide](lightning-module.md) - Compare with the other approach +- [Training Guide](training.md) - Run experiments, save outputs +- [Best Practices](best-practices.md) - Production patterns + +## Quick Reference + +```python +from lighter import LighterModule + +class MyModel(LighterModule): + # Optional custom __init__ + def __init__(self, network, criterion, optimizer, **kwargs): + super().__init__( + network=network, + criterion=criterion, + optimizer=optimizer + ) + + # Required: training step + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) + + if self.train_metrics: + self.train_metrics(pred, y) + + return {"loss": loss} + + # Optional: validation step + def validation_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) + + if self.val_metrics: + self.val_metrics(pred, y) + + return {"loss": loss} +``` + +```yaml +# Config +model: + _target_: models.MyModel + network: + _target_: ... + criterion: + _target_: ... + optimizer: + _target_: ... + params: "$@model::network.parameters()" +``` diff --git a/docs/guides/lightning-module.md b/docs/guides/lightning-module.md new file mode 100644 index 00000000..3d423028 --- /dev/null +++ b/docs/guides/lightning-module.md @@ -0,0 +1,802 @@ +--- +title: Using LightningModule +--- + +# Using LightningModule + +Use your existing PyTorch Lightning code with Lighter's configuration system. + +**Key insight**: You don't rewrite your LightningModule. You just add YAML configs. + +## When to Use This Approach + +**Use LightningModule when:** + +- You have existing Lightning code +- You need custom training logic +- You want full control over step methods +- You're integrating with existing projects + +**You get:** + +- YAML configuration for hyperparameters +- CLI overrides without code changes +- Experiment tracking and versioning +- All PyTorch Lightning features + +**You write:** + +- All step methods (`training_step`, `validation_step`, etc.) +- `configure_optimizers()` +- Your own logging +- Custom hooks and callbacks + +## Basic Example + +### Your Existing Module + +`model.py`: + +```python +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +class ImageClassifier(pl.LightningModule): + """Standard PyTorch Lightning module.""" + + def __init__(self, num_classes=10, learning_rate=0.001): + super().__init__() + self.save_hyperparameters() + + self.model = torch.nn.Sequential( + torch.nn.Conv2d(3, 64, 3, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2), + torch.nn.Conv2d(64, 128, 3, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2), + torch.nn.Flatten(), + torch.nn.Linear(128 * 8 * 8, 256), + torch.nn.ReLU(), + torch.nn.Linear(256, num_classes), + ) + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = (logits.argmax(dim=1) == y).float().mean() + + self.log("train/loss", loss) + self.log("train/acc", acc) + + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = (logits.argmax(dim=1) == y).float().mean() + + self.log("val/loss", loss) + self.log("val/acc", acc) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + return optimizer +``` + +### Add Config + +`config.yaml`: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 + accelerator: auto + +model: + _target_: model.ImageClassifier # Your module! + num_classes: 10 + learning_rate: 0.001 + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.ToTensor + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + transform: + _target_: torchvision.transforms.ToTensor +``` + +### Run + +```bash +# Run with default config +lighter fit config.yaml + +# Override learning rate +lighter fit config.yaml model::learning_rate=0.01 + +# Override multiple values +lighter fit config.yaml \ + model::learning_rate=0.01 \ + trainer::max_epochs=100 \ + data::train_dataloader::batch_size=64 +``` + +**That's it!** Your existing Lightning code works unchanged. + +## Advanced Examples + +### Example 1: Complex Network Architecture + +Pass complex architectures through config: + +`models.py`: + +```python +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +class FlexibleClassifier(pl.LightningModule): + """Accepts any network architecture.""" + + def __init__(self, network, learning_rate=0.001): + super().__init__() + self.save_hyperparameters(ignore=['network']) + self.network = network + self.lr = learning_rate + + def forward(self, x): + return self.network(x) + + def training_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self(x), y) + self.log("train/loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = (logits.argmax(1) == y).float().mean() + + self.log("val/loss", loss) + self.log("val/acc", acc) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.lr) +``` + +`config.yaml`: + +```yaml +model: + _target_: models.FlexibleClassifier + learning_rate: 0.001 + + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + weights: null # Train from scratch +``` + +Now you can swap architectures by changing just the config: + +```yaml +# Try ResNet50 +network: + _target_: torchvision.models.resnet50 + num_classes: 10 + +# Try EfficientNet +network: + _target_: torchvision.models.efficientnet_b0 + num_classes: 10 + +# Try your custom network +network: + _target_: my_project.networks.CustomNet + num_classes: 10 + hidden_dim: 512 +``` + +### Example 2: Multiple Optimizers + +For GANs or other multi-optimizer setups: + +`gan.py`: + +```python +import pytorch_lightning as pl +import torch + +class GAN(pl.LightningModule): + def __init__(self, generator, discriminator, lr_g=0.0002, lr_d=0.0002): + super().__init__() + self.save_hyperparameters(ignore=['generator', 'discriminator']) + self.generator = generator + self.discriminator = discriminator + + def training_step(self, batch, batch_idx, optimizer_idx): + real_imgs, _ = batch + + # Train generator + if optimizer_idx == 0: + z = torch.randn(real_imgs.size(0), self.generator.latent_dim) + fake_imgs = self.generator(z) + g_loss = -torch.mean(self.discriminator(fake_imgs)) + self.log("train/g_loss", g_loss) + return g_loss + + # Train discriminator + if optimizer_idx == 1: + z = torch.randn(real_imgs.size(0), self.generator.latent_dim) + fake_imgs = self.generator(z).detach() + + d_loss_real = -torch.mean(self.discriminator(real_imgs)) + d_loss_fake = torch.mean(self.discriminator(fake_imgs)) + d_loss = d_loss_real + d_loss_fake + + self.log("train/d_loss", d_loss) + return d_loss + + def configure_optimizers(self): + opt_g = torch.optim.Adam( + self.generator.parameters(), + lr=self.hparams.lr_g, + betas=(0.5, 0.999) + ) + opt_d = torch.optim.Adam( + self.discriminator.parameters(), + lr=self.hparams.lr_d, + betas=(0.5, 0.999) + ) + return [opt_g, opt_d] +``` + +`config.yaml`: + +```yaml +model: + _target_: gan.GAN + lr_g: 0.0002 + lr_d: 0.0002 + + generator: + _target_: gan.Generator + latent_dim: 100 + img_shape: [3, 32, 32] + + discriminator: + _target_: gan.Discriminator + img_shape: [3, 32, 32] +``` + +### Example 3: Custom Metrics + +Use torchmetrics or your own: + +`models.py`: + +```python +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics + +class MetricsModule(pl.LightningModule): + def __init__(self, network, num_classes=10, learning_rate=0.001): + super().__init__() + self.save_hyperparameters(ignore=['network']) + self.network = network + + # Initialize metrics + self.train_acc = torchmetrics.Accuracy( + task='multiclass', + num_classes=num_classes + ) + self.val_acc = torchmetrics.Accuracy( + task='multiclass', + num_classes=num_classes + ) + self.val_f1 = torchmetrics.F1Score( + task='multiclass', + num_classes=num_classes + ) + + def forward(self, x): + return self.network(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + + # Update metrics + self.train_acc(logits, y) + + # Log + self.log("train/loss", loss) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) + + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + + # Update metrics + self.val_acc(logits, y) + self.val_f1(logits, y) + + # Log + self.log("val/loss", loss) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) + self.log("val/f1", self.val_f1, on_step=False, on_epoch=True) + + def configure_optimizers(self): + return torch.optim.Adam( + self.parameters(), + lr=self.hparams.learning_rate + ) +``` + +### Example 4: Learning Rate Schedulers + +Add schedulers in `configure_optimizers`: + +```python +def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=self.trainer.max_epochs, + eta_min=1e-6 + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + "frequency": 1, + } + } +``` + +Or use ReduceLROnPlateau: + +```python +def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + patience=5, + min_lr=1e-6 + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", # Metric to monitor + "interval": "epoch", + } + } +``` + +### Example 5: Gradient Clipping + +Add in config or code: + +**Option 1: In Config** + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + gradient_clip_val: 0.5 + gradient_clip_algorithm: norm +``` + +**Option 2: In Code** + +```python +def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): + self.clip_gradients( + optimizer, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm + ) +``` + +### Example 6: Model Hooks + +Use Lightning hooks for custom behavior: + +```python +class MyModule(pl.LightningModule): + def __init__(self, network, learning_rate=0.001): + super().__init__() + self.network = network + self.lr = learning_rate + + def on_train_start(self): + """Called when training starts.""" + print(f"Starting training with LR: {self.lr}") + + def on_train_epoch_end(self): + """Called at the end of each epoch.""" + # Log learning rate + current_lr = self.trainer.optimizers[0].param_groups[0]['lr'] + self.log("train/lr", current_lr) + + def on_validation_epoch_end(self): + """Called after validation epoch.""" + # Custom validation logic + pass + + def on_save_checkpoint(self, checkpoint): + """Modify what gets saved.""" + checkpoint['custom_data'] = {'my_value': 42} + + def on_load_checkpoint(self, checkpoint): + """Load custom data.""" + custom_data = checkpoint.get('custom_data', {}) + print(f"Loaded custom data: {custom_data}") +``` + +All Lightning hooks work normally! + +## Passing Data from Config + +You can pass **any** data structure through config: + +### Lists + +```yaml +model: + _target_: models.MyModule + layer_sizes: [64, 128, 256, 512] +``` + +```python +def __init__(self, layer_sizes): + layers = [] + for in_size, out_size in zip(layer_sizes[:-1], layer_sizes[1:]): + layers.append(nn.Linear(in_size, out_size)) + layers.append(nn.ReLU()) + self.network = nn.Sequential(*layers) +``` + +### Dicts + +```yaml +model: + _target_: models.MyModule + config: + hidden_dim: 256 + num_layers: 4 + dropout: 0.1 +``` + +```python +def __init__(self, config): + self.hidden_dim = config['hidden_dim'] + self.num_layers = config['num_layers'] + self.dropout = config['dropout'] +``` + +### Nested Objects + +```yaml +model: + _target_: models.MyModule + encoder: + _target_: models.Encoder + hidden_dim: 256 + decoder: + _target_: models.Decoder + hidden_dim: 256 +``` + +```python +def __init__(self, encoder, decoder): + self.encoder = encoder + self.decoder = decoder +``` + +## Integration with Callbacks + +Use any PyTorch Lightning callback: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 100 + callbacks: + # Model checkpointing + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 3 + filename: 'epoch{epoch:02d}-acc{val/acc:.4f}' + + # Early stopping + - _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val/loss + patience: 10 + mode: min + + # Learning rate monitor + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + + # Custom callback + - _target_: my_project.callbacks.MyCustomCallback + some_param: 42 +``` + +## Integration with Loggers + +Use any PyTorch Lightning logger: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: ./logs + name: my_experiment +``` + +Or multiple loggers: + +```yaml +trainer: + logger: + - _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: ./logs + + - _target_: pytorch_lightning.loggers.CSVLogger + save_dir: ./logs + + - _target_: pytorch_lightning.loggers.WandbLogger + project: my_project + name: experiment_1 +``` + +## Testing and Validation + +Your module works with all Lightning testing features: + +### Fast Dev Run + +```bash +lighter fit config.yaml trainer::fast_dev_run=true +``` + +Runs 1 batch of train/val to catch bugs. + +### Validation Only + +```bash +lighter validate config.yaml args::validate::ckpt_path=checkpoints/best.ckpt +``` + +### Testing + +```bash +lighter test config.yaml args::test::ckpt_path=checkpoints/best.ckpt +``` + +### Overfit on Small Batch + +```bash +lighter fit config.yaml trainer::overfit_batches=10 +``` + +## Common Patterns + +### Pattern 1: Save Hyperparameters + +Always save hyperparameters for reproducibility: + +```python +def __init__(self, network, learning_rate=0.001, weight_decay=0.0): + super().__init__() + # Save all args except network (it's not serializable) + self.save_hyperparameters(ignore=['network']) + self.network = network +``` + +Now `self.hparams` contains your config: + +```python +def configure_optimizers(self): + return torch.optim.Adam( + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay + ) +``` + +### Pattern 2: Separate Forward from Loss + +Keep forward pass separate from loss calculation: + +```python +def forward(self, x): + """Just the forward pass.""" + return self.network(x) + +def training_step(self, batch, batch_idx): + """Loss calculation and logging.""" + x, y = batch + logits = self(x) # Call forward + loss = F.cross_entropy(logits, y) + self.log("train/loss", loss) + return loss +``` + +This makes your model usable for inference. + +### Pattern 3: Shared Step Logic + +Reduce duplication with shared methods: + +```python +def _shared_step(self, batch, stage): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + acc = (logits.argmax(1) == y).float().mean() + + self.log(f"{stage}/loss", loss) + self.log(f"{stage}/acc", acc) + + return loss + +def training_step(self, batch, batch_idx): + return self._shared_step(batch, "train") + +def validation_step(self, batch, batch_idx): + return self._shared_step(batch, "val") + +def test_step(self, batch, batch_idx): + return self._shared_step(batch, "test") +``` + +## Migration from Pure Lightning + +Already have a Lightning project? Add Lighter in 3 steps: + +### Step 1: Add `__lighter__.py` + +```bash +cd my_lightning_project +touch __lighter__.py +``` + +### Step 2: Create Config + +`config.yaml`: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + # Copy your Trainer args from your Python script + +model: + _target_: my_project.models.MyLightningModule + # Copy your module's __init__ args + +data: + _target_: my_project.data.MyDataModule + # Or use LighterDataModule +``` + +### Step 3: Run + +```bash +lighter fit config.yaml +``` + +Your code works unchanged! + +## Comparison: LightningModule vs LighterModule + +Both approaches are fully supported. Here's when to use each: + +| Feature | LightningModule | LighterModule | +|---------|---------------------|---------------| +| Use existing code | ✅ Yes | ❌ Need to adapt | +| Custom training logic | ✅ Full control | ⚠️ Must fit pattern | +| Auto configure_optimizers | ❌ Manual | ✅ Automatic | +| Automatic logging | ❌ Manual | ✅ Dual logging | +| Learning curve | Easy if you know Lightning | Need to learn LighterModule | +| Boilerplate | More | Less | + +**Use your LightningModule when:** + +- Migrating existing projects +- Need full control +- Have custom training loops +- Team knows Lightning well + +**Use LighterModule when:** + +- Starting new projects +- Want less boilerplate +- Standard workflows +- Config-driven everything + +Both give you YAML configs and CLI overrides! + +## Next Steps + +- [LighterModule Guide](lighter-module.md) - Compare with the other approach +- [Training Guide](training.md) - Run experiments, save outputs +- [Best Practices](best-practices.md) - Production patterns + +## Quick Reference + +```python +# LightningModule works as-is +class MyModule(pl.LightningModule): + def __init__(self, network, learning_rate=0.001): + super().__init__() + self.save_hyperparameters(ignore=['network']) + self.network = network + + def training_step(self, batch, batch_idx): + # Your logic + return loss + + def configure_optimizers(self): + # Your optimizer + return optimizer +``` + +```yaml +# Reference it in config +model: + _target_: my_project.models.MyModule + learning_rate: 0.001 + network: + _target_: torchvision.models.resnet18 +``` + +```bash +# Run with overrides +lighter fit config.yaml model::learning_rate=0.01 +``` diff --git a/docs/guides/training.md b/docs/guides/training.md new file mode 100644 index 00000000..8d676f5f --- /dev/null +++ b/docs/guides/training.md @@ -0,0 +1,1024 @@ +--- +title: Training Guide +--- + +# Training Guide + +Run experiments, track results, and save outputs with Lighter. + +This guide covers the full training workflow from config to results. + +## Basic Commands + +Lighter provides four main commands: + +```bash +# Train and validate +lighter fit config.yaml + +# Validate only (requires checkpoint) +lighter validate config.yaml + +# Test only (requires checkpoint) +lighter test config.yaml + +# Run inference +lighter predict config.yaml +``` + +All commands use the same config structure. + +## The Fit Command + +Train your model with automatic validation: + +```bash +lighter fit config.yaml +``` + +### What Happens + +1. Loads config from YAML +2. Instantiates trainer, model, and data +3. Runs training loop with validation +4. Saves checkpoints automatically +5. Logs metrics to configured logger + +### Output Structure + +``` +outputs/ +└── YYYY-MM-DD/ + └── HH-MM-SS/ + ├── config.yaml # Copy of config used + ├── checkpoints/ + │ ├── last.ckpt # Latest checkpoint + │ └── epoch=09-step=1000.ckpt + └── logs/ # Tensorboard/CSV logs +``` + +### Resuming Training + +Resume from latest checkpoint: + +```bash +lighter fit config.yaml --ckpt_path path/to/checkpoint.ckpt +``` + +## Overriding from CLI + +Change any config value without editing files: + +### Single Override + +```bash +# Change learning rate +lighter fit config.yaml model::optimizer::lr=0.01 + +# Train longer +lighter fit config.yaml trainer::max_epochs=100 + +# Use more GPUs +lighter fit config.yaml trainer::devices=4 +``` + +### Multiple Overrides + +```bash +lighter fit config.yaml \ + model::optimizer::lr=0.01 \ + trainer::max_epochs=100 \ + data::train_dataloader::batch_size=64 \ + trainer::devices=4 +``` + +### Nested Overrides + +```bash +# Override nested values +lighter fit config.yaml \ + model::network::num_classes=100 \ + model::optimizer::weight_decay=0.0001 +``` + +### Complex Overrides + +```bash +# Add callbacks from CLI +lighter fit config.yaml \ + 'trainer::callbacks=[{_target_: pytorch_lightning.callbacks.EarlyStopping, monitor: val/loss}]' +``` + +## Merging Configs + +Combine multiple YAML files: + +```bash +lighter fit base.yaml,experiment.yaml +``` + +### Example: Base + Experiment + +`base.yaml`: + +```yaml +trainer: + max_epochs: 100 + accelerator: auto + devices: 1 + +model: + _target_: models.MyModel + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + +data: + _target_: lighter.LighterDataModule + train_dataloader: + batch_size: 32 +``` + +`experiment.yaml`: + +```yaml +# Override specific values +trainer: + max_epochs: 200 # Override + devices: 4 # Add + +model: + optimizer: + lr: 0.01 # Add optimizer config +``` + +**Result**: Merged config with `max_epochs=200`, `devices=4`, new optimizer. + +### Merge Operators + +Control how configs merge: + +**Replace with `=`**: + +```yaml +# experiment.yaml +trainer: + =callbacks: # Replace entire list + - _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val/loss +``` + +**Delete with `~`**: + +```yaml +# experiment.yaml +trainer: + ~callbacks: null # Remove callbacks entirely + +data: + ~test_dataloader: null # Remove test dataloader +``` + +## Checkpointing + +### Automatic Checkpointing + +Lightning saves `last.ckpt` automatically. For more control: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: checkpoints + filename: 'epoch{epoch:02d}-loss{val/loss:.4f}' + monitor: val/loss + mode: min + save_top_k: 3 # Keep best 3 + save_last: true # Keep last checkpoint + every_n_epochs: 1 # Save every epoch +``` + +### Save Based on Metric + +```yaml +# Save best validation accuracy +- _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 1 + filename: 'best-acc{val/acc:.4f}' +``` + +### Multiple Checkpointers + +Save different metrics: + +```yaml +trainer: + callbacks: + # Best accuracy + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/acc + mode: max + save_top_k: 1 + filename: 'best-acc' + + # Best loss + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val/loss + mode: min + save_top_k: 1 + filename: 'best-loss' + + # Regular saves + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_epochs: 10 + filename: 'epoch{epoch:02d}' +``` + +### Loading Checkpoints + +**For validation/testing**: + +```bash +lighter validate config.yaml --ckpt_path checkpoints/best.ckpt +lighter test config.yaml --ckpt_path checkpoints/best.ckpt +``` + +**For inference**: + +```bash +lighter predict config.yaml --ckpt_path checkpoints/best.ckpt +``` + +**To resume training**: + +```bash +lighter fit config.yaml --ckpt_path checkpoints/last.ckpt +``` + +## Logging + +### TensorBoard (Default) + +```yaml +trainer: + logger: + _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: logs + name: my_experiment +``` + +View logs: + +```bash +tensorboard --logdir logs +``` + +### CSV Logger + +```yaml +trainer: + logger: + _target_: pytorch_lightning.loggers.CSVLogger + save_dir: logs + name: my_experiment +``` + +Results saved to `logs/my_experiment/version_0/metrics.csv`. + +### Weights & Biases + +```yaml +trainer: + logger: + _target_: pytorch_lightning.loggers.WandbLogger + project: my_project + name: experiment_1 + save_dir: logs +``` + +### Multiple Loggers + +Use all at once: + +```yaml +trainer: + logger: + - _target_: pytorch_lightning.loggers.TensorBoardLogger + save_dir: logs + + - _target_: pytorch_lightning.loggers.CSVLogger + save_dir: logs + + - _target_: pytorch_lightning.loggers.WandbLogger + project: my_project +``` + +### No Logging + +Disable logging: + +```yaml +trainer: + logger: false +``` + +## Saving Predictions + +Use Writers to save predictions to files. + +### CSV Writer + +Save predictions to CSV: + +```yaml +trainer: + callbacks: + - _target_: lighter.callbacks.CSVWriter + write_interval: batch # or 'epoch' +``` + +Your `predict_step` should return a dict: + +```python +def predict_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + + return { + "prediction": pred.argmax(dim=1), + "probability": pred.max(dim=1).values, + "target": y, + } +``` + +Output: `predictions.csv` with columns for each key. + +### File Writer + +Save predictions to individual files: + +```yaml +trainer: + callbacks: + - _target_: lighter.callbacks.FileWriter + write_interval: batch +``` + +Return dict with data and filenames: + +```python +def predict_step(self, batch, batch_idx, dataloader_idx=0): + images, paths = batch + + predictions = self(images) + + # Save each prediction + results = [] + for i, (pred, path) in enumerate(zip(predictions, paths)): + results.append({ + "prediction": pred.cpu().numpy(), + "$id": f"pred_{batch_idx}_{i}", # Unique filename + }) + + return results +``` + +Saves: `predictions/pred_0_0.npz`, `pred_0_1.npz`, etc. + +### Custom Writer + +Create your own: + +```python +from lighter.callbacks import BaseWriter + +class CustomWriter(BaseWriter): + def write(self, data): + """Save data however you want.""" + # data is what you returned from predict_step + output_path = self.output_dir / f"{data['$id']}.pkl" + + with open(output_path, 'wb') as f: + pickle.dump(data, f) +``` + +Use in config: + +```yaml +trainer: + callbacks: + - _target_: my_project.writers.CustomWriter + write_interval: batch +``` + +## Debugging + +### Fast Dev Run + +Run 1 batch of train/val/test to catch bugs: + +```bash +lighter fit config.yaml trainer::fast_dev_run=true +``` + +Or specify number of batches: + +```bash +lighter fit config.yaml trainer::fast_dev_run=5 +``` + +### Overfit on Small Batch + +Test if model can overfit (sanity check): + +```bash +lighter fit config.yaml trainer::overfit_batches=10 +``` + +Trains on same 10 batches repeatedly. + +### Limit Batches + +Run partial epoch: + +```bash +# Train on 10% of data +lighter fit config.yaml \ + trainer::limit_train_batches=0.1 \ + trainer::limit_val_batches=0.1 +``` + +Or specific number: + +```bash +lighter fit config.yaml trainer::limit_train_batches=100 +``` + +### Profiler + +Profile your code: + +```bash +lighter fit config.yaml trainer::profiler=simple +``` + +Options: + +- `simple` - Basic profiling +- `advanced` - Detailed profiling +- `pytorch` - PyTorch profiler + +Results saved to logs directory. + +### Find Learning Rate + +Automatically find optimal LR: + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + callbacks: + - _target_: pytorch_lightning.callbacks.LearningRateFinder + min_lr: 1e-6 + max_lr: 1.0 +``` + +Or run tuner: + +```bash +lighter fit config.yaml trainer::auto_lr_find=true +``` + +## Multi-GPU Training + +### Single Machine, Multiple GPUs + +```yaml +trainer: + devices: 4 # Use 4 GPUs + strategy: ddp # Distributed Data Parallel +``` + +Or use all available GPUs: + +```yaml +trainer: + devices: -1 # All GPUs + strategy: ddp +``` + +### Strategy Options + +**DDP (Recommended)**: + +```yaml +trainer: + strategy: ddp +``` + +**DDP Spawn**: + +```yaml +trainer: + strategy: ddp_spawn +``` + +**DeepSpeed**: + +```yaml +trainer: + strategy: + _target_: pytorch_lightning.strategies.DeepSpeedStrategy + stage: 2 +``` + +**FSDP (Fully Sharded)**: + +```yaml +trainer: + strategy: fsdp +``` + +### Batch Size Adjustment + +Scale batch size with GPUs: + +```yaml +vars: + num_gpus: 4 + per_gpu_batch: 32 + +data: + train_dataloader: + batch_size: "$%vars::per_gpu_batch * %vars::num_gpus" +``` + +Or keep per-GPU batch size: + +```yaml +# Each GPU gets batch_size=32 +data: + train_dataloader: + batch_size: 32 +``` + +## Mixed Precision Training + +Use 16-bit precision for faster training: + +```yaml +trainer: + precision: 16 +``` + +Or BFloat16: + +```yaml +trainer: + precision: "bf16-mixed" +``` + +Automatic mixed precision (AMP) is handled by Lightning. + +## Gradient Accumulation + +Simulate larger batch sizes: + +```yaml +trainer: + accumulate_grad_batches: 4 +``` + +Effective batch size = `batch_size × accumulate_grad_batches`. + +Example: + +```yaml +# Effective batch size = 32 × 4 = 128 +data: + train_dataloader: + batch_size: 32 + +trainer: + accumulate_grad_batches: 4 +``` + +## Early Stopping + +Stop training when metric stops improving: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val/loss + patience: 10 + mode: min + verbose: true +``` + +Parameters: + +- `monitor`: Metric to track +- `patience`: Epochs to wait before stopping +- `mode`: `min` or `max` +- `min_delta`: Minimum change to qualify as improvement + +## Progress Bars + +### Default Progress Bar + +Shows by default. Disable with: + +```yaml +trainer: + enable_progress_bar: false +``` + +### Custom Progress Bar + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.RichProgressBar +``` + +Or: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.TQDMProgressBar + refresh_rate: 10 +``` + +## Validation + +### Validate Only + +Run validation on a checkpoint: + +```bash +lighter validate config.yaml --ckpt_path checkpoints/best.ckpt +``` + +### Validation Frequency + +Validate every N epochs: + +```yaml +trainer: + check_val_every_n_epoch: 5 +``` + +Or every N steps: + +```yaml +trainer: + val_check_interval: 0.5 # Validate twice per epoch +``` + +Or specific number of steps: + +```yaml +trainer: + val_check_interval: 100 # Every 100 training steps +``` + +### Skip Validation + +```yaml +trainer: + limit_val_batches: 0 # No validation +``` + +## Testing + +Run final test after training: + +```bash +# Fit then test automatically +lighter fit config.yaml + +# Test separately +lighter test config.yaml --ckpt_path checkpoints/best.ckpt +``` + +### Test During Fit + +Not recommended, but possible by loading checkpoint at end of fit. + +## Prediction/Inference + +Run inference on data: + +```bash +lighter predict config.yaml --ckpt_path checkpoints/best.ckpt +``` + +Requires: + +1. `predict_step` in your module +2. `predict_dataloader` in your data config +3. Optional: Writer callback to save results + +Example config: + +```yaml +data: + predict_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + dataset: + _target_: my_project.data.PredictionDataset + root: ./inference_data + +trainer: + callbacks: + - _target_: lighter.callbacks.FileWriter + write_interval: batch +``` + +Example predict_step: + +```python +def predict_step(self, batch, batch_idx): + images = batch + predictions = self(images) + + return { + "predictions": predictions.cpu(), + "batch_idx": batch_idx, + } +``` + +## Experiment Organization + +### Recommended Structure + +``` +my_project/ +├── __lighter__.py +├── models.py +├── data.py +├── configs/ +│ ├── base.yaml # Baseline config +│ ├── resnet50.yaml # Architecture variants +│ ├── augmented.yaml # Augmentation experiments +│ └── ablation/ +│ ├── no_dropout.yaml +│ └── no_batchnorm.yaml +└── outputs/ # Generated by Lighter + └── YYYY-MM-DD/ + └── HH-MM-SS/ +``` + +### Config Naming + +Use descriptive names: + +``` +configs/ +├── baseline-resnet18.yaml +├── baseline-resnet50.yaml +├── lr0.01-batch128.yaml +├── augment-strong.yaml +└── finetune-imagenet.yaml +``` + +### Version Control + +Track configs in git: + +```bash +git add configs/ +git commit -m "Add strong augmentation experiment" +``` + +Compare experiments: + +```bash +git diff configs/baseline.yaml configs/improved.yaml +``` + +## Common Workflows + +### Workflow 1: Hyperparameter Search + +Create configs for different hyperparameters: + +```bash +# Try different learning rates +lighter fit base.yaml model::optimizer::lr=0.001 +lighter fit base.yaml model::optimizer::lr=0.01 +lighter fit base.yaml model::optimizer::lr=0.1 + +# Try different architectures +lighter fit base.yaml model::network::_target_=torchvision.models.resnet18 +lighter fit base.yaml model::network::_target_=torchvision.models.resnet50 +lighter fit base.yaml model::network::_target_=torchvision.models.efficientnet_b0 +``` + +### Workflow 2: Resume Failed Training + +Training crashed? Resume: + +```bash +lighter fit config.yaml --ckpt_path outputs/2024-01-15/10-30-45/checkpoints/last.ckpt +``` + +### Workflow 3: Incremental Training + +Train, then finetune: + +```bash +# Initial training +lighter fit pretrain.yaml + +# Finetune with lower LR +lighter fit finetune.yaml \ + --ckpt_path outputs/.../checkpoints/last.ckpt \ + model::optimizer::lr=0.0001 +``` + +### Workflow 4: Cross-Validation + +Run multiple folds: + +```bash +for fold in {0..4}; do + lighter fit config.yaml data::fold=$fold +done +``` + +Config: + +```python +# data.py +class CVDataset(Dataset): + def __init__(self, root, fold, num_folds=5): + # Split data by fold + ... +``` + +## Output Management + +### Change Output Directory + +```yaml +# In config +trainer: + default_root_dir: ./my_outputs +``` + +Or CLI: + +```bash +lighter fit config.yaml trainer::default_root_dir=./my_outputs +``` + +### Disable Checkpoints + +```yaml +trainer: + enable_checkpointing: false +``` + +### Save Frequency + +Save less often: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_epochs: 10 # Save every 10 epochs +``` + +Or based on steps: + +```yaml +trainer: + callbacks: + - _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_train_steps: 1000 +``` + +## Troubleshooting + +### Out of Memory + +**Solutions:** + +1. Reduce batch size: + ```bash + lighter fit config.yaml data::train_dataloader::batch_size=16 + ``` + +2. Use gradient accumulation: + ```yaml + trainer: + accumulate_grad_batches: 4 + ``` + +3. Use mixed precision: + ```yaml + trainer: + precision: 16 + ``` + +4. Reduce model size: + ```bash + lighter fit config.yaml model::network::_target_=torchvision.models.resnet18 + ``` + +### Training Too Slow + +**Solutions:** + +1. Use more workers: + ```yaml + data: + train_dataloader: + num_workers: 8 + ``` + +2. Pin memory: + ```yaml + data: + train_dataloader: + pin_memory: true + ``` + +3. Use multiple GPUs: + ```yaml + trainer: + devices: 4 + strategy: ddp + ``` + +4. Mixed precision: + ```yaml + trainer: + precision: 16 + ``` + +### Model Not Learning + +**Debug steps:** + +1. Overfit on small batch: + ```bash + lighter fit config.yaml trainer::overfit_batches=10 + ``` + +2. Check learning rate: + ```bash + lighter fit config.yaml trainer::auto_lr_find=true + ``` + +3. Visualize data: + ```python + # In training_step + if batch_idx == 0: + self.logger.experiment.add_images("train/batch", x[:8]) + ``` + +4. Profile: + ```bash + lighter fit config.yaml trainer::profiler=simple + ``` + +## Next Steps + +- [Best Practices](best-practices.md) - Production patterns +- [Examples](../examples/image-classification.md) - Complete working examples +- [CLI Reference](../reference/cli.md) - Full command documentation + +## Quick Reference + +```bash +# Basic commands +lighter fit config.yaml +lighter validate config.yaml +lighter test config.yaml +lighter predict config.yaml + +# Override from CLI +lighter fit config.yaml key::path=value + +# Merge configs +lighter fit base.yaml,experiment.yaml + +# Resume training +lighter fit config.yaml --ckpt_path path/to/last.ckpt + +# Multi-GPU +lighter fit config.yaml trainer::devices=4 trainer::strategy=ddp + +# Debug +lighter fit config.yaml trainer::fast_dev_run=true +lighter fit config.yaml trainer::overfit_batches=10 +``` diff --git a/docs/how-to/adapters.md b/docs/how-to/adapters.md deleted file mode 100644 index 8dc85ce4..00000000 --- a/docs/how-to/adapters.md +++ /dev/null @@ -1,347 +0,0 @@ -# Adapters: Data Flow Control - -Adapters are the **secret sauce** that makes Lighter incredibly flexible. They act as intelligent translators between different components of your pipeline, ensuring data flows correctly regardless of format differences. - -## The Data Flow Problem - -In a typical training step, data flows sequentially: a **batch** from the dataloader is fed to the **model** to produce a **prediction**. This prediction, along with the **target** from the batch, is then used by the **loss function**, **metrics**, and **logger**. - -The problem is that each component might have different expectations for the data it receives: - -- A `Dataset` might return a dictionary of tensors, but the `model` only needs a specific tensor. -- A `loss function` might expect `(prediction, target)`, while another expects `(target, prediction)`. -- A `metric` might need class indices, but the model outputs probabilities. - -**Without adapters**, you would have to write glue code inside your model or training loop to handle these mismatches. This makes your components less reusable. - -**With Lighter's adapters**, you define these translations in your configuration, keeping your components clean and independent. - -## The Solution: Adapters - -Lighter provides four types of adapters, each designed to intercept and transform data at a specific point in the data flow: - -| Adapter | Purpose | When to Use | -|---|---|---| -| **BatchAdapter** | Extracts and formats data from the dataloader's batch. | When your dataset has a different structure than the expected `(input, target)`. | -| **CriterionAdapter** | Prepares `input`, `target`, and `pred` for the loss function. | When your loss function has specific argument or data format requirements. | -| **MetricsAdapter** | Prepares `input`, `target`, and `pred` for metrics. | When your metrics have specific argument or data format requirements. | -| **LoggingAdapter** | Prepares `input`, `target`, and `pred` for logging. | When you need to transform data for visualization. | - -Here are some common challenges and how adapters solve them. Don't worry about the details yet—we'll dive deeper into each adapter below.: - -| Component / Scenario | Common Challenge | Adapter Solution | -| :--- | :--- | :--- | -| **Batch** 📦 | Dataset returns a `dict` (e.g., `{"x": ..., "y": ...}`), but the pipeline needs `input` and `target`. | Use `BatchAdapter` to map dictionary keys to `input` and `target` (e.g., `input_accessor="x"`). | -| **Batch in Self-Supervision** | The training step has no `target`. | Use `BatchAdapter` and set `target_accessor` to `None`. | -| **Loss Function** 📉 | Loss function expects `loss(target, pred)`, but Lighter's default is `loss(pred, target)`. | Use `CriterionAdapter` to reorder arguments (e.g., `pred_argument=1`, `target_argument=0`). | -| **Metrics** 📏 | A metric needs class indices, but the model outputs probabilities. | Use `MetricsAdapter` with `pred_transforms` to apply `torch.argmax` before the metric calculation. | -| **Logging** 📊 | Logger expects RGB images, but the data is grayscale. | Use `LoggingAdapter` with `input_transforms` to repeat the channel dimension. | - -## Execution Order - -Understanding when each adapter is called is crucial for debugging and designing your pipeline. Here's the complete data flow during a training step: - -``` -Step 1: DataLoader produces batch - ↓ -Step 2: BatchAdapter - └─→ Extracts: (input, target, identifier) - ↓ -Step 3: Model.forward(input) or Inferer - └─→ Produces: prediction - ↓ -Step 4: CriterionAdapter (train/val modes only) - ├─→ Transforms: input, target, pred - └─→ Maps to loss function arguments - ↓ -Step 5: Loss function computes loss - ↓ -Step 6: MetricsAdapter (train/val/test modes) - ├─→ Transforms: input, target, pred - └─→ Maps to metrics arguments - ↓ -Step 7: Metrics compute values - ↓ -Step 8: LoggingAdapter - ├─→ Transforms: input, target, pred - └─→ Prepares for logger/callbacks - ↓ -Step 9: Logger and callbacks receive data -``` - -### Key Points - -- **BatchAdapter** runs first, before the model sees any data -- **CriterionAdapter** runs only in train and val modes (skipped in test/predict) -- **MetricsAdapter** runs only in train, val, and test modes (skipped in predict) -- **LoggingAdapter** runs last, after all computations are done -- Each mode (train/val/test/predict) can have its own set of adapters - -### Practical Implications - -1. **BatchAdapter transforms are executed on every batch**: Keep them lightweight, or move expensive operations to your Dataset's `__getitem__`. - -2. **CriterionAdapter and MetricsAdapter run after model forward**: You can safely apply post-processing like `argmax` or `sigmoid` here without affecting the model. - -3. **LoggingAdapter is for visualization only**: Transforms here don't affect training—use them for detaching tensors, converting to CPU, or formatting for display. - -4. **Mode-specific adapters**: You can configure different adapters for train vs val vs test: - ```yaml - system: - adapters: - train: - batch: ... - criterion: ... - val: - batch: ... # Can be different from train - criterion: ... - ``` - -For a deeper understanding of the complete System data flow, see [System Internals](../design/system.md#system-data-flow). - -## Deep Dive - -### BatchAdapter - -The `BatchAdapter` is responsible for extracting `input`, `target`, and an optional `identifier` from each batch produced by your `DataLoader`. This is the first and most common adapter you'll use. - -Lighter needs to know how to get the following from a batch: - -| Item | Description | -|-----------|-------------| -| Input | The data fed into the model (e.g., images, text). | -| Target **(optional)** | The ground truth data (e.g., labels, masks) used for loss calculation. | -| Identifier **(optional)** | A unique ID for the data point (e.g., filename, patient ID). | - -!!! note - By default, Lighter assumes the batch is an `(input, target)` tuple: - `BatchAdapter(input_accessor=0, target_accessor=1)` - -You can specify `accessors` to handle different batch structures: - -| Accessor Type | Description | Example | -| --- | --- | --- | -| **Integer Index** | Access elements by position in a list or tuple batch. | `input_accessor: 0` | -| **String Key** | Access elements by key in a dictionary batch. | `input_accessor: "image"` | -| **Callable** | Use a function for complex logic. | `target_accessor: $lambda batch: one_hot(batch[1])` | - -#### Example 1: Dictionary-based Dataset - -If your dataset returns a dictionary, you can map the keys to `input`, `target`, and `identifier`. - -```yaml -# Problem: Dataset returns a dict, but the model needs tensors. -system: - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: "image" - target_accessor: "mask" - identifier_accessor: "patient_id" -``` - -#### Example 2: Self-Supervised Learning - -In self-supervised learning, you might not have a `target`. You can set its accessor to `None`. - -```yaml -# Problem: No targets in this training phase. -system: - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: 0 - target_accessor: null # No targets! -``` - -For more details, see the [`BatchAdapter` documentation](../../reference/adapters/#lighter.adapters.BatchAdapter). - -### CriterionAdapter - -The `CriterionAdapter` acts as a bridge between your model's prediction and your loss function. It allows you to: - -1. **Map** `pred`, `target`, and `input` to the arguments of your loss function. -2. **Transform** these tensors before they are passed to the loss function. - -**Argument Mappers** can be: - -| Mapper Type | Description | Example | -| --- | --- | --- | -| **Integer Index** | Map to a positional argument. | `pred_argument: 1`, `target_argument: 0` | -| **String Key** | Map to a keyword argument. | `pred_argument: "prediction"` | -| **`None`** | Don't pass this tensor to the loss function. | `input_argument: None` | - -**Transforms** are functions applied to the tensors before mapping. - -| Transform Type | Description | Example | -| --- | --- | --- | -| **Callable** | A function or list of functions to apply. | `pred_transforms: [_target_: torch.sigmoid]` | - -#### Example: Custom Argument Order and Activation - -If your loss function expects `(target, pred)` and requires sigmoid on the predictions: - -```yaml -# Problem: Loss function has a non-standard signature and needs activated predictions. -system: - adapters: - train: - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_argument: 1 # Pass 'pred' as the 2nd argument - target_argument: 0 # Pass 'target' as the 1st argument - pred_transforms: - - _target_: torch.sigmoid -``` - -For more details, see the [`CriterionAdapter` documentation](../../reference/adapters/#lighter.adapters.CriterionAdapter). - -### MetricsAdapter - -The `MetricsAdapter` is identical in configuration to the `CriterionAdapter`, but for your metrics. You can use it to map and transform `pred`, `target`, and `input` before they are fed into your `torchmetrics` functions. - -#### Example: Preparing Predictions for a Metric - -If a metric requires class indices from your model's output probabilities: - -```yaml -# Problem: Metric expects class indices, not probabilities. -system: - adapters: - val: - metrics: - _target_: lighter.adapters.MetricsAdapter - pred_argument: "preds" - target_argument: "target" - pred_transforms: - - _target_: torch.argmax - dim: 1 -``` - -For more details, see the [`MetricsAdapter` documentation](../../reference/adapters/#lighter.adapters.MetricsAdapter). - -### LoggingAdapter - -The `LoggingAdapter` is used to transform `input`, `target`, and `pred` tensors just before they are sent to the logger (e.g., for image visualization in TensorBoard). It only supports transforms, not argument mapping. - -#### Example: Visualizing Grayscale Images - -If you want to log a single-channel image and your logger expects a 3-channel RGB image: - -```yaml -# Problem: Logger expects a 3-channel image, but data is grayscale. -system: - adapters: - train: - logging: - _target_: lighter.adapters.LoggingAdapter - input_transforms: "$lambda x: x.repeat(1, 3, 1, 1)" # Convert to 3-channel -``` - -For more details, see the [`LoggingAdapter` documentation](../../reference/adapters/#lighter.adapters.LoggingAdapter). - -## Advanced Usage - -### Custom Adapters - -For highly complex scenarios, you can implement your own adapter by inheriting from Lighter's base adapters. - -```python -# my_project/adapters.py -from lighter.adapters import BatchAdapter - -class MultiModalBatchAdapter(BatchAdapter): - """A custom adapter for multi-modal data.""" - def __call__(self, batch): - # Custom logic to unpack a complex batch - return { - "image": batch["image_data"], - "text": batch["text_embeddings"], - "tabular": batch["clinical_features"] - }, batch["diagnosis"], batch["id"] -``` - -Then, use it in your config: - -```yaml -system: - adapters: - train: - batch: - _target_: my_project.adapters.MultiModalBatchAdapter -``` - -## Tips and Troubleshooting - -| Tip / Issue | Solution | -|---|---| -| **Performance** | Put expensive transforms in your `Dataset`, not in adapters, to leverage multi-worker data loading. | -| **Debugging** | Add a print transform (`$lambda x: print(x.shape) or x`) to inspect tensor shapes at any point in the flow. | -| **`KeyError` in batch** | Your `input_accessor` or `target_accessor` might be wrong. Use a print transform to inspect the batch keys/indices. | -| **Wrong argument order** | Double-check the signature of your loss/metric function and use named arguments in `CriterionAdapter` or `MetricsAdapter` for clarity. | -| **Reusing configs** | Use raw references (`%`) to reuse adapters: `val: "%system::adapters::train::batch"` creates a new instance with the same config. | - -## Complete Example: Segmentation Pipeline - -Here’s how adapters work together in a complete segmentation pipeline: - - -```yaml -system: - adapters: - train: - # 1. Extract 'image' and 'mask' from the batch dictionary. - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: "image" - target_accessor: "mask" - identifier_accessor: "patient_id" - - # 2. Pass prediction and target to the loss function and apply softmax. - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_argument: 0 - target_argument: 1 - pred_transforms: - - _target_: torch.nn.functional.softmax - dim: 1 - - # 3. Pass prediction and target to metrics by name and apply argmax. - metrics: - _target_: lighter.adapters.MetricsAdapter - pred_argument: "preds" - target_argument: "target" - pred_transforms: - - _target_: torch.argmax - dim: 1 - - # 4. Reuse the same batch adapter for validation. - val: - batch: "%system::adapters::train::batch" - # Criterion and Metrics adapters for 'val' would also be defined here. -``` - -## Recap and Next Steps - -Adapters are what make Lighter truly flexible: - -✅ **Key Benefits:** - -- Handle any data format without code changes -- Connect incompatible components seamlessly -- Transform data at the right pipeline stage -- Debug and monitor data flow easily - -🎯 **Best Practices:** - -- Keep transforms simple and composable -- Move expensive operations to datasets -- Use debug prints during development -- Reuse configurations with YAML anchors - -## Related Guides -- [Metrics](metrics.md) - Using MetricsAdapter -- [Writers](writers.md) - Using LoggingAdapter -- [Inferers](inferers.md) - Inference-time adaptation diff --git a/docs/how-to/configuration.md b/docs/how-to/configuration.md deleted file mode 100644 index 95dbe824..00000000 --- a/docs/how-to/configuration.md +++ /dev/null @@ -1,338 +0,0 @@ ---- -title: Configuration Reference ---- - -# Configuration Reference - -Lighter uses **[Sparkwheel](https://project-lighter.github.io/sparkwheel/)** for configuration—a powerful YAML-based system supporting references, expressions, and object instantiation. - -!!! tip "Complete Documentation" - This page covers Lighter-specific patterns and common usage. For complete Sparkwheel syntax, advanced features, and detailed examples, see the **[Sparkwheel documentation](https://project-lighter.github.io/sparkwheel/)**. - -## Quick Reference - -| Symbol | Purpose | Sparkwheel Docs | -|--------|---------|-----------------| -| `_target_` | Instantiate a class | [Instantiation](https://project-lighter.github.io/sparkwheel/user-guide/instantiation/) | -| `@path::to::value` | Resolved reference (instantiated object) | [References](https://project-lighter.github.io/sparkwheel/user-guide/references/) | -| `%path::to::value` | Raw reference (unprocessed YAML) | [References](https://project-lighter.github.io/sparkwheel/user-guide/references/) | -| `$expression` | Evaluate Python expression | [Expressions](https://project-lighter.github.io/sparkwheel/user-guide/expressions/) | -| `::` | Path notation (navigate config) | [Basics](https://project-lighter.github.io/sparkwheel/user-guide/basics/) | -| `.` | Access Python attributes | [Expressions](https://project-lighter.github.io/sparkwheel/user-guide/expressions/) | -| `=key:` | Replace operator (override merge) | [Operators](https://project-lighter.github.io/sparkwheel/user-guide/operators/) | -| `~key:` | Delete operator | [Operators](https://project-lighter.github.io/sparkwheel/user-guide/operators/) | - -## Lighter Configuration Structure - -Every Lighter config has two mandatory sections: - -```yaml -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - model: ... - criterion: ... - optimizer: ... - dataloaders: ... -``` - -Optional sections: -```yaml -_requires_: # Import Python modules -project: ./path # Custom module directory -vars: ... # Variables for reuse -args: ... # Stage-specific arguments (fit, test, etc.) -``` - -## Essential Syntax - -### `_target_`: Instantiate Classes - -```yaml -model: - _target_: torchvision.models.resnet18 - num_classes: 10 -``` - -**Equivalent to:** `torchvision.models.resnet18(num_classes=10)` - -[Learn more →](https://project-lighter.github.io/sparkwheel/user-guide/instantiation/) - -### `@` and `%`: References - -| Type | Syntax | Use Case | -|------|--------|----------| -| **Resolved** (`@`) | `@system::optimizer` | Pass actual object instances | -| **Raw** (`%`) | `%system::metrics::train` | Reuse config to create new instances | - -**Example:** -```yaml -scheduler: - _target_: torch.optim.lr_scheduler.StepLR - optimizer: "@system::optimizer" # Resolved: actual optimizer object - -metrics: - train: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - val: "%system::metrics::train" # Raw: creates new instance -``` - -[Learn more →](https://project-lighter.github.io/sparkwheel/user-guide/references/) - -### `$`: Expressions - -Evaluate Python in configs: - -```yaml -optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" # Call model.parameters() - lr: "$0.001 * 2" # Result: 0.002 -``` - -[Learn more →](https://project-lighter.github.io/sparkwheel/user-guide/expressions/) - -### `::`: Path Notation - -Navigate nested configs: - -```yaml -@system::model # Access model -@system::optimizer::lr # Access nested value -%::train::batch_size # Relative reference (sibling) -``` - -[Learn more →](https://project-lighter.github.io/sparkwheel/user-guide/basics/) - -## CLI Overrides - -Override any config value from command line: - -```bash -# Simple override -lighter fit config.yaml trainer::max_epochs=100 - -# Nested values -lighter fit config.yaml system::optimizer::lr=0.001 - -# Multiple overrides -lighter fit config.yaml \ - trainer::max_epochs=100 \ - system::optimizer::lr=0.001 \ - trainer::devices=4 -``` - -[Learn more →](https://project-lighter.github.io/sparkwheel/user-guide/cli/) - -## Merging Configs - -Combine multiple YAML files for modular experiments: - -```bash -# Merge base + experiment -lighter fit base.yaml,experiment.yaml - -# Compose from modules -lighter fit base.yaml,models/resnet.yaml,data/cifar10.yaml -``` - -### Default Merging Behavior - -**Dictionaries merge recursively:** -```yaml -# base.yaml -trainer: - max_epochs: 10 - devices: 1 - -# experiment.yaml -trainer: - max_epochs: 100 # Overrides - accelerator: gpu # Adds - -# Result: max_epochs=100, devices=1, accelerator=gpu -``` - -**Lists extend (append):** -```yaml -# base.yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - -# experiment.yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.EarlyStopping - -# Result: Both callbacks present -``` - -### Override Merging: `=` and `~` - -**Replace with `=`:** -```yaml -# experiment.yaml -trainer: - =callbacks: # Replace instead of extend - - _target_: pytorch_lightning.callbacks.RichProgressBar -``` - -**Delete with `~`:** -```yaml -# Delete entire key -trainer: - ~callbacks: null - -# Delete list items -trainer: - ~callbacks: [1, 3] # Delete indices 1 and 3 - -# Delete dict keys -system: - ~dataloaders: ["train", "test"] -``` - -[Complete merging reference →](https://project-lighter.github.io/sparkwheel/user-guide/operators/) - -## Common Lighter Patterns - -### Pattern 1: Model → Optimizer - -```yaml -system: - model: - _target_: torchvision.models.resnet18 - num_classes: 10 - - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 -``` - -### Pattern 2: Optimizer → Scheduler - -```yaml -system: - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - - scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - optimizer: "@system::optimizer" - factor: 0.5 -``` - -### Pattern 3: Reusing Configurations - -```yaml -system: - metrics: - train: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - val: "%system::metrics::train" # Reuse - - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 128 - num_workers: 4 - val: - _target_: torch.utils.data.DataLoader - batch_size: "%::train::batch_size" # Relative reference - num_workers: "%::train::num_workers" -``` - -### Pattern 4: Variables for Reuse - -```yaml -vars: - batch_size: 32 - num_classes: 10 - base_lr: 0.001 - -system: - model: - _target_: torchvision.models.resnet18 - num_classes: "%vars::num_classes" - - optimizer: - lr: "%vars::base_lr" - - dataloaders: - train: - batch_size: "%vars::batch_size" -``` - -### Pattern 5: Stage-Specific Arguments - -```yaml -args: - fit: - ckpt_path: null # Start from scratch - test: - ckpt_path: "checkpoints/best.ckpt" - predict: - ckpt_path: "checkpoints/best.ckpt" - return_predictions: true -``` - -**Override from CLI:** -```bash -lighter test config.yaml args::test::ckpt_path="other.ckpt" -``` - -## Common Pitfalls - -### 1. Resolved vs Raw Reference - -```yaml -# ❌ Wrong: Shares same instance -metrics: - val: "@system::metrics::train" - -# ✅ Correct: Creates new instance -metrics: - val: "%system::metrics::train" -``` - -### 2. Path Notation vs Python Attributes - -```yaml -# ❌ Wrong: :: is for config paths -params: "$@system::model::parameters()" - -# ✅ Correct: . is for Python attributes -params: "$@system::model.parameters()" -``` - -### 3. Missing $ for Expressions - -```yaml -# ❌ Wrong: Treated as string -batch_size: "@vars::base_batch * 2" - -# ✅ Correct: Evaluated -batch_size: "$%vars::base_batch * 2" -``` - -## Advanced Features - -Refer to **[Sparkwheel documentation](https://project-lighter.github.io/sparkwheel/)** for advanced usage. - -## Next Steps - -- [Running Experiments](run.md) - Execute training, testing, prediction -- [Configuration Recipes](recipes.md) - Ready-to-use patterns -- [Troubleshooting](troubleshooting.md) - Debug config errors -- [Sparkwheel Documentation](https://project-lighter.github.io/sparkwheel/) - Complete reference diff --git a/docs/how-to/experiment_tracking.md b/docs/how-to/experiment_tracking.md deleted file mode 100644 index ffd2c497..00000000 --- a/docs/how-to/experiment_tracking.md +++ /dev/null @@ -1,421 +0,0 @@ -# Experiment Tracking - -Lighter provides comprehensive experiment tracking through PyTorch Lightning loggers and Writer callbacks. This guide covers what's logged automatically, how to configure loggers, and how to add custom logging. - -## What Lighter Logs Automatically - -Lighter's System class automatically logs the following metrics without any configuration: - -### 1. Loss Values - -**Logged in**: train and val modes (test mode doesn't compute loss by default) - -``` -{mode}/loss/step # Per-batch loss -{mode}/loss/epoch # Epoch-averaged loss -``` - -For dict-based losses (multi-task learning), all sublosses are logged: - -``` -{mode}/loss/total/step -{mode}/loss/total/epoch -{mode}/loss/classification/step -{mode}/loss/classification/epoch -{mode}/loss/segmentation/step -{mode}/loss/segmentation/epoch -``` - -**Code reference**: `src/lighter/system.py:194-202` - -### 2. Metrics - -**Logged in**: train, val, and test modes (predict mode doesn't compute metrics) - -``` -{mode}/metrics/{metric_name}/step -{mode}/metrics/{metric_name}/epoch -``` - -All metrics defined in your config are automatically logged with both step-level and epoch-level aggregation. - -**Code reference**: `src/lighter/system.py:205-208` - -### 3. Optimizer Statistics - -**Logged in**: train mode only, **once per epoch** (at the beginning) - -``` -train/lr # Learning rate -train/momentum # If using SGD with momentum -train/beta1 # If using Adam/AdamW -train/beta2 # If using Adam/AdamW -``` - -This automatic logging helps you track learning rate schedules and optimizer behavior without any additional configuration. - -**Code reference**: `src/lighter/system.py:210-213` - -### 4. Hyperparameters - -The Runner automatically logs all configuration parameters (the entire YAML config) to the logger at the start of training. This ensures full reproducibility. - -**Code reference**: `src/lighter/engine/runner.py:143-146` - -## Logger Configuration - -Lighter uses PyTorch Lightning's logger system. You can configure any Lightning-compatible logger in your config. - -### TensorBoard (Built-in) - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - name: my_experiment - version: null # Auto-incrementing version -``` - -```bash -# View logs -tensorboard --logdir logs -``` - -## Weights & Biases - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: my_project - name: experiment_name - save_dir: logs -``` - -## MLflow - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.MLFlowLogger - experiment_name: my_experiment - tracking_uri: file:./mlruns -``` - -## CSV Logger - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.CSVLogger - save_dir: logs - name: my_experiment -``` - -## Multiple Loggers - -```yaml -trainer: - logger: - - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - - _target_: pytorch_lightning.loggers.WandbLogger - project: my_project -``` - -For advanced logging features, see [PyTorch Lightning Logger docs](https://lightning.ai/docs/pytorch/stable/extensions/logging.html). - -## Writer Callbacks for Predictions - -While loggers handle scalar metrics, Writer callbacks save predictions, inputs, and targets to disk. Writers are triggered automatically after each batch in val, test, and predict modes. - -### Available Writers - -Lighter provides several built-in writers in `lighter.callbacks.writer`: - -- **FileWriter**: Save individual files (images, arrays) -- **TableWriter**: Save predictions in tabular format (CSV, Parquet) - -For detailed usage, see the [Writers Guide](writers.md). - -### When Writers Are Triggered - -Writers are **batch-level callbacks** that run after each batch in val/test/predict modes: - -``` -For each batch in validation/test/predict: - 1. System._step() computes predictions - 2. Output dict is returned to callbacks - 3. Writer callbacks process the dict - 4. Files are written to disk -``` - -This allows you to save all predictions without accumulating them in memory. - -### Writer Memory Management - -Writers actively clear predictions from the output dictionary after processing to save CPU memory. This is especially important for large-scale inference. - -**Code reference**: `src/lighter/callbacks/writer.py:141-143` - -## Custom Logging Strategies - -### Strategy 1: Logging Additional Scalars - -To log custom values, extend the System class and override `_log_stats`: - -```python title="my_project/custom_system.py" -from lighter.system import System - -class CustomSystem(System): - def _log_stats(self, loss, metrics, batch_idx): - # Call parent to log standard metrics - super()._log_stats(loss, metrics, batch_idx) - - # Log custom values - if self.mode == "train": - # Example: Log gradient norms - grad_norm = self._compute_gradient_norm() - self.log(f"{self.mode}/grad_norm", grad_norm) - - # Example: Log model statistics - self.log(f"{self.mode}/model_mean_weight", - self.model.fc.weight.mean()) - - def _compute_gradient_norm(self): - total_norm = 0.0 - for p in self.model.parameters(): - if p.grad is not None: - total_norm += p.grad.data.norm(2).item() ** 2 - return total_norm ** 0.5 -``` - -Use in config: - -```yaml -system: - _target_: my_project.custom_system.CustomSystem - model: ... - # ... other components -``` - -### Strategy 2: Conditional Logging - -Log different metrics based on mode or epoch: - -```python title="my_project/conditional_system.py" -from lighter.system import System - -class ConditionalSystem(System): - def _log_stats(self, loss, metrics, batch_idx): - super()._log_stats(loss, metrics, batch_idx) - - # Only log expensive metrics every N epochs - if self.current_epoch % 10 == 0: - if self.mode == "val": - self.log(f"{self.mode}/expensive_metric", - self._compute_expensive_metric()) - - def _compute_expensive_metric(self): - # Expensive computation here - pass -``` - -### Strategy 3: Logging Images/Media - -For image logging, use the logger directly: - -```python title="my_project/vision_system.py" -from lighter.system import System -import torch - -class VisionSystem(System): - def on_validation_epoch_end(self): - # Log sample predictions as images - if self.trainer.logger is not None: - sample_input = self.validation_samples[:8] # 8 images - sample_pred = self.model(sample_input) - - # Log to TensorBoard - if hasattr(self.trainer.logger.experiment, 'add_images'): - self.trainer.logger.experiment.add_images( - 'val/predictions', - sample_pred, - self.current_epoch - ) -``` - -## Integration: System, Logger, and Writers - -Understanding how these components work together: - -``` -Training Loop (System) - ↓ -Automatic Logging (System._log_stats) - ├─→ Logger receives scalar metrics - └─→ TensorBoard/W&B/MLflow displays them - ↓ -Step Output (System._step returns dict) - ↓ -Callbacks (Writers) - └─→ Save predictions/inputs/targets to files -``` - -### Example: Complete Tracking Setup - -```yaml title="complete_tracking.yaml" -trainer: - _target_: pytorch_lightning.Trainer - - # Multiple loggers for comprehensive tracking - logger: - - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - name: my_experiment - version: null - - - _target_: pytorch_lightning.loggers.WandbLogger - project: my_project - name: my_experiment - save_dir: logs - log_model: true - - # Logging frequency - log_every_n_steps: 50 - - # Callbacks for saving predictions - callbacks: - - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: epoch - - - _target_: lighter.callbacks.writer.FileWriter - write_dir: predictions - predicates: ["val", "test"] - -system: - _target_: lighter.System - # ... rest of config -``` - -## Advanced: Logger-Specific Features - -### TensorBoard: Hyperparameter Tuning - -TensorBoard can visualize hyperparameter search results: - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - name: hparam_search - default_hp_metric: true # Enable HP tracking - -# Vary hyperparameters across runs -system: - optimizer: - lr: 0.001 # Change this across experiments -``` - -View in TensorBoard: -```bash -tensorboard --logdir logs -# Navigate to HPARAMS tab -``` - -### W&B: Artifact Tracking - -Track model checkpoints as W&B artifacts: - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: my_project - log_model: all # 'all', True, False - - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: checkpoints - save_top_k: 3 -``` - -### MLflow: Model Registry - -Integrate with MLflow's model registry: - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.MLFlowLogger - experiment_name: my_experiment - tracking_uri: file:./mlruns - log_model: true # Log as MLflow model -``` - -## Troubleshooting - -### Issue: No logs appearing - -**Solution**: Check that logger is not None: - -```yaml -trainer: - logger: ... # Make sure this is configured -``` - -### Issue: Metrics not syncing across GPUs - -**Solution**: Lighter automatically sets `sync_dist=True` for epoch-level metrics. For custom metrics, ensure you use `on_epoch=True`: - -```python -self.log("custom_metric", value, on_epoch=True, sync_dist=True) -``` - -### Issue: Too much logging slowing down training - -**Solution**: Reduce logging frequency: - -```yaml -trainer: - log_every_n_steps: 100 # Default is 50 -``` - -### Issue: Writer memory usage too high - -**Solution**: Writers automatically clear predictions. If still high, process predictions in smaller batches: - -```yaml -system: - dataloaders: - predict: - batch_size: 16 # Reduce batch size -``` - -## Best Practices - -1. **Use multiple loggers**: Combine TensorBoard (local) with W&B/MLflow (team) -2. **Log hyperparameters**: Automatic in Lighter, but verify they appear in your logger -3. **Monitor optimizer stats**: Use `LearningRateMonitor` callback for detailed tracking -4. **Separate concerns**: Use Logger for scalars, Writers for predictions -5. **Version your experiments**: Use timestamps or version numbers in logger names -6. **Document runs**: Add notes/tags to experiments in W&B or MLflow - -## Summary - -Lighter provides comprehensive tracking out of the box: - -- **Automatic**: Loss, metrics, optimizer stats, hyperparameters -- **Flexible**: Support for all PyTorch Lightning loggers -- **Scalable**: Batch-level Writers prevent memory issues -- **Extensible**: Easy to add custom logging via System subclassing - -For most use cases, you just need to configure a logger—everything else is automatic! - -## Related Guides -- [System Internals](../design/system.md) - Understanding automatic logging -- [Writers](writers.md) - Saving predictions to disk -- [Configuration Guide](configuration.md) - Logger configuration syntax -- [Run Guide](run.md) - Running experiments diff --git a/docs/how-to/freezers.md b/docs/how-to/freezers.md deleted file mode 100644 index 6131f355..00000000 --- a/docs/how-to/freezers.md +++ /dev/null @@ -1,228 +0,0 @@ -# Freezers: Smart Layer Management for Transfer Learning - -Freezing layers is a powerful technique that can accelerate training, prevent catastrophic forgetting, and improve model performance. Lighter's `Freezer` callback gives you fine-grained control over which layers train and when. - -## Quick Start 🚀 - -```yaml title="config.yaml" -# Freeze encoder for first 10 epochs, then unfreeze -trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.encoder"] # What to freeze - until_epoch: 10 # When to unfreeze -``` - -## Why Freeze Layers? 🤔 - -| Scenario | Strategy | Benefit | -|----------|----------|---------| -| **Transfer Learning** | Freeze pretrained layers initially | Preserve learned features | -| **Limited Data** | Freeze most layers | Prevent overfitting | -| **Fine-tuning** | Gradual unfreezing | Stable adaptation | -| **Multi-stage Training** | Stage-wise freezing | Focused learning | - -**Freezing Strategies**: - -`Freezer` callback offers flexible layer freezing strategies: - -1. **Freeze by Name Prefix (`name_starts_with`)**: - - * Freeze parameters with names starting with prefix/prefixes in `name_starts_with` arg. - * Useful for freezing modules or layer groups. - * **Example**: - - ```yaml title="config.yaml" - trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.encoder", "model.embedding"] # Freeze encoder/embedding layers - until_epoch: 5 - ``` - - Config freezes parameters starting with `"model.encoder"` or `"model.embedding"` until epoch 5. - -2. **Freeze by Exact Name (`names`)**: - - * Freeze specific parameters by name using `names` arg. - * For fine-grained control over individual layers/parameters. - * **Example**: - - ```yaml title="config.yaml" - trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - names: ["model.classifier.weight", "model.classifier.bias"] # Freeze classifier layer weights/bias - until_step: 1000 - ``` - - Config freezes parameters named `"model.classifier.weight"` and `"model.classifier.bias"` until step 1000. - -3. **Exclude Layers from Freezing (`except_names`, `except_name_starts_with`)**: - - * Exclude layers from freezing (even if matched by `name_starts_with` or `names`) using `except_names`/`except_name_starts_with`. - * Selectively unfreeze parts of otherwise frozen module. - * **Example**: - - ```yaml title="config.yaml" - trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.encoder"] # Freeze all encoder layers - except_name_starts_with: ["model.encoder.layer5"] # Except "model.encoder.layer5" layers - until_epoch: 7 - ``` - - Config freezes `"model.encoder"` layers except `"model.encoder.layer5"`, keeping layer5 trainable. - -4. **Unfreezing after Condition (`until_step`, `until_epoch`)**: - - * `until_step`: Unfreeze layers after training step. - * `until_epoch`: Unfreeze layers after epoch. - * Use either/both `until_step`/`until_epoch`. Unfreezes when either condition met. - * Omit `until_step`/`until_epoch` to freeze layers for entire training (or manual unfreezing). - * **Example**: - - ```yaml title="config.yaml" - trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.backbone"] - until_epoch: 5 # Unfreeze after epoch 5 - until_step: 5000 # OR after step 5000 - ``` - - Config unfreezes `"model.backbone"` layers after epoch 5 OR step 5000 (whichever first). - -**Combining Freezing Strategies**: - -Combine `Freezer` callbacks in `config.yaml` for complex freezing schedules. E.g., initial backbone freeze, gradual part unfreezing. - -**Example: Gradual Layer Unfreezing** - -```yaml title="config.yaml" -trainer: - callbacks: - # Unfreeze backbone layers at epoch 5 - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.backbone"] - until_epoch: 5 - # Unfreeze early encoder layers at epoch 10 - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.encoder.layer1", "model.encoder.layer2"] - until_epoch: 10 -``` - -Example: 2 `Freezer` callbacks for gradual unfreezing - initial backbone freeze, gradual encoder layer unfreezing. - -**Inspecting Frozen Layers**: - -`Freezer` callback logs freezing info during training. Check logs (TensorBoard/console) to verify. - -**`Freezer` Callback Use Cases**: - -* **Transfer Learning**: Freeze pre-trained model's early layers, train head. -* **Fine-tuning**: Gradually unfreeze pre-trained layers. -* **Training Stability**: Initial layer freezing. -* **Regularization**: Layer freezing for regularization. -* **Efficient Training**: Reduce training time/memory. - -## Practical Example: Transfer Learning - -```yaml -trainer: - callbacks: - # Stage 1: Train only the classifier head - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.backbone"] # Freeze pretrained backbone - until_epoch: 5 - - # Stage 2: Fine-tune top layers - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.backbone.layer1", "model.backbone.layer2"] - until_epoch: 10 # Keep early layers frozen longer -``` - -This progressive unfreezing strategy: -1. **Epochs 0-5**: Only train the classifier head -2. **Epochs 5-10**: Unfreeze and train top backbone layers -3. **Epochs 10+**: Train entire model - -## Advanced Pattern: Discriminative Learning Rates - -```yaml -# Combine freezing with different learning rates -trainer: - callbacks: - - _target_: lighter.callbacks.Freezer - name_starts_with: ["model.backbone"] - until_epoch: 5 - -system: - optimizer: - _target_: torch.optim.Adam - params: - - params: "$[p for n, p in @system::model.named_parameters() if 'backbone' in n]" - lr: 0.0001 # Lower LR for pretrained layers - - params: "$[p for n, p in @system::model.named_parameters() if 'head' in n]" - lr: 0.001 # Higher LR for new layers -``` - -## Troubleshooting Guide 🔧 - -| Issue | Solution | -|-------|----------| -| **Parameters not freezing** | Print parameter names to verify: `for n, p in model.named_parameters(): print(n, p.requires_grad)` | -| **Performance drops after unfreezing** | Reduce LR with scheduler at unfreeze epoch | -| **BatchNorm issues** | Keep BN layers in eval mode even when unfrozen | -| **Memory increases** | Use gradient checkpointing or accumulation | -| **Don't know what to freeze** | Print model structure with `model.named_children()` | - -## Best Practices 🏆 - -1. **Start Conservative**: Freeze more layers initially, then gradually unfreeze -2. **Monitor Metrics**: Track validation loss when layers unfreeze -3. **Use Warmup**: Apply learning rate warmup after unfreezing -4. **Reduce LR**: Lower learning rate when unfreezing pretrained layers -5. **Test First**: Verify which layers to freeze by printing model structure - -## Quick Reference Card 📄 - -```yaml -# Freeze by prefix -name_starts_with: ["model.encoder", "model.embeddings"] - -# Freeze specific layers -names: ["model.layer1.weight", "model.layer1.bias"] - -# Exclude from freezing -except_names: ["model.encoder.final_layer.weight"] -except_name_starts_with: ["model.encoder.norm"] - -# Unfreeze timing -until_epoch: 10 # After epoch 10 -until_step: 1000 # After step 1000 -# Both: unfreeze when EITHER condition is met -``` - -## Recap and Next Steps - -✅ **You've Mastered:** - -- Strategic layer freezing for transfer learning -- Progressive unfreezing techniques -- Troubleshooting common freezing issues -- Best practices for stable training - -🎯 **Key Insights:** - -- Freezing preserves pretrained knowledge -- Gradual unfreezing prevents catastrophic forgetting -- Monitor performance when changing freeze status -- Combine with appropriate learning rates - -💡 **Pro Tip:** Log which layers are frozen/unfrozen at each epoch for reproducibility! - -## Related Guides -- [Run Guide](run.md) - Training workflows -- [Configuration](configuration.md) - Advanced config patterns diff --git a/docs/how-to/inferers.md b/docs/how-to/inferers.md deleted file mode 100644 index 77012b64..00000000 --- a/docs/how-to/inferers.md +++ /dev/null @@ -1,297 +0,0 @@ -# Inferers: Bridging Training and Inference - -Inferers adapt how your model processes data at inference time based on how it was trained and what your deployment needs are. Lighter's inferer system handles these adaptations seamlessly. - -## Why Inferers Matter 🎯 - -The way you train a model often differs from how you need to use it: - -| Training Scenario | Inference Challenge | Solution (Inferer) | -|------------------|--------------------|--------------------| -| Trained on fixed-size patches | Need to process full images | **Sliding Window** - breaks large images into patches | -| Trained on clean data | Noisy test data | **Test-Time Augmentation** - average multiple predictions | -| Single model | Need confidence scores | **Monte Carlo Dropout** - uncertainty estimation | -| Multiple models trained | Want best performance | **Ensemble** - combine predictions | - -## Common Inference Patterns - -### Sliding Window Inference -**When to use:** Your model was trained on fixed-size patches but you need to process larger images/volumes - -**Why it works:** The model only knows how to process the patch size it was trained on - -**Example:** A model trained on 256×256 patches from CT scans needs to process full 512×512×400 volumes - -### Test-Time Augmentation (TTA) -**When to use:** You want more robust predictions and can afford extra compute time - -**Why it works:** Averaging predictions from augmented inputs reduces noise - -**Example:** Medical image segmentation where rotation/flip invariance improves boundaries - -### Monte Carlo Dropout -**When to use:** You need uncertainty estimates along with predictions - -**Why it works:** Dropout at inference creates an ensemble effect - -**Example:** Medical diagnosis where you need to know prediction confidence - -## Configuring Inferers in Lighter - -You can configure any custom inferer within the `system.inferer` section of your `config.yaml` file. The inferer can be any callable that takes the model and input, then returns predictions. - -#### Configuration - -Here's an example of a basic inferer configuration: - -```yaml title="config.yaml" -system: - inferer: - _target_: my_project.inferers.CustomInferer - # Add any arguments your inferer needs -``` - -When an inferer is configured, Lighter automatically uses it during the `forward` pass in validation, test, and predict modes. - -## Implementing a Custom Inferer - -You can implement custom inference logic to handle: - -* **Advanced Ensembling Strategies:** Implementing ensembling techniques beyond simple averaging. -* **Sliding Window Inference:** Processing large images in patches. -* **Test-Time Augmentation:** Averaging predictions across augmentations. -* **Highly Specialized Output Processing:** Tailoring output processing to your unique research problem. - -To implement a custom inferer in Lighter, you'll create a Python class that adheres to a specific structure. - -### Custom Inferer Class Structure - -```python title="my_project/inferers/my_custom_inferer.py" -from typing import Any - -import torch -from torch.nn import Module - -class MyCustomInferer: - def __init__(self, arg1, arg2, **kwargs): - """ - Initialize your custom inferer. - - Args: - arg1: Custom argument 1. - arg2: Custom argument 2. - **kwargs: Additional keyword arguments. - """ - self.arg1 = arg1 - self.arg2 = arg2 - #... initialize any internal components... - - def __call__(self, inputs: torch.Tensor, network: Module, *args: Any, **kwargs: Any) -> torch.Tensor: - """ - Perform inference using your custom logic. - - Args: - inputs: Input tensor(s) to the model. - network: The deep learning model (torch.nn.Module). - *args: Additional positional arguments (if needed). - **kwargs: Additional keyword arguments (if needed). - - Returns: - torch.Tensor: The processed prediction tensor(s). - """ - # Implement your custom inference logic here - # This could include: - # - Test-time augmentation - # - Model ensembling - # - Sliding window or patch-based inference - # - Any other custom processing - - # Example: Simple forward pass with optional post-processing - outputs = network(inputs, *args, **kwargs) - processed_outputs = self.post_process(outputs) - return processed_outputs - - def post_process(self, outputs: torch.Tensor) -> torch.Tensor: - """ - Optional post-processing of model outputs. - - Args: - outputs (torch.Tensor): Raw model output tensor(s). - - Returns: - torch.Tensor: Processed output tensor(s). - """ - # Implement post-processing logic if needed (e.g., thresholding, softmax) - return outputs -``` - -### Key Components - -1. **`__init__`:** - * This is the constructor of your inferer class. - * It takes any custom arguments that you can define in your `config.yaml`. - * Use this method to initialize any internal components or parameters your inferer needs. - -2. **`__call__`:** - * This method makes your class callable like a function, enabling it to be used directly for inference. - * **Arguments:** - * `inputs (torch.Tensor)`: The input tensor(s) to your model. - * `network (torch.nn.Module)`: Your deep learning model (equivalent to `self.model` in your `System`). - * `*args`, `**kwargs`: These allow you to pass additional arguments if required, although they are not typically used in inferers. - * **Logic:** - * This is where you implement your core inference logic. - * A common pattern is to perform a forward pass through your `network` using `outputs = network(inputs)`. - * You can integrate various inference techniques here, such as TTA, ensembling, or sliding window inference. - * You can also call a `post_process` method to further refine the model's raw outputs. - * **Return Value:** - * This method must return the processed prediction tensor(s) as a `torch.Tensor`. This output will be used as the `pred` value in your validation, testing, or prediction steps. - -3. **`post_process` (Optional):** - * This is an optional method for applying post-processing operations to the model's raw outputs. - * You can use it for tasks like thresholding, applying a softmax function, or any other custom processing relevant to your problem. - * If no post-processing is required, you can simply return the `outputs` tensor directly. - -#### Integrating a Custom Inferer - -1. **Save:** Save your custom inferer class (e.g., `MyCustomInferer`) in a Python file within your project (e.g., `my_project/inferers/my_custom_inferer.py`). - -2. **Configure:** In your `config.yaml`, specify the inferer within the `system.inferer` section, providing the path to your class and any necessary arguments for its `__init__` method: - - ```yaml title="config.yaml" - system: - inferer: - _target_: my_project.inferers.my_custom_inferer.MyCustomInferer - arg1: value1 - arg2: value2 - ``` - - * **`_target_`:** Points to your custom inferer class. - * **`arg1` and `arg2`:** Arguments passed to your inferer's `__init__` method. - -With this configuration, Lighter will create an instance of your custom inferer and use it during the appropriate stages of your experiment. - -## Example: Test-Time Augmentation Inferer - -```python -# my_project/inferers/tta_inferer.py -import torch -import torchvision.transforms.functional as TF -from torch.nn import Module - -class TTAInferer: - """Test-Time Augmentation for robust predictions.""" - def __init__(self, num_augmentations=4, aggregate="mean"): - self.num_augmentations = num_augmentations - self.aggregate = aggregate - - def __call__(self, inputs: torch.Tensor, network: Module) -> torch.Tensor: - predictions = [] - - # Define augmentations - augmentations = [ - lambda x: x, # Original - lambda x: TF.hflip(x), # Horizontal flip - lambda x: TF.vflip(x), # Vertical flip - lambda x: torch.rot90(x, k=2, dims=[-2, -1]) # 180 rotation - ] - - for aug_fn in augmentations[:self.num_augmentations]: - # Apply augmentation - aug_input = aug_fn(inputs) - - # Get prediction - with torch.no_grad(): - pred = network(aug_input) - - # Reverse augmentation on prediction - if aug_fn != augmentations[0]: # Skip original - pred = aug_fn(pred) # Most augmentations are self-inverse - - predictions.append(pred) - - # Aggregate predictions - stacked = torch.stack(predictions) - if self.aggregate == "mean": - return stacked.mean(dim=0) - elif self.aggregate == "voting": - votes = stacked.argmax(dim=-1) - return torch.mode(votes, dim=0)[0] - else: - return stacked.mean(dim=0) -``` - -Use in config: -```yaml -system: - inferer: - _target_: my_project.inferers.TTAInferer - num_augmentations: 4 - aggregate: mean -``` - -## Performance Tips ⚡ - -| Optimization | Technique | Impact | -|-------------|-----------|--------| -| **Batch TTA** | Process all augmentations in one batch | 2-3x faster | -| **Mixed Precision** | Use `torch.cuda.amp.autocast()` | 30-50% speedup | -| **Reduce Augmentations** | Use only most impactful transforms | Proportional speedup | - -## Choosing the Right Inferer - -```mermaid -graph TD - A[What was your model trained on?] --> B{Training Setup} - B -->|Fixed patches| C[Use Sliding Window] - B -->|Full images| D{Memory constraints?} - D -->|Yes| C - D -->|No| E[Use Simple Inferer] - - A --> F{Need uncertainty?} - F -->|Yes| G[Add MC Dropout] - - A --> H{Want robustness?} - H -->|Yes| I[Add TTA] - - A --> J{Have multiple models?} - J -->|Yes| K[Use Ensemble] -``` - -## Configuration Examples - -### Model Trained on Patches → Full Image Inference -```yaml -# Model was trained on 128×128×128 patches -# Now need to process 512×512×200 volumes -system: - inferer: - _target_: monai.inferers.SlidingWindowInferer - roi_size: [128, 128, 128] # Must match training patch size! - sw_batch_size: 4 - overlap: 0.5 # 50% overlap for smooth predictions - mode: gaussian # Smooth blending at boundaries -``` - -### Adding Robustness with TTA -```yaml -# Model trained normally, but test data is noisier -system: - inferer: - _target_: my_project.inferers.TTAInferer - num_augmentations: 4 # Balance speed vs robustness - aggregate: mean # Average predictions -``` - -### Uncertainty Quantification with MC Dropout -```yaml -# Model has dropout layers, need confidence intervals -system: - inferer: - _target_: my_project.inferers.MCDropoutInferer - num_samples: 20 # Multiple forward passes - return_std: true # Return standard deviation as uncertainty -``` - -## Related Guides -- [Adapters](adapters.md) - Transform inference outputs -- [Writers](writers.md) - Save predictions diff --git a/docs/how-to/metrics.md b/docs/how-to/metrics.md deleted file mode 100644 index 6529b906..00000000 --- a/docs/how-to/metrics.md +++ /dev/null @@ -1,273 +0,0 @@ -# Custom Metrics: Beyond Standard Evaluation - -Metrics are the compass that guides your model development. While `torchmetrics` provides excellent built-in metrics, real-world projects often need custom evaluation logic. This guide shows you how to create powerful custom metrics that provide deep insights into your model's behavior. - -## Quick Start: Your First Custom Metric in 30 Seconds 🚀 - -```python -# my_project/metrics/weighted_accuracy.py -from torchmetrics import Metric -import torch - -class WeightedAccuracy(Metric): - """Accuracy that cares more about certain classes.""" - def __init__(self, class_weights): - super().__init__() - self.class_weights = class_weights - self.add_state("weighted_correct", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total_weight", default=torch.tensor(0.0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - pred_classes = preds.argmax(dim=1) - correct = pred_classes == target - weights = torch.tensor([self.class_weights[t.item()] for t in target]) - self.weighted_correct += (correct * weights).sum() - self.total_weight += weights.sum() - - def compute(self): - return self.weighted_correct / self.total_weight -``` - -Use it in your config: -```yaml -system: - metrics: - val: - - _target_: my_project.metrics.WeightedAccuracy - class_weights: [1.0, 2.0, 5.0] # Class 2 is 5x more important -``` - -## Core Concepts: The Metric Trinity 🏆 - -Every custom metric needs three essential components: - -| Component | Purpose | When Called | -|-----------|---------|-------------| -| **1. `add_state()`** | Register variables to track | Once at initialization | -| **2. `update()`** | Process batch & accumulate stats | Every batch | -| **3. `compute()`** | Calculate final metric value | End of epoch/validation | - -**The lifecycle flow:** -1. **Initialize** → Set up state variables -2. **Update** (repeated) → Process each batch -3. **Compute** → Calculate final metric - -## Creating a Custom Metric: Step-by-Step - -Let's walk through the steps of creating a custom metric in Lighter using `torchmetrics`. We'll create a simple example custom metric called `MyCustomMetric` for binary classification, which calculates a variation of accuracy. - -1. **Subclass `torchmetrics.Metric`**: - First, create a new Python file (e.g., `my_project/metrics/my_custom_metric.py`) within your project to define your custom metric class. Start by importing `torchmetrics.Metric` and subclassing it. - - ```python title="my_project/metrics/my_custom_metric.py" - from torchmetrics import Metric - import torch - - class MyCustomMetric(Metric): - def __init__(self): - super().__init__() - # ... (state initialization will be added in the next step) ... - - def update(self, preds: torch.Tensor, target: torch.Tensor): - # ... (update logic will be added in the next step) ... - pass - - def compute(self): - # ... (compute logic will be added in the next step) ... - pass - ``` - -2. **Initialize Metric State with `add_state()`**: - In the `__init__` method, use `self.add_state()` to initialize state variables for accumulated statistics. For `MyCustomMetric`, track correct and total predictions: - - ```python title="my_project/metrics/my_custom_metric.py" - from torchmetrics import Metric - import torch - - class MyCustomMetric(Metric): - def __init__(self): - super().__init__() - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") # Tracks correct predictions - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") # Tracks total predictions - ``` - - - Registers "correct" state variable: - - Initializes to a PyTorch tensor of 0. - - `dist_reduce_fx="sum"`: Reduces state across distributed processes by summing. - - Registers "total" state variable: - - Initializes to a PyTorch tensor of 0. - - `dist_reduce_fx="sum"`: Reduces state across distributed processes similarly. - -3. **Implement `update()` Method**: - The `update()` method processes each batch of predictions and targets. For `MyCustomMetric`, implement the following: - - 1. Convert probability predictions to binary (0/1). - 2. Count correct predictions and update state variables. - - ```python title="my_project/metrics/my_custom_metric.py" - from torchmetrics import Metric - import torch - - class MyCustomMetric(Metric): - def __init__(self): - super().__init__() - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - # 1. Convert probabilities to binary predictions - # Convert probabilities to binary (0/1). - # - `binary_preds = (preds >= 0.5).int()`: Converts probabilities to binary predictions (0 or 1). # commented out to avoid repetition - - # 2. Count correct predictions and update state variables. - # Count correct predictions and update state variables. - # - `self.correct += torch.sum(binary_preds == target)`: Increments `correct` state with batch's correct predictions. # commented out to avoid repetition - # - `self.total += target.numel()`: Increments `total` state with batch size. # commented out to avoid repetition - self.correct += torch.sum(binary_preds == target) - self.total += target.numel() - ``` - - - 1. Convert probability predictions to binary (0/1). - - 2. Count correct predictions and update state variables. - -4. **Implement `compute()` Method**: - The `compute()` method calculates the final metric value at the epoch end. For `MyCustomMetric`, calculate custom accuracy: - - ```python title="my_project/metrics/my_custom_metric.py" - from torchmetrics import Metric - import torch - - class MyCustomMetric(Metric): - def __init__(self): - super().__init__() - self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - binary_preds = (preds >= 0.5).int() - self.correct += torch.sum(binary_preds == target) - self.total += target.numel() - - def compute(self): - # Returns custom accuracy: correct predictions / total predictions - return self.correct.float() / self.total -``` - - - Returns custom accuracy: correct predictions / total predictions - -5. **Integrate with Lighter Configuration**: - Reference your custom metric in `config.yaml` to use it during train/val/test. - - **Example `config.yaml`**: - - ```yaml title="config.yaml" - project: my_project/ # Project root directory - - system: - metrics: - train: - - _target_: torchmetrics.Accuracy - - _target_: my_project.metrics.MyCustomMetric # Use custom metric - - val: - - _target_: torchmetrics.Accuracy - - _target_: my_project.metrics.MyCustomMetric - ``` - - This config uses both built-in `Accuracy` and `MyCustomMetric` during train/val stages. - -## Practical Example: Domain-Specific Metric - -```python -from torchmetrics import Metric -import torch - -class DiceScore(Metric): - """Dice coefficient for segmentation tasks.""" - def __init__(self, smooth=1e-6): - super().__init__() - self.smooth = smooth - self.add_state("intersection", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("union", default=torch.tensor(0.0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - # Flatten predictions and targets - preds = preds.view(-1) - target = target.view(-1) - - # Calculate intersection and union - intersection = (preds * target).sum() - union = preds.sum() + target.sum() - - self.intersection += intersection - self.union += union - - def compute(self): - # Calculate Dice score - dice = (2 * self.intersection + self.smooth) / (self.union + self.smooth) - return dice -``` - -Use in config: -```yaml -system: - metrics: - val: - - _target_: my_project.metrics.DiceScore - smooth: 1e-6 -``` - -## Key Optimization Tips ⚡ - -| Tip | Do | Don't | -|-----|----|---------| -| **Use Vectorization** | `(preds == target).sum()` | Loop through elements | -| **Accumulate Stats** | Store sums and counts | Store all predictions | -| **Handle Edge Cases** | Check for zero division | Assume valid inputs | - -## Common Pitfalls 🛡️ - -| Pitfall | Solution | -|---------|----------| -| **State accumulation across epochs** | Lighter resets automatically | -| **Wrong distributed reduction** | Use `dist_reduce_fx="sum"` for counts | -| **Type mismatches** | Convert tensors to same dtype | - -## Quick Reference Card 📄 - -```python -# Minimal custom metric template -from torchmetrics import Metric -import torch - -class YourMetric(Metric): - def __init__(self, your_param=1.0): - super().__init__() - # 1. Register state variables - self.add_state("state_var", default=torch.tensor(0.0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - # 2. Process batch and update state - self.state_var += your_computation(preds, target) - - def compute(self): - # 3. Calculate final metric - return self.state_var / normalization_factor -``` - -## Recap and Next Steps - -You now have the power to create sophisticated custom metrics: - -🎨 **What You Learned:** - -- Core metric lifecycle: `add_state()` → `update()` → `compute()` -- Performance optimization techniques -- Testing strategies for robust metrics -- Common patterns for classification, regression, and calibration - -💡 **Pro Tip:** Start simple, test thoroughly, optimize later! - -## Related Guides -- [Adapters](adapters.md) - Transform data for metrics -- [Writers](writers.md) - Save metric results diff --git a/docs/how-to/project_module.md b/docs/how-to/project_module.md deleted file mode 100644 index 809850c5..00000000 --- a/docs/how-to/project_module.md +++ /dev/null @@ -1,274 +0,0 @@ -# Project Module: Seamless Custom Code Integration - -As an ML practitioner, you often develop custom components like models, datasets, or metrics. Integrating these efficiently into your training framework, while maintaining a clean and reusable structure, is key for rapid experimentation. Lighter solves this with its Project Module system. - -## What is a Project Module? 🎯 - -A Project Module in Lighter utilizes Python's native module system. In Python, a **module** can be a single `.py` file, or a directory containing an `__init__.py` file, used to organize related code hierarchically. - -Lighter lets you designate a directory as your "project root." This root and its subdirectories (if they contain `__init__.py`) become dynamically importable in your Lighter configurations. This allows you to reference and instantiate custom classes and functions directly from your YAML files. - -This integration enables you to define and manage: - -- **🧠 Custom Models**: Your neural network architectures. -- **📦 Custom Datasets**: Your data loading logic. -- **🎯 Custom Metrics**: Your evaluation methods. -- **🔄 Custom Transforms**: Your data preprocessing. -- **🎛️ Custom Callbacks**: Your training hooks. - -**Key Benefits:** - -- 📦 **Encapsulation**: Project-specific code. -- 🚀 **Rapid Prototyping**: Test ideas quickly without changing other code. - -## Project Structure: Organizing Your Custom Code - -**Key Principle:** For Lighter to import your custom code, it must be a valid Python module. - -- **Python Module (File)**: A single `.py` file (e.g., `my_model.py`). -- **Python Module (Directory)**: A directory with an `__init__.py` file (e.g., `models/` with `models/__init__.py`). This groups related submodules. - -Example: `my_project/` is a Python module due to `__init__.py`. `models/` and `datasets/` are also modules. `experiments/` is not a module, typically holding config files. - -``` -my_project/ -├── __init__.py # Makes 'my_project' a module -├── models/ -│ ├── __init__.py # Makes 'models' a module -│ └── my_model.py # A module within the 'models' module -├── datasets/ -│ ├── __init__.py # Makes 'datasets' a module -│ └── my_dataset.py # A module within the 'datasets' module -└── experiments/ # Not a Python; typically for config files - ├── finetune_full.yaml - └── finetune_decoder.yaml -``` - -## Defining Project Modules - -With your project structure set up, defining custom components is straightforward. Within your module directories (e.g., `models/`, `datasets/`), define your custom Python modules as regular `.py` files. For example, define `MyModel` in `my_model.py`: - -```python title="my_project/models/my_model.py" -import torch.nn as nn - -class MyModel(nn.Module): - def __init__(self, input_size, num_classes): - super().__init__() - # Define a linear layer - self.linear = nn.Linear(input_size, num_classes) - - def forward(self, x): - # Forward pass through the linear layer - return self.linear(x) -``` - -and a custom dataset `MyDataset` in `my_dataset.py`: - -```python title="my_project/datasets/my_dataset.py" -from torch.utils.data import Dataset - -class MyDataset(Dataset): - def __init__(self, data_path, transform=None): - self.data_path = data_path - self.transform = transform - # Load data from data_path (implementation not shown) - self.samples = [...] # List of data samples (replace [...] with actual data loading) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - sample = self.samples[idx] - # Preprocess sample (implementation not shown) - # ... - if self.transform: - sample = self.transform(sample) - return sample -``` - -## Importing Project Module in the Config - -After defining your custom modules, make them accessible to Lighter by configuring it to dynamically load your project. Lighter's [`import_module_from_path`](../../reference/utils/dynamic_imports/#lighter.utils.dynamic_imports.import_module_from_path) function imports your designated project root as a top-level module named `project`. - -### Specifying `project` Path - -Specify your project's root directory in `config.yaml` using the `project` key. This tells Lighter the path to your custom module collection. - -**Example:** - -```yaml title="config.yaml" -project: my_project/ # Project root path -``` - -!!! warning "Relative Path Behavior" - The `project` path is relative to your **current working directory** when running the `lighter` command, not relative to the config file location. - - ```bash - # If you run from parent directory - cd /path/to/parent && lighter fit /path/to/my_project/experiments/config.yaml project=/path/to/my_project/ - - # If you run from project directory - cd /path/to/parent/my_project && lighter fit experiments/config.yaml project=. - ``` - - **Tip:** Use absolute paths to avoid confusion, or be mindful of your current working directory. - -### Referencing Your Project Module - -With the `project` path specified, Lighter makes your custom modules available under the top-level module name `project`. Reference your project's modules and classes like any other Python module. See `system::model` and `system::dataloaders::train::dataset` in the config below: - -**Example:** - -```yaml title="config.yaml" hl_lines="5 13" -project: my_project/ - -system: - model: - _target_: project.models.MyModel - input_size: 784 - num_classes: 10 - - dataloaders: - train: - _target_: torch.utils.data.DataLoader - dataset: - _target_: project.datasets.MyDataset - data_path: "data/train.csv" - # ... dataset arguments ... - batch_size: 32 - shuffle: True -``` - - -## Practical Example: Custom Model Architecture - -```python -# my_project/models/custom_unet.py -import torch -import torch.nn as nn - -class CustomUNet(nn.Module): - """U-Net for segmentation tasks.""" - def __init__(self, in_channels=3, num_classes=2, features=[64, 128, 256, 512]): - super().__init__() - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - self.pool = nn.MaxPool2d(2, 2) - - # Encoder - for feature in features: - self.encoder.append(self._block(in_channels, feature)) - in_channels = feature - - # Decoder - for feature in reversed(features[:-1]): - self.decoder.append( - nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2) - ) - self.decoder.append(self._block(feature*2, feature)) - - self.final = nn.Conv2d(features[0], num_classes, kernel_size=1) - - def _block(self, in_channels, out_channels): - return nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) - ) - - def forward(self, x): - skip_connections = [] - - # Encoder - for encode in self.encoder: - x = encode(x) - skip_connections.append(x) - x = self.pool(x) - - skip_connections = skip_connections[::-1] - - # Decoder - for idx in range(0, len(self.decoder), 2): - x = self.decoder[idx](x) - skip = skip_connections[idx//2] - x = torch.cat((skip, x), dim=1) - x = self.decoder[idx+1](x) - - return self.final(x) -``` - -Use in config: -```yaml -project: my_project/ - -system: - model: - _target_: project.models.custom_unet.CustomUNet - in_channels: 3 - num_classes: 10 - features: [64, 128, 256, 512] -``` - -## Best Practices for Project Organization 🏆 - -### Recommended Structure -``` -my_project/ -├── __init__.py -├── models/ # Neural network architectures -├── datasets/ # Data loading and processing -├── metrics/ # Custom evaluation metrics -├── callbacks/ # Training callbacks -└── utils/ # Helper functions -``` - -### Key Guidelines - -1. **Ensure `__init__.py` files are present** in all directories intended to be Python modules (i.e., containing code you wish to import). -2. **Use type hints** for better IDE support -3. **Write tests** for critical components -4. **Document with docstrings** for team collaboration -5. **Keep modules focused** - one concept per file - -## Running with Custom Modules - -```bash -# Basic training -lighter fit config.yaml - -# With module path override -lighter fit config.yaml project=./my_research_project - -# Multiple configs with custom modules -lighter fit base.yaml,models/unet.yaml,data/custom.yaml -``` - -## Common Issues & Solutions - -| Issue | Solution | -|-------|----------| -| **ModuleNotFoundError** | Check `__init__.py` files and project path | -| **AttributeError** | Ensure classes/functions are correctly imported or exposed in `__init__.py` files if accessing them directly from the package. | -| **Circular imports** | Use lazy imports inside functions | -| **Path issues** | Use absolute imports or `project: ./my_project` | - -## Recap and Next Steps - -You're now equipped to build sophisticated custom modules: - -🎯 **Key Takeaways:** - -- Structure projects with clear module organization -- Use type hints and documentation for maintainability -- Leverage advanced patterns (multi-modal, caching, custom augmentations) -- Test your modules for reliability - -💡 **Remember:** Great research code is modular, tested, and reusable! - -## Related Guides -- [Configuration](configuration.md) - Referencing the project module -- [Adapters](adapters.md) - Custom adapter creation -- [Metrics](metrics.md) - Custom metric creation diff --git a/docs/how-to/recipes.md b/docs/how-to/recipes.md deleted file mode 100644 index 044035df..00000000 --- a/docs/how-to/recipes.md +++ /dev/null @@ -1,718 +0,0 @@ ---- -title: Configuration Recipes ---- - -# Configuration Recipes - -Ready-to-use configurations for common scenarios. Copy, adapt, and run. - -## Training Infrastructure - -### Multi-GPU: DDP - -```yaml -trainer: - devices: -1 # All GPUs (or specify: devices: 4) - accelerator: gpu - strategy: ddp - precision: "16-mixed" - max_epochs: 100 - sync_batchnorm: true # Synchronize batch norm across GPUs - -system: - dataloaders: - train: - batch_size: 32 # Per GPU! Effective = 32 * num_gpus - num_workers: 4 # Per GPU - pin_memory: true - persistent_workers: true -``` - -### Multi-GPU: DeepSpeed (Large Models) - -```yaml -trainer: - devices: -1 - accelerator: gpu - strategy: deepspeed_stage_2 # or deepspeed_stage_3 - precision: "16-mixed" - -system: - dataloaders: - train: - batch_size: 8 # Smaller for large models -``` - -### Multi-GPU: FSDP (Very Large Models) - -```yaml -trainer: - devices: -1 - accelerator: gpu - strategy: fsdp - precision: "bf16-mixed" # BFloat16 often better for FSDP -``` - -### Experiment Tracking: TensorBoard - -```yaml -trainer: - logger: - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - name: my_experiment - version: null # Auto-increment - - callbacks: - - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step -``` - -### Experiment Tracking: Weights & Biases - -```yaml -_requires_: - - "$import datetime" - -trainer: - logger: - _target_: pytorch_lightning.loggers.WandbLogger - project: my_project - name: "$'exp_' + datetime.datetime.now().strftime('%Y%m%d_%H%M%S')" - save_dir: logs - log_model: true # Save checkpoints to W&B - - callbacks: - - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: epoch -``` - -### Multiple Loggers - -```yaml -trainer: - logger: - - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - name: experiment - - _target_: pytorch_lightning.loggers.WandbLogger - project: my_project - name: experiment - - _target_: pytorch_lightning.loggers.CSVLogger - save_dir: logs -``` - -### Best Model Checkpointing - -```yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: checkpoints - filename: "best-{epoch:02d}-{val_loss:.4f}" - monitor: val_loss - mode: min - save_top_k: 3 # Keep best 3 - save_last: true -``` - -### Monitor Multiple Metrics - -```yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: checkpoints/loss - filename: "loss-{epoch:02d}-{val_loss:.4f}" - monitor: val_loss - mode: min - save_top_k: 2 - - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: checkpoints/acc - filename: "acc-{epoch:02d}-{val_acc:.4f}" - monitor: val_acc - mode: max - save_top_k: 2 -``` - -### Early Stopping - -```yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: val_loss - patience: 10 - mode: min - min_delta: 0.001 - verbose: true - - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - monitor: val_loss - mode: min - save_top_k: 1 -``` - -## Data Augmentation - -### Image Classification - -```yaml -system: - dataloaders: - train: - dataset: - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.RandomResizedCrop - size: 224 - scale: [0.8, 1.0] - - _target_: torchvision.transforms.RandomHorizontalFlip - p: 0.5 - - _target_: torchvision.transforms.RandomRotation - degrees: 15 - - _target_: torchvision.transforms.ColorJitter - brightness: 0.4 - contrast: 0.4 - saturation: 0.4 - hue: 0.1 - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - - val: - dataset: - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.Resize - size: 256 - - _target_: torchvision.transforms.CenterCrop - size: 224 - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] -``` - -### RandAugment - -```yaml -_requires_: - - "$from torchvision.transforms import RandAugment" - -system: - dataloaders: - train: - dataset: - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.RandomResizedCrop - size: 224 - - _target_: torchvision.transforms.RandAugment - num_ops: 2 - magnitude: 9 - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] -``` - -## Learning Rate Schedules - -### Cosine Annealing - -```yaml -vars: - max_epochs: 100 - -system: - optimizer: - _target_: torch.optim.AdamW - params: "$@system::model.parameters()" - lr: 0.001 - weight_decay: 0.01 - - scheduler: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - optimizer: "@system::optimizer" - T_max: "%vars::max_epochs" - eta_min: 0.00001 -``` - -### ReduceLROnPlateau - -```yaml -system: - scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - optimizer: "@system::optimizer" - mode: min - factor: 0.5 - patience: 5 - min_lr: 0.00001 - verbose: true -``` - -### Step Decay - -```yaml -system: - scheduler: - _target_: torch.optim.lr_scheduler.MultiStepLR - optimizer: "@system::optimizer" - milestones: [30, 60, 90] - gamma: 0.1 -``` - -## Transfer Learning - -### Fine-tuning Pretrained Models - -```yaml -system: - model: - _target_: torchvision.models.resnet50 - weights: IMAGENET1K_V2 - num_classes: 10 - - optimizer: - _target_: torch.optim.SGD - params: "$@system::model.parameters()" - lr: 0.001 # Lower LR for fine-tuning - momentum: 0.9 - - callbacks: - - _target_: lighter.callbacks.Freezer - modules: ["layer1", "layer2"] # Freeze early layers -``` - -### Differential Learning Rates - -```yaml -system: - model: - _target_: torchvision.models.resnet50 - weights: IMAGENET1K_V2 - num_classes: 10 - - optimizer: - _target_: torch.optim.SGD - params: - - params: "$@system::model.layer1.parameters()" - lr: 0.0001 - - params: "$@system::model.layer4.parameters()" - lr: 0.001 - - params: "$@system::model.fc.parameters()" - lr: 0.01 - momentum: 0.9 -``` - -## Gradient Handling - -### Gradient Clipping - -```yaml -trainer: - gradient_clip_val: 1.0 - gradient_clip_algorithm: norm # or 'value' -``` - -### Gradient Accumulation - -```yaml -trainer: - accumulate_grad_batches: 4 - -system: - dataloaders: - train: - batch_size: 8 # Effective: 8 * 4 = 32 -``` - -## Performance Optimization - -### Fast Training - -```yaml -trainer: - precision: "16-mixed" - devices: -1 - accelerator: gpu - benchmark: true # cudnn.benchmark - -system: - dataloaders: - train: - num_workers: 8 - pin_memory: true - persistent_workers: true - prefetch_factor: 2 - batch_size: 64 -``` - -### Memory Optimization - -```yaml -trainer: - precision: "16-mixed" - accumulate_grad_batches: 8 - -system: - dataloaders: - train: - batch_size: 4 # Effective: 4 * 8 = 32 - num_workers: 2 -``` - -## Development & Debugging - -### Fast Development Run - -```yaml -trainer: - fast_dev_run: true # 1 batch of train/val/test -``` - -### Overfit Single Batch - -```yaml -trainer: - overfit_batches: 1 - max_epochs: 100 - logger: false -``` - -### Profiling - -```yaml -trainer: - profiler: simple # or 'advanced', 'pytorch' - max_epochs: 1 - limit_train_batches: 10 -``` - ---- - -## Machine Learning Paradigms - -### Multi-Task Learning - -Train one model on multiple tasks with shared representations. - -```yaml -system: - model: - _target_: my_project.MultiTaskModel - backbone: resnet50 - num_classes_classification: 10 - num_classes_segmentation: 2 - - criterion: - _target_: my_project.MultiTaskLoss - classification_weight: 1.0 - segmentation_weight: 0.5 - - metrics: - train: - classification: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - segmentation: - - _target_: torchmetrics.JaccardIndex - task: binary - val: "%system::metrics::train" -``` - -```python title="my_project/losses.py" -import torch.nn as nn - -class MultiTaskLoss(nn.Module): - def __init__(self, classification_weight=1.0, segmentation_weight=1.0): - super().__init__() - self.cls_loss = nn.CrossEntropyLoss() - self.seg_loss = nn.BCEWithLogitsLoss() - self.cls_weight = classification_weight - self.seg_weight = segmentation_weight - - def forward(self, pred, target): - cls_loss = self.cls_loss(pred['classification'], target['class']) - seg_loss = self.seg_loss(pred['segmentation'], target['mask']) - - return { - "total": self.cls_weight * cls_loss + self.seg_weight * seg_loss, - "classification": cls_loss, - "segmentation": seg_loss, - } -``` - -All sublosses logged automatically. - -### Self-Supervised Learning (Contrastive) - -```yaml -system: - model: - _target_: my_project.SimCLRModel - backbone: resnet18 - projection_dim: 128 - - criterion: - _target_: my_project.NTXentLoss - temperature: 0.5 - - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: 0 - target_accessor: null # No labels -``` - -```python title="my_project/model.py" -import torch.nn as nn - -class SimCLRModel(nn.Module): - def __init__(self, backbone='resnet18', projection_dim=128): - super().__init__() - self.encoder = torch.hub.load('pytorch/vision', backbone, weights=None) - self.encoder.fc = nn.Identity() - self.projector = nn.Sequential( - nn.Linear(512, 512), - nn.ReLU(), - nn.Linear(512, projection_dim) - ) - - def forward(self, x): - features = self.encoder(x) - return self.projector(features) -``` - -### Knowledge Distillation - -Train small student model from large teacher. - -```yaml -system: - model: - _target_: my_project.DistillationModel - student: - _target_: torchvision.models.resnet18 - num_classes: 10 - teacher: - _target_: torchvision.models.resnet50 - num_classes: 10 - teacher_weights: checkpoints/teacher.ckpt - - criterion: - _target_: my_project.DistillationLoss - temperature: 3.0 - alpha: 0.7 # Distillation loss weight -``` - -```python title="my_project/model.py" -import torch.nn as nn - -class DistillationModel(nn.Module): - def __init__(self, student, teacher, teacher_weights=None): - super().__init__() - self.student = student - self.teacher = teacher - - if teacher_weights: - self.teacher.load_state_dict(torch.load(teacher_weights)) - - for param in self.teacher.parameters(): - param.requires_grad = False - self.teacher.eval() - - def forward(self, x): - student_logits = self.student(x) - with torch.no_grad(): - teacher_logits = self.teacher(x) - return {"student": student_logits, "teacher": teacher_logits} -``` - -### Curriculum Learning - -Progressively increase task difficulty. - -```python title="my_project/model.py" -class CurriculumModel(nn.Module): - def __init__(self, num_classes=10, max_epochs=100): - super().__init__() - self.backbone = nn.Sequential(...) - self.classifier = nn.Linear(512, num_classes) - self.max_epochs = max_epochs - - def forward(self, x, epoch=None): - # Lighter automatically injects epoch - if epoch is not None: - difficulty = min(epoch / self.max_epochs, 1.0) - x = self.apply_difficulty(x, difficulty) - - features = self.backbone(x) - return self.classifier(features) -``` - -No config needed—epoch injection is automatic. - -### Model Ensembling - -```yaml -system: - model: - _target_: my_project.EnsembleModel - models: - - _target_: torchvision.models.resnet18 - num_classes: 10 - - _target_: torchvision.models.resnet34 - num_classes: 10 - checkpoints: - - checkpoints/model1.ckpt - - checkpoints/model2.ckpt - - inferer: - _target_: my_project.EnsembleInferer - ensemble_method: "average" # or "voting" -``` - -```python title="my_project/model.py" -class EnsembleInferer: - def __init__(self, ensemble_method="average"): - self.method = ensemble_method - - def __call__(self, x, model, **kwargs): - predictions = [] - for m in model.models: - m.eval() - with torch.no_grad(): - predictions.append(m(x)) - - if self.method == "average": - return torch.stack(predictions).mean(dim=0) - elif self.method == "voting": - return torch.stack(predictions).mode(dim=0)[0] -``` - -### Cross-Validation - -```python title="cross_validation.py" -import subprocess - -def run_kfold(base_config="config.yaml", k=5): - for fold in range(k): - subprocess.run([ - "lighter", "fit", base_config, - f"system::dataloaders::train::dataset::fold={fold}", - f"trainer::logger::version=fold_{fold}" - ]) - -if __name__ == "__main__": - run_kfold() -``` - -```yaml title="config.yaml" -system: - dataloaders: - train: - dataset: - _target_: my_project.KFoldDataset - k: 5 - fold: 0 # Overridden from CLI -``` - ---- - -## Production Setup - -Complete production-ready configuration: - -```yaml -_requires_: - - "$import datetime" - -vars: - experiment_name: "production_run" - timestamp: "$datetime.datetime.now().strftime('%Y%m%d_%H%M%S')" - -trainer: - _target_: pytorch_lightning.Trainer - devices: -1 - accelerator: gpu - strategy: ddp - precision: "16-mixed" - max_epochs: 200 - gradient_clip_val: 1.0 - - logger: - - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: logs - name: "%vars::experiment_name" - version: "%vars::timestamp" - - _target_: pytorch_lightning.loggers.WandbLogger - project: production - name: "$%vars::experiment_name + '_' + %vars::timestamp" - - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: "$'checkpoints/' + %vars::experiment_name" - filename: "best-{epoch:02d}-{val_loss:.4f}" - monitor: val_loss - mode: min - save_top_k: 3 - save_last: true - - - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: val_loss - patience: 20 - mode: min - - - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: epoch - -system: - _target_: lighter.System - - model: - _target_: torchvision.models.resnet50 - weights: IMAGENET1K_V2 - num_classes: 10 - - criterion: - _target_: torch.nn.CrossEntropyLoss - - optimizer: - _target_: torch.optim.AdamW - params: "$@system::model.parameters()" - lr: 0.001 - weight_decay: 0.01 - - scheduler: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - optimizer: "@system::optimizer" - T_max: 200 - - dataloaders: - train: - batch_size: 64 - num_workers: 8 - pin_memory: true - persistent_workers: true - val: - batch_size: 128 - num_workers: 4 - pin_memory: true -``` - -## Next Steps - -- [Configuration Reference](configuration.md) - Complete syntax guide -- [Troubleshooting](troubleshooting.md) - Debug issues -- [Adapters](adapters.md) - Handle any data format -- [System Internals](../design/system.md) - Understanding the pipeline diff --git a/docs/how-to/run.md b/docs/how-to/run.md deleted file mode 100644 index e7b6951d..00000000 --- a/docs/how-to/run.md +++ /dev/null @@ -1,229 +0,0 @@ -# Running Experiments - -Lighter streamlines deep learning experiments through well-defined stages. This guide covers everything from basic runs to advanced workflows. - -## Stages Overview - -Lighter orchestrates experiments through four distinct stages, each serving a specific purpose: - -| Stage | Purpose | Common Use Cases | -|-------|---------|------------------| -| **`fit`** | Train model on training data | Initial training, fine-tuning, transfer learning | -| **`validate`** | Evaluate on validation data | Hyperparameter tuning, model selection | -| **`test`** | Final evaluation on test data | Performance benchmarking, final metrics | -| **`predict`** | Generate predictions on new data | Inference, deployment, data analysis | - -## Quick Start - -### Basic Commands - -```bash -# Train a model -lighter fit config.yaml - -# Validate a trained model -lighter validate config.yaml - -# Test final performance -lighter test config.yaml - -# Generate predictions -lighter predict config.yaml -``` - -### Common Workflows - -#### 1. Full Training Pipeline -```bash -# Train + validate (automatic validation during training) -lighter fit config.yaml - -# Then test on held-out data - equivalent to Trainer.test(..., ckpt_path="best") -lighter test config.yaml args::test::ckpt_path="best.ckpt" -``` - -#### 2. Resume Training from Checkpoint -```bash -lighter fit config.yaml args::fit::ckpt_path="last.ckpt" -``` - -#### 3. Fine-tuning Pre-trained Model -```bash -# Use base config + fine-tuning overrides -lighter fit base_config.yaml,finetune_config.yaml -``` - -## Advanced Configuration - -### Stage-Specific Arguments - -The `args` key provides a way to pass arguments directly to the PyTorch Lightning `Trainer`'s stage methods: `fit`, `validate`, `test`, and `predict`. Each key under `args` corresponds to a stage, and the arguments within it are passed to that stage's method. For instance, `args.test.ckpt_path` is passed as `Trainer.test(ckpt_path=...)`. - -Configure stage arguments in your YAML or via CLI: - -```yaml title="config.yaml" -args: - fit: - ckpt_path: null # Start from scratch - validate: - ckpt_path: "checkpoints/best.ckpt" - test: - ckpt_path: "checkpoints/best.ckpt" - predict: - ckpt_path: "checkpoints/best.ckpt" - return_predictions: true -``` - -### CLI Override Patterns - -```bash -# Override any config parameter -lighter fit config.yaml trainer::max_epochs=50 - -# Override nested parameters -lighter fit config.yaml system::optimizer::lr=0.001 - -# Override list elements -lighter fit config.yaml trainer::callbacks::0::patience=10 - -# Multiple overrides -lighter fit config.yaml \ - trainer::max_epochs=100 \ - system::optimizer::lr=0.0001 \ - trainer::devices=2 -``` - -## Pro Tips 💡 - -### 1. Config Composition -Combine multiple configs for modular experiments: -```bash -# Base + dataset + model + training configs -lighter fit base.yaml,data/cifar10.yaml,models/resnet.yaml,train/sgd.yaml -``` - -### 2. Quick Experimentation -Test configurations without full training: -```bash -# Fast run with 2 batches -lighter fit config.yaml trainer::fast_dev_run=2 -``` - -### 3. GPU Management -```bash -# Use specific GPUs -lighter fit config.yaml trainer::devices=[0,1] - -# Use all available GPUs -lighter fit config.yaml trainer::devices=-1 -``` - -### 4. Debugging Runs -```bash -# Enable detailed logging -lighter fit config.yaml trainer::log_every_n_steps=1 - -# Profile your code -lighter fit config.yaml trainer::profiler="simple" -``` - -## Common Patterns - -### Pattern 1: Hyperparameter Search -```bash -# Run multiple experiments with different learning rates -for lr in 0.001 0.01 0.1; do - lighter fit config.yaml \ - system::optimizer::lr=$lr \ - trainer::logger::name="lr_$lr" -done -``` - -### Pattern 2: Cross-Validation -```bash -# Run k-fold cross-validation -for fold in {1..5}; do - lighter fit config.yaml \ - system::dataloaders::train::dataset::fold=$fold \ - trainer::logger::name="fold_$fold" -done -``` - -### Pattern 3: Progressive Training -```bash -# Start with small resolution, then increase -lighter fit config.yaml vars::image_size=128 trainer::max_epochs=10 -lighter fit config.yaml vars::image_size=256 args::fit::ckpt_path="last.ckpt" -lighter fit config.yaml vars::image_size=512 args::fit::ckpt_path="last.ckpt" -``` - -## Troubleshooting - -### Issue: Out of Memory -```bash -# Reduce batch size -lighter fit config.yaml system::dataloaders::train::batch_size=8 - -# Enable gradient accumulation -lighter fit config.yaml trainer::accumulate_grad_batches=4 - -# Use mixed precision -lighter fit config.yaml trainer::precision="16-mixed" -``` - -### Issue: Training Too Slow -```bash -# Increase number of workers -lighter fit config.yaml system::dataloaders::train::num_workers=8 - -# Enable compile mode (PyTorch 2.0+) -lighter fit config.yaml system::model::compile=true -``` - -### Issue: Validation Takes Too Long -```bash -# Reduce validation frequency -lighter fit config.yaml trainer::check_val_every_n_epoch=5 - -# Limit validation batches -lighter fit config.yaml trainer::limit_val_batches=0.25 -``` - -## Environment Variables - -Control Lighter behavior with environment variables: - -```bash -# Set random seed for reproducibility via Pytorch Lightning -PL_GLOBAL_SEED=42 lighter fit config.yaml - -# Enable debugging mode -LIGHTER_DEBUG=1 lighter fit config.yaml -``` - -## Quick Reference - -| Task | Command | -|------|---------| -| Train from scratch | `lighter fit config.yaml` | -| Resume training | `lighter fit config.yaml args::fit::ckpt_path="last.ckpt"` | -| Validate checkpoint | `lighter validate config.yaml args::validate::ckpt_path="best.ckpt"` | -| Test model | `lighter test config.yaml args::test::ckpt_path="best.ckpt"` | -| Generate predictions | `lighter predict config.yaml args::predict::ckpt_path="best.ckpt"` | -| Fast debugging | `lighter fit config.yaml trainer::fast_dev_run=true` | -| Multi-GPU training | `lighter fit config.yaml trainer::devices=4` | -| Mixed precision | `lighter fit config.yaml trainer::precision="16-mixed"` | - -## Recap and Next Steps - -You now have a comprehensive understanding of running experiments with Lighter. Key takeaways: - -* Four stages (`fit`, `validate`, `test`, `predict`) cover the full ML lifecycle -* Flexible configuration through YAML files and CLI overrides -* Powerful composition and workflow patterns -* Built-in solutions for common issues - -## Related Guides -- [Configuration Guide](configuration.md) - Config syntax and patterns -- [Troubleshooting](troubleshooting.md) - Common errors and solutions -- [Experiment Tracking](experiment_tracking.md) - Logging experiments diff --git a/docs/how-to/troubleshooting.md b/docs/how-to/troubleshooting.md deleted file mode 100644 index affb956e..00000000 --- a/docs/how-to/troubleshooting.md +++ /dev/null @@ -1,776 +0,0 @@ -# Troubleshooting - -Comprehensive guide to debugging errors and issues in Lighter. Includes actual error messages and step-by-step solutions. - -## Configuration Errors - -### ModuleNotFoundError: No module named 'project' - -**Cause:** Missing `__init__.py` files or incorrect project path - -**Solution:** -```yaml -# In your config.yaml -project: ./my_project # Ensure path is correct - -# Ensure all module directories have __init__.py: -my_project/ -├── __init__.py # Required! -├── models/ -│ ├── __init__.py # Required! -│ └── my_model.py -``` - -### Config Reference Errors - -**Wrong:** `"$@system::model::parameters()"` - Using `::` for attributes -**Correct:** `"$@system::model.parameters()"` - Use `.` for Python attributes - -**Wrong:** Circular references -```yaml -model: - lr: "@system::optimizer::lr" # Circular! -optimizer: - lr: "@system::model.lr" # Circular! -``` - -**Correct:** Use `vars` section -```yaml -vars: - lr: 0.001 -model: - lr: "%vars::lr" -optimizer: - lr: "%vars::lr" -``` - -### YAML Syntax Errors - -Common mistakes: -- Missing colons after keys -- Inconsistent indentation (use spaces, not tabs) -- Missing quotes around values with special characters -- Missing values (like the `roi_size` example in inferers) - -### Sparkwheel Validation Errors - -Lighter uses Sparkwheel for config validation. Here's how to interpret validation errors: - -**Error Example:** -``` -ValueError: Configuration validation failed: -system.model: Missing required field '_target_' -system.optimizer.lr: Expected float, got str -``` - -**Solution:** -Check the schema in `src/lighter/engine/schema.py` or add missing fields: -```yaml -system: - model: - _target_: torch.nn.Linear # Must have _target_ - optimizer: - lr: 0.001 # Not "0.001" (string) -``` - -**Error: Missing required component** -``` -ValueError: Configuration validation failed: -system.optimizer: Required field missing -``` - -**Solution:** Lighter requires certain components depending on the stage: -- FIT stage: model, optimizer, criterion, train dataloader required -- VALIDATE stage: model, criterion, val dataloader required -- TEST stage: model, test dataloader required -- PREDICT stage: model, predict dataloader required - -### Reference Resolution Errors - -**Error: Reference not found** -``` -KeyError: 'Could not resolve reference @system::modell' -``` - -**Solution:** Typo in reference path. Check spelling: -```yaml -optimizer: - params: "$@system::model.parameters()" # Not 'modell' -``` - -**Error: Attribute not found** -``` -AttributeError: 'ResNet' object has no attribute 'paramters' -``` - -**Solution:** Typo in method name: -```yaml -optimizer: - params: "$@system::model.parameters()" # Not 'paramters' -``` - -**Error: Using :: for Python attributes** -``` -# Wrong -params: "$@system::model::parameters()" - -# Correct -params: "$@system::model.parameters()" -``` - -**Rule**: Use `::` for config navigation, `.` for Python attributes - -## Training Issues - -### CUDA Out of Memory - -**Solutions:** -```bash -# Reduce batch size -lighter fit config.yaml system::dataloaders::train::batch_size=8 - -# Enable gradient accumulation -lighter fit config.yaml trainer::accumulate_grad_batches=4 - -# Use mixed precision -lighter fit config.yaml trainer::precision="16-mixed" -``` - -For distributed strategies, see [PyTorch Lightning docs](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html). - -### Loss is NaN - -**Check:** -1. Learning rate too high → Reduce by 10x -2. Missing data normalization → Add transforms -3. Wrong loss function for task → Verify criterion -4. Gradient explosion → Add gradient clipping in Trainer config - -### Slow Training - -**Optimize:** -```yaml -system: - dataloaders: - train: - num_workers: 8 # Increase for faster data loading - pin_memory: true # For GPU training - persistent_workers: true # Reduce worker startup overhead -``` - -For profiling and optimization, see [PyTorch Lightning performance docs](https://lightning.ai/docs/pytorch/stable/tuning/profiler.html). - -### Metrics Not Computing - -**Error:** No metrics logged to TensorBoard/W&B - -**Cause 1:** Metrics not defined for the mode -```yaml -# Wrong: No val metrics -system: - metrics: - train: - - _target_: torchmetrics.Accuracy -``` - -**Solution:** Add metrics for each mode -```yaml -system: - metrics: - train: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - val: "%system::metrics::train" # Reuse train config -``` - -**Cause 2:** MetricsAdapter argument mismatch -```yaml -# Wrong: Metric expects 'preds', but adapter uses default (positional) -system: - adapters: - train: - metrics: - _target_: lighter.adapters.MetricsAdapter - # Missing pred_argument and target_argument -``` - -**Solution:** Match metric signature -```yaml -system: - adapters: - train: - metrics: - _target_: lighter.adapters.MetricsAdapter - pred_argument: "preds" - target_argument: "target" -``` - -### Model Not Learning (Loss Not Decreasing) - -**Check 1: Are gradients flowing?** - -Add to config temporarily: -```yaml -_requires_: - - "$import torch" - -system: - model: - _target_: MyModel - # Check grad flow after first batch - _debug: "$print('Has grad:', next(iter(@system::model.parameters())).grad is not None)" -``` - -**Check 2: Is optimizer updating weights?** - -Enable optimizer stats logging (automatic in Lighter): -```yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step -``` - -Check `train/lr` in logs. If it's not changing with a scheduler, scheduler may be misconfigured. - -**Check 3: Is loss function appropriate?** - -- Classification: Use `CrossEntropyLoss` (takes logits, not probabilities) -- Regression: Use `MSELoss` -- Binary: Use `BCEWithLogitsLoss` (takes logits) - -**Common mistake:** -```yaml -# Wrong: Applying softmax before CrossEntropyLoss -system: - adapters: - train: - criterion: - pred_transforms: - - _target_: torch.nn.functional.softmax # ❌ Don't do this! - criterion: - _target_: torch.nn.CrossEntropyLoss -``` - -**Correct:** -```yaml -# CrossEntropyLoss applies softmax internally -system: - criterion: - _target_: torch.nn.CrossEntropyLoss # No softmax needed -``` - -## Distributed Training Issues (DDP) - -### Error: Address already in use - -**Error:** -``` -RuntimeError: Address already in use -``` - -**Cause:** Previous DDP process didn't terminate cleanly - -**Solution:** -```bash -# Find and kill lingering processes -ps aux | grep python -kill -9 - -# Or restart your terminal/jupyter kernel -``` - -### File Writing Conflicts - -**Error:** Multiple processes writing to same file causing corruption - -**Cause:** Writers in predict mode running on all GPUs - -**Solution:** Use rank-zero only for file operations -```python -# In custom callback -from pytorch_lightning.utilities import rank_zero_only - -class MyWriter(Callback): - @rank_zero_only - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - # Only rank 0 writes - save_predictions(outputs) -``` - -Lighter's built-in Writers handle this automatically. - -### Metrics Aggregation Issues - -**Error:** Metrics values differ across GPUs - -**Solution:** Lighter automatically sets `sync_dist=True` for epoch-level metrics. For manual logging: -```python -self.log("custom_metric", value, on_epoch=True, sync_dist=True) -``` - -### Different Behavior on Single vs Multi-GPU - -**Cause:** Batch normalization or dropout behaving differently - -**Solution:** Use `sync_batchnorm` for BN: -```yaml -trainer: - strategy: ddp - sync_batchnorm: true # Synchronize BN statistics -``` - -## Adapter Debugging - -### Inspecting Tensor Shapes - -Add print transforms to see tensor shapes at each stage: - -```yaml -system: - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: "image" - target_accessor: "mask" - # Debug: Print shapes after extraction - input_transforms: - - "$lambda x: print(f'Input shape: {x.shape}') or x" - target_transforms: - - "$lambda x: print(f'Target shape: {x.shape}') or x" - - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_transforms: - - "$lambda x: print(f'Pred before softmax: {x.shape}') or x" - - _target_: torch.nn.functional.softmax - dim: 1 - - "$lambda x: print(f'Pred after softmax: {x.shape}') or x" -``` - -### Common Adapter Errors - -**Error: Wrong argument order** -``` -TypeError: forward() got an unexpected keyword argument 'prediction' -``` - -**Solution:** Check loss function signature and match adapter: -```python -# If loss expects loss(pred, target) -def my_loss(pred, target): - ... - -# Adapter config (default is correct) -system: - adapters: - train: - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_argument: 0 - target_argument: 1 -``` - -**Error: KeyError in batch** -``` -KeyError: 'image' -``` - -**Solution:** Check dataset output: -```python -# Temporarily add to dataset -def __getitem__(self, idx): - batch = {...} - print(f"Batch keys: {batch.keys()}") # Debug - return batch -``` - -Then match `input_accessor` to actual key. - -## Performance Profiling - -### Identify Bottlenecks - -Use PyTorch Lightning profiler: - -```yaml -trainer: - profiler: simple # or 'advanced', 'pytorch' - max_epochs: 1 - limit_train_batches: 100 -``` - -Output shows time spent in each method: -``` -╒═══════════════════════╤═════════╤═════════╕ -│ Action │ Mean │ Total │ -╞═══════════════════════╪═════════╪═════════╡ -│ run_training_batch │ 0.105 │ 10.5 │ -│ get_train_batch │ 0.095 │ 9.5 │ -│ training_step │ 0.008 │ 0.8 │ -╘═══════════════════════╧═════════╧═════════╛ -``` - -If `get_train_batch` is slow: Increase `num_workers` in dataloader - -If `training_step` is slow: Profile model forward/loss computation - -### Data Loading Bottleneck - -**Symptom:** GPU underutilized, low GPU usage - -**Solution:** -```yaml -system: - dataloaders: - train: - num_workers: 8 # Increase (up to CPU cores) - prefetch_factor: 4 # Prefetch more batches - persistent_workers: true # Reuse workers - pin_memory: true # Faster GPU transfer -``` - -**Test different num_workers:** -```bash -for i in 2 4 8 16; do - echo "Testing num_workers=$i" - lighter fit config.yaml \ - system::dataloaders::train::num_workers=$i \ - trainer::limit_train_batches=100 \ - trainer::max_epochs=1 -done -``` - -### Model Bottleneck - -**Symptom:** High GPU usage, slow batches - -**Solutions:** - -1. **Mixed precision training:** -```yaml -trainer: - precision: "16-mixed" # ~2x speedup, less memory -``` - -2. **Gradient accumulation (simulate larger batch):** -```yaml -trainer: - accumulate_grad_batches: 4 # Update every 4 batches - -system: - dataloaders: - train: - batch_size: 8 # Effective: 8 * 4 = 32 -``` - -3. **Compile model (PyTorch 2.0+):** -```yaml -_requires_: - - "$import torch" - -system: - model: - _target_: torchvision.models.resnet50 - # Compile for faster execution - _post_init_: "$lambda m: torch.compile(m)" -``` - -## Memory Optimization - -### Beyond Reducing Batch Size - -**Strategy 1: Gradient Checkpointing** - -Trade compute for memory (recompute activations during backward): - -```python title="my_project/model.py" -from torch.utils.checkpoint import checkpoint_sequential - -class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.layers = nn.Sequential(*[ - ResidualBlock() for _ in range(50) - ]) - - def forward(self, x): - # Checkpoint every 10 layers - return checkpoint_sequential(self.layers, 10, x) -``` - -**Strategy 2: CPU Offloading** - -Move intermediate results to CPU: - -```yaml -system: - adapters: - val: - logging: - _target_: lighter.adapters.LoggingAdapter - # Move to CPU for callbacks - pred_transforms: - - "$lambda x: x.cpu()" -``` - -**Strategy 3: Clear Cache Periodically** - -```python title="my_project/custom_system.py" -from lighter.system import System -import torch - -class MemoryEfficientSystem(System): - def on_train_batch_end(self, outputs, batch, batch_idx): - # Clear cache every 100 batches - if batch_idx % 100 == 0: - torch.cuda.empty_cache() -``` - -**Strategy 4: Use Smaller Precision** - -```yaml -trainer: - precision: "bf16-mixed" # BFloat16 (less memory than fp16 in some cases) -``` - -## Debugging Strategies - -### Quick Testing -```bash -# Test with 2 batches only -lighter fit config.yaml trainer::fast_dev_run=2 -``` - -### Debug Config Values -```yaml -# Print values during config resolution -optimizer: - lr: "$print('LR:', 0.001) or 0.001" -``` - -### Check Adapter Outputs -Temporarily add print transforms in adapters: -```yaml -adapters: - train: - criterion: - pred_transforms: - - "$lambda x: print('Pred shape:', x.shape) or x" -``` - -### Debug Mode Checklist - -When encountering an error: - -1. **Test with fast_dev_run:** - ```bash - lighter fit config.yaml trainer::fast_dev_run=true - ``` - -2. **Verify config resolves:** - ```yaml - _requires_: - - "$import sys" - vars: - _debug: "$print('Config loaded successfully', file=sys.stderr)" - ``` - -3. **Check tensor shapes:** - Add print transforms in adapters (see Adapter Debugging above) - -4. **Isolate the component:** - Test individual components in Python: - ```python - from sparkwheel import Config - config = Config.from_file("config.yaml") - model = config.resolve("system::model") - print(model) # Does it instantiate correctly? - ``` - -5. **Check logs:** - Look for warnings/errors in the console output - -## Common Error Messages Reference - -### TypeError: unhashable type: 'dict' - -**Cause:** Passing a dict where Lighter expects a hashable key - -**Common scenario:** Using dict-based batch in metric that expects tensors - -**Solution:** Use MetricsAdapter to extract tensors: -```yaml -system: - adapters: - train: - metrics: - pred_transforms: - - "$lambda x: x['logits']" # Extract tensor from dict -``` - -### RuntimeError: Expected all tensors to be on the same device - -**Cause:** Model on GPU but data on CPU (or vice versa) - -**Solution:** Lighter handles this automatically. If you see this: -- Check custom transforms aren't moving data -- Ensure model is properly registered in System -- For manual operations, use `self.device`: - ```python - def custom_operation(self): - tensor = torch.tensor([1, 2, 3]).to(self.device) - ``` - -### ValueError: The loss dictionary must include a 'total' key - -**Cause:** Dict-based loss missing required 'total' key - -**Solution:** -```python -def my_criterion(pred, target): - loss1 = ... - loss2 = ... - return { - "total": loss1 + loss2, # Required! - "classification": loss1, - "segmentation": loss2, - } -``` - -### OSError: [Errno 24] Too many open files - -**Cause:** Too many num_workers, system limit reached - -**Solution:** -```bash -# Temporary fix (macOS/Linux) -ulimit -n 4096 - -# Or reduce num_workers -lighter fit config.yaml system::dataloaders::train::num_workers=4 -``` - -## Real Error Examples from Users - -### Example 1: Circular Reference - -**Error:** -``` -RecursionError: maximum recursion depth exceeded -``` - -**User's config:** -```yaml -system: - model: - _target_: MyModel - optimizer: "@system::optimizer" # ❌ Circular! - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" # ❌ Circular! -``` - -**Fix:** -```yaml -vars: - lr: 0.001 - -system: - model: - _target_: MyModel - lr: "%vars::lr" # Use var instead - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: "%vars::lr" -``` - -### Example 2: Wrong Loss Function for Task - -**Error:** Loss is negative or NaN - -**User's config:** -```yaml -# Binary classification with CrossEntropyLoss (wrong!) -system: - criterion: - _target_: torch.nn.CrossEntropyLoss # For multi-class! -``` - -**Fix:** -```yaml -# Use BCEWithLogitsLoss for binary -system: - criterion: - _target_: torch.nn.BCEWithLogitsLoss # For binary classification -``` - -### Example 3: Adapter Argument Mismatch - -**Error:** -``` -TypeError: __call__() got an unexpected keyword argument 'preds' -``` - -**User's config:** -```yaml -system: - metrics: - train: - - _target_: torchmetrics.Accuracy - # Expects 'preds' and 'target' kwargs - adapters: - train: - metrics: - _target_: lighter.adapters.MetricsAdapter - # Using positional (default) instead of kwargs -``` - -**Fix:** -```yaml -system: - adapters: - train: - metrics: - _target_: lighter.adapters.MetricsAdapter - pred_argument: "preds" # Match metric signature - target_argument: "target" -``` - -## Getting Help - -When asking for help, include: - -1. **Your config file** (or relevant section) -2. **Full error message** (including traceback) -3. **Lighter version:** `lighter --version` -4. **Python/PyTorch versions:** `python --version`, `python -c "import torch; print(torch.__version__)"` -5. **What you've tried** so far - -### Resources - -1. **Documentation:** Search this site -2. **FAQ:** [Common questions](../faq.md) -3. **Examples:** Check `projects/` directory in the repo -4. **PyTorch Lightning:** [Lightning docs](https://lightning.ai/docs/pytorch/stable/) for Trainer issues -5. **Discord:** [Join community](https://discord.gg/zJcnp6KrUp) -6. **GitHub Issues:** [Report bugs](https://github.com/project-lighter/lighter/issues) - -## Summary - -Most issues fall into these categories: - -| Category | Quick Fix | -|----------|-----------| -| Config syntax | Check YAML indentation, quotes, colons | -| References | Use `::` for config, `.` for Python | -| Memory | Reduce batch size, use mixed precision | -| Speed | Increase num_workers, enable pin_memory | -| DDP | Enable sync_batchnorm, use rank_zero_only | -| Adapters | Add print transforms to inspect shapes | -| Metrics | Check mode has metrics, verify adapter args | - -**Pro tip:** Most errors can be caught early with `trainer::fast_dev_run=true`! diff --git a/docs/how-to/writers.md b/docs/how-to/writers.md deleted file mode 100644 index 0c55f5e1..00000000 --- a/docs/how-to/writers.md +++ /dev/null @@ -1,327 +0,0 @@ -# Writers: Save Your Results Like a Pro - -Writers are your data persistence layer—they capture model outputs and save them in formats ready for analysis, visualization, or deployment. - -## Quick Start 🚀 - -```yaml -# Save predictions as images -trainer: - callbacks: - - _target_: lighter.callbacks.FileWriter - path: "outputs/predictions" - writer: "image" # PNG for 2D, MP4 for 3D - -# Save metrics to CSV -trainer: - callbacks: - - _target_: lighter.callbacks.TableWriter - path: "outputs/metrics.csv" -``` - -## Writer Types at a Glance - -| Writer | Purpose | Output Format | Best For | -|--------|---------|---------------|----------| -| **FileWriter** | Save predictions/tensors | PNG, MP4, PT | Images, videos, tensors | -| **TableWriter** | Save tabular data | CSV | Metrics, statistics, results | - -## Using `FileWriter` - -`FileWriter` callback saves tensors to files, supports various formats, and is customizable. - -**Configuration**: - -Configure `FileWriter` in `config.yaml` within `trainer.callbacks` section: - -```yaml title="config.yaml" -trainer: - callbacks: - - _target_: lighter.callbacks.FileWriter # Use the FileWriter callback - path: "outputs/predictions" # Directory to save output files - writer: "tensor" # Writer function to use -``` - -* **`_target_: lighter.callbacks.FileWriter`**: Specifies that you want to use the `FileWriter` callback. -* **`path: "outputs/predictions"`**: Defines the directory where the output files will be saved. Lighter will create this directory if it doesn't exist. -* **`writer: "tensor"`**: Specifies the writer function to be used for saving tensors. - -**Built-in Writer Functions**: - -`FileWriter` has built-in writer functions for different formats: - -* **`"tensor"`**: Raw PyTorch `.pt` files (general tensor saving). -* **`"image"`**: PNG images for 2D tensors. -* **`"video"`**: MP4 videos for 4D tensor time-series (CTHW format). - -**Usage**: - -Once configured, `FileWriter` is used by Lighter in validation, test, and predict stages (if enabled). - -In these stages, per batch, `FileWriter`: - -1. Receives `pred` tensor from `predict_step`, `validation_step`, or `test_step`. -2. Applies `LoggingAdapter` transforms (if configured). -3. Uses writer function (e.g., `"tensor"`) to save `pred` tensor to file in `path` dir. -4. Names file using batch `identifier` (if available) or generates unique name. - -**Example: Saving Predictions as Tensors** - -```yaml title="config.yaml" -trainer: - callbacks: - - _target_: lighter.callbacks.FileWriter - path: "outputs/predictions" - writer: "tensor" # Save as .pt files - -system: - # ... (other system configurations) ... - dataloaders: - val: - _target_: torch.utils.data.DataLoader - dataset: - _target_: my_project.datasets.MyDataset - root: "data/" - batch_size: 1 -``` - -Example config: `FileWriter` saves predictions during validation stage as PyTorch tensor files in `outputs/predictions` dir. - -## Extending `FileWriter` with Custom Writers - -Extend `FileWriter` by creating custom writer functions or classes for specific needs. - -**1. Create Custom Writer Function**: - -Define a custom writer function with two arguments: - -* `path`: Full file path for saving tensor (filename & extension). -* `tensor`: PyTorch tensor to save. - -**Example: Custom Writer Function for Text Files** - -```python title="my_project/writers/my_custom_writer.py" -import torch -import numpy as np - -def write_tensor_as_text(path: str, tensor: torch.Tensor): - """Saves tensor to text file.""" - tensor_numpy = tensor.cpu().numpy() # Convert to NumPy array - np.savetxt(path, tensor_numpy) # Save as text file -``` - -**2. Register Custom Writer Function**: - -Register custom writer function in `config.yaml` to use with `FileWriter`: - -```yaml title="config.yaml" -trainer: - callbacks: - - _target_: lighter.callbacks.FileWriter - path: "outputs/text_tensors" - writer: my_project.writers.my_custom_writer.write_tensor_as_text # Path to custom writer -``` - -* **`writer`**: Path to custom writer function. Replace `"my_project.writers.my_custom_writer"` with your module path. - -**3. Create Custom Writer Class (Advanced)**: - -For complex logic or stateful writers, create a custom class inheriting from `lighter.callbacks.writer.BaseWriter`. - -**Example: Custom Writer Class for Tensors with Metadata** - -```python title="my_project/writers/my_custom_writer_class.py" -from lighter.callbacks.writer import BaseWriter -import torch -import json -import os - -class MyCustomClassWriter(BaseWriter): - @property - def writers(self): - return {"tensor_with_metadata": self.write_tensor_with_metadata} # Register writer function - - def write(self, tensor: torch.Tensor, identifier: str): - """Main write method called by FileWriter.""" - path = os.path.join(self.path, f"{identifier}.json") # Define output path - self.write_tensor_with_metadata(path, tensor, identifier=identifier) # Call writer function - - def write_tensor_with_metadata(self, path: str, tensor: torch.Tensor, identifier: str): - """Saves tensor to JSON file with metadata.""" - metadata = { - "identifier": identifier, - "shape": list(tensor.shape), - "dtype": str(tensor.dtype), - "timestamp": datetime.datetime.now().isoformat() - } - data = { - "metadata": metadata, - "data": tensor.cpu().numpy().tolist() # Convert tensor data to list - } - with open(path, 'w') as f: - json.dump(data, f, indent=4) # Save data+metadata to JSON -``` - -`MyCustomClassWriter` class example: - -* Inherits from `BaseWriter`. -* `write_tensor_with_metadata`: Saves tensors to JSON with metadata (shape, dtype, timestamp). -* Registers writer function in `writers` property with key `"tensor_with_metadata"`. -* Overrides `write` method for filename generation and calling custom writer function. - -**4. Use Custom Writer Class in `config.yaml`**: - -Use custom writer class by specifying class module path and registered writer key in `config.yaml`: - -```yaml title="config.yaml" -trainer: - callbacks: - - _target_: lighter.callbacks.FileWriter - path: "outputs/metadata_tensors" - writer: my_project.writers.my_custom_writer_class.MyCustomClassWriter.tensor_with_metadata # Custom class writer -``` - -* **`writer`**: Path to custom writer class and registered writer key (`"tensor_with_metadata"`). - -## Using `TableWriter` - -`TableWriter` callback saves tabular data to CSV files, useful for logging metrics or aggregated predictions. - -**Configuration**: - -Configure `TableWriter` in `config.yaml` within `trainer.callbacks` section: - -```yaml title="config.yaml" -trainer: - callbacks: - - _target_: lighter.callbacks.TableWriter # Use the TableWriter callback - path: "outputs/metrics.csv" # Path to save CSV file -``` - -* **`_target_: lighter.callbacks.TableWriter`**: Use `TableWriter` callback. -* **`path: "outputs/metrics.csv"`**: CSV file path for saving tabular data. - -**Usage**: - -To use `TableWriter`, return a dictionary from `validation_step`, `test_step`, or `predict_step`. `TableWriter` saves key-value pairs from dict as CSV rows. - -**Example: Logging Metrics to CSV using `TableWriter`** - -```python title="config.yaml" -trainer: - callbacks: - - _target_: lighter.callbacks.TableWriter - path: "outputs/metrics.csv" # Save metrics to CSV - -system: - metrics: - val: - - _target_: torchmetrics.Accuracy - - _target_: torchmetrics.DiceCoefficient - - def validation_step(self, batch, batch_idx): - output = super().validation_step(batch, batch_idx) # Call base validation step - metrics = output[Data.METRICS] # Get computed metrics - self.log_dict(metrics) # Log metrics for display - return metrics # Return metrics dictionary for TableWriter -``` - -Example: `TableWriter` saves data to `outputs/metrics.csv`. In `validation_step`: - -* Call base class `validation_step` to compute metrics. -* Extract metrics from `output` dict. -* Return `metrics` dict from `validation_step`. - -`TableWriter` captures dict from `validation_step`/`test_step`/`predict_step`, saves as CSV rows. Dict keys become CSV column headers. - -## Custom Writer Example - -```python -# my_project/writers/visualization_writer.py -import matplotlib.pyplot as plt -from pathlib import Path - -class VisualizationWriter: - """Save comparison plots of input, target, and prediction.""" - def __init__(self, path): - self.path = Path(path) - self.path.mkdir(exist_ok=True) - - def write_comparison(self, input_img, target, prediction, identifier): - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - - axes[0].imshow(input_img.cpu().numpy().transpose(1, 2, 0)) - axes[0].set_title("Input") - - axes[1].imshow(target.cpu().numpy(), cmap='tab20') - axes[1].set_title("Ground Truth") - - axes[2].imshow(prediction.argmax(0).cpu().numpy(), cmap='tab20') - axes[2].set_title("Prediction") - - for ax in axes: - ax.axis('off') - - plt.tight_layout() - plt.savefig(self.path / f"{identifier}_comparison.png") - plt.close() -``` - -## Quick Reference 📄 - -### FileWriter Formats -```yaml -# Medical imaging -writer: "itk_nifti" # .nii.gz files -writer: "itk_nrrd" # .nrrd files -writer: "itk_seg_nrrd" # Segmentation masks - -# Standard formats -writer: "tensor" # NumPy .npy files -writer: "image" # PNG (2D) or MP4 (3D) -writer: "video" # MP4 for time series - -# Custom -writer: my_project.writers.custom_writer -``` - -### TableWriter Patterns -```python -# Return dict from step methods for TableWriter -def validation_step(self, batch, batch_idx): - # ... compute metrics ... - return { - "patient_id": batch["id"], - "dice_score": dice, - "loss": loss.item(), - "prediction_confidence": pred.max() - } -``` - -## Common Issues & Solutions - -| Issue | Solution | -|-------|----------| -| **Permission denied** | Create output directory: `Path("outputs").mkdir(exist_ok=True)` | -| **Out of disk space** | Use compression or write less frequently | -| **Slow writing** | Reduce precision to FP16 or use async writing | - -## Recap and Next Steps - -✅ **You've Learned:** -- Use FileWriter for predictions and tensors -- Use TableWriter for metrics and results -- Create custom writers for special needs -- Optimize writing for performance - -🎯 **Best Practices:** -- Organize outputs hierarchically -- Use compression for large outputs -- Consider async writing for speed -- Save metadata with predictions - -💡 **Pro Tip:** Always save enough information to reproduce your results! - -## Related Guides -- [Adapters](adapters.md) - Transform before writing -- [Inferers](inferers.md) - Write inference results diff --git a/docs/index.md b/docs/index.md index 154b6ead..59d901f2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,431 +33,333 @@ pip install lighter
- +**YAML configuration for PyTorch Lightning experiments**
-- :material-rocket-launch:{ .lg .middle } __From Idea to Experiment in Seconds__ +- :material-rocket-launch:{ .lg .middle } **Fast Iteration** --- - No boilerplate. No training loops. Just define your model, data, and optimizer in YAML and run `lighter fit config.yaml`. + Change hyperparameters from CLI without editing code. -- :material-refresh:{ .lg .middle } __100% Reproducible__ - - --- - - Every experiment is a YAML file. Version control configs like code. Share experiments with collaborators. No hidden state. + ```bash + lighter fit config.yaml model::learning_rate=0.01 + ``` -- :material-tune:{ .lg .middle } __Hyperparameter Sweeps Made Easy__ +- :material-refresh:{ .lg .middle } **Reproducible** --- - Override any parameter from CLI: `lighter fit config.yaml system::optimizer::lr=0.01`. Run 100 experiments without editing files. + One YAML file = one experiment. Version control configs like code. -- :material-puzzle-outline:{ .lg .middle } __Task-Agnostic Adapters__ +- :material-lightning-bolt:{ .lg .middle } **Pure Lightning** --- - Classification, segmentation, or self-supervised learning? Adapters handle any data format. One system, unlimited tasks. + Use any LightningModule. Full PyTorch Lightning power. Zero lock-in. -- :material-feather:{ .lg .middle } __~1,000 Lines of Code__ +
- --- +## What is Lighter? - Read the entire framework in an afternoon. Debug easily. Understand exactly what's happening. No magic. +Lighter runs PyTorch Lightning experiments from YAML configs instead of hardcoded Python values. -- :material-lightning-bolt:{ .lg .middle } __Built on PyTorch Lightning__ +**You write Lightning code. Lighter handles configuration.** - --- +```python title="model.py" +import pytorch_lightning as pl - Multi-GPU, mixed precision, gradient accumulation, profiling—all Lightning features work out of the box. +class MyModule(pl.LightningModule): + def __init__(self, learning_rate=0.001): + super().__init__() + self.lr = learning_rate + # ... your model code ... - + def training_step(self, batch, batch_idx): + # ... your training logic ... + return loss +``` -## Quick Start: 60 Seconds +```yaml title="config.yaml" +model: + _target_: project.model.MyModule # Auto-discovered with __lighter__.py + learning_rate: 0.001 -
+trainer: + max_epochs: 10 +``` -1. **Install Lighter** +```bash +# Run it +lighter fit config.yaml - ```bash - pip install lighter - ``` +# Override from CLI +lighter fit config.yaml model::learning_rate=0.01 +``` -2. **Create a config** (`config.yaml`) +## Two Approaches, Same Power - ```yaml - trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 +Choose the approach that fits your workflow: - system: - _target_: lighter.System - model: - _target_: torchvision.models.resnet18 - num_classes: 10 - criterion: - _target_: torch.nn.CrossEntropyLoss - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - dataloaders: - train: # (1)! - _target_: torch.utils.data.DataLoader - batch_size: 32 - dataset: - _target_: torchvision.datasets.CIFAR10 - root: ./data - train: true - download: true - transform: - _target_: torchvision.transforms.ToTensor - ``` +
- 1. Define your data like any PyTorch component +
-3. **Run training** +### :material-code-braces: LightningModule - ```bash - lighter fit config.yaml - ``` +**Best for:** -That's it. Automatic training loops, validation, checkpointing, and logging. +- Existing Lightning projects +- Custom training logic +- Full control over everything -
+**You write:** -!!! tip "Experiment with different hyperparameters" - ```bash - # Change learning rate without editing files - lighter fit config.yaml system::optimizer::lr=0.01 +- All step methods +- `configure_optimizers()` +- Your own logging - # Train longer - lighter fit config.yaml trainer::max_epochs=100 +**Lighter adds:** - # Use multiple GPUs - lighter fit config.yaml trainer::devices=4 - ``` +- YAML configuration +- CLI overrides +- Experiment tracking +[Learn more →](guides/lightning-module.md) -## Lighter vs. PyTorch Lightning +
-!!! abstract "Same Power, Different Interface" - Lighter uses PyTorch Lightning under the hood. You get all Lightning features (multi-GPU, callbacks, profilers) but define experiments in YAML instead of Python classes. +
-See how training a model on CIFAR-10 differs: +### :material-auto-fix: LighterModule -=== "Lighter" - ```bash title="Terminal" - lighter fit config.yaml - ``` +**Best for:** - ```yaml title="config.yaml" - trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 2 +- New projects +- Standard workflows +- Less boilerplate - system: - _target_: lighter.System +**You write:** - model: - _target_: torchvision.models.resnet18 - num_classes: 10 +- Step implementations only +- Your model's forward logic - criterion: - _target_: torch.nn.CrossEntropyLoss +**Lighter adds:** - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 +- Automatic `configure_optimizers()` +- Dual logging (step + epoch) +- Config-driven everything - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 32 - shuffle: true - dataset: - _target_: torchvision.datasets.CIFAR10 - download: true - root: .datasets - train: true - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - ``` +[Learn more →](guides/lighter-module.md) - **Benefits:** +
- - :material-check: Experiment is self-documenting - - :material-check: Change hyperparameters from CLI without editing files - - :material-check: Version control and compare configs with git diff - - :material-check: Share experiments as single files +
-=== "PyTorch Lightning" - ```bash title="Terminal" - python cifar10.py - ``` +!!! tip "You can switch anytime" + Both approaches use the same config system. Start with one, switch to the other by changing `_target_`. No code rewrite needed. - ```py title="cifar10.py" - from pytorch_lightning import Trainer, LightningModule - from torch.nn import CrossEntropyLoss - from torch.optim import Adam - from torch.utils.data import DataLoader - from torchvision.models import resnet18 - from torchvision.datasets import CIFAR10 - from torchvision.transforms import ToTensor, Normalize, Compose +## Quick Comparison +=== "LightningModule" - class Model(LightningModule): - def __init__(self): - super().__init__() - self.model = resnet18(num_classes=10) - self.criterion = CrossEntropyLoss() + ```python title="model.py" + import torch + import torch.nn.functional as F + import pytorch_lightning as pl - def forward(self, x): - return self.model(x) + class MyModule(pl.LightningModule): + def __init__(self, network, learning_rate=0.001): + super().__init__() + self.network = network + self.lr = learning_rate def training_step(self, batch, batch_idx): x, y = batch - y_hat = self(x) - loss = self.criterion(y_hat, y) + loss = F.cross_entropy(self.network(x), y) + self.log("train/loss", loss) return loss def configure_optimizers(self): - return Adam(self.model.parameters(), lr=0.001) - - - transform = Compose([ - ToTensor(), - Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) - ]) - - train_dataset = CIFAR10( - root=".datasets", - train=True, - download=True, - transform=transform - ) - - train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) - - model = Model() - trainer = Trainer(max_epochs=2) - trainer.fit(model, train_loader) + return torch.optim.Adam(self.parameters(), lr=self.lr) ``` - **Challenges:** - - - :material-close: Need to edit Python code for hyperparameter changes - - :material-close: Harder to compare experiments (code vs config) - - :material-close: More boilerplate for each experiment + ```yaml title="config.yaml" + trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 ---- + model: + _target_: project.model.MyModule # project.file.Class + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + learning_rate: 0.001 + + data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + ``` -## Who Should Use Lighter? + ```bash + lighter fit config.yaml + ``` -
+=== "LighterModule" -
+ ```python title="model.py" + from lighter import LighterModule -### :material-check-circle:{ .green } **Perfect For** + class MyModel(LighterModule): + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) -- **Researchers** running many experiments with hyperparameter variations -- **Teams** sharing reproducible experiments and baselines -- **Engineers** who value configuration over code for ML pipelines -- **Anyone** tired of writing boilerplate training loops + if self.train_metrics: + self.train_metrics(pred, y) -[Get Started →](tutorials/get-started.md){ .md-button .md-button--primary } + return {"loss": loss} -
+ def validation_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) -
+ if self.val_metrics: + self.val_metrics(pred, y) -### :material-information:{ .blue } **Consider Alternatives If** + return {"loss": loss} + ``` -- You need highly custom training loops with exotic logic -- You prefer pure Python workflows without YAML -- You're doing rapid prototyping where code is faster than config -- Your project has few experimental variations + ```yaml title="config.yaml" + trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 -[Compare Frameworks →](design/overview.md#framework-comparison){ .md-button } + model: + _target_: project.model.MyModel # project.file.Class + network: + _target_: torchvision.models.resnet18 + num_classes: 10 + criterion: + _target_: torch.nn.CrossEntropyLoss + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + val_metrics: "%model::train_metrics" + + data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + download: true + ``` -
+ ```bash + lighter fit config.yaml + ``` -
+## Why Lighter? ---- +### Reproducibility -## Key Features in Depth +One YAML = one experiment. Version control, share, compare. -### Configuration-Driven Everything +```bash +git diff experiment_v1.yaml experiment_v2.yaml +``` -Every component is defined in YAML. Model, optimizer, scheduler, metrics, data—all configurable. +See exactly what changed between experiments. -```yaml -# Differential learning rates? Easy. -optimizer: - _target_: torch.optim.SGD - params: - - params: "$@system::model.backbone.parameters()" - lr: 0.0001 # Low LR for pretrained backbone - - params: "$@system::model.head.parameters()" - lr: 0.01 # High LR for new head -``` +### Fast Iteration -[Learn Config Syntax →](how-to/configuration.md) +Override any config value from CLI: -### Task-Agnostic Adapters +```bash +# Change learning rate +lighter fit config.yaml model::learning_rate=0.01 -Adapters transform data between pipeline stages. This makes Lighter work for **any** task. +# Use more GPUs +lighter fit config.yaml trainer::devices=4 -```yaml -# Dict-based dataset? No problem. -system: - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: "image" # Extract from dict - target_accessor: "label" +# Combine multiple changes +lighter fit config.yaml model::learning_rate=0.01 trainer::max_epochs=100 ``` -Classification, segmentation, detection, self-supervised learning—adapters handle it all. +### No Lock-In -[Learn About Adapters →](how-to/adapters.md) +Lighter is a thin layer over PyTorch Lightning: -### Built on Solid Foundations +- Use **any** LightningModule +- Use **any** Lightning callback +- Use **any** Lightning logger +- Switch back to pure Lightning anytime -- **PyTorch Lightning** - Battle-tested training engine with multi-GPU, profiling, callbacks -- **[Sparkwheel](https://project-lighter.github.io/sparkwheel/)** - Powerful config system with references, expressions, and validation -- **~1,000 lines** - Read the entire framework, understand exactly what's happening +## Installation -[Architecture Deep Dive →](design/overview.md) +```bash +pip install lighter +``` ---- +## Get Started -## Choose Your Path +Ready to try it? Pick your path:
-- :material-school:{ .lg .middle } __New to Lighter?__ +- :material-rocket:{ .lg .middle } **Quick Start** --- - **Start here:** Follow our comprehensive tutorial from installation to running your first experiments. + Get a model training in 10 minutes. - Time: 15 minutes + [:octicons-arrow-right-24: Quick Start](quickstart.md) - [:octicons-arrow-right-24: Get Started Tutorial](tutorials/get-started.md) - -- :material-lightning-bolt:{ .lg .middle } __PyTorch Lightning User?__ +- :material-book-open-variant:{ .lg .middle } **Complete Examples** --- - **Migration guide:** Translate your existing Lightning code to Lighter configs in minutes. - - Time: 10 minutes + Full, working code you can copy-paste. - [:octicons-arrow-right-24: Migration Guide](migration/from-pytorch-lightning.md) + [:octicons-arrow-right-24: Examples](examples/image-classification.md) -- :material-book-open-variant:{ .lg .middle } __Learn the Syntax__ +- :material-school:{ .lg .middle } **Guides** --- - **Configuration reference:** Master Sparkwheel syntax: `_target_`, references (`@` and `%`), expressions (`$`), and path notation (`::`). - - Time: 20 minutes - - [:octicons-arrow-right-24: Configuration Guide](how-to/configuration.md) - -- :material-code-braces:{ .lg .middle } __Ready-to-Use Examples__ - - --- - - **Recipes & patterns:** Copy-paste configs for common scenarios and best practices. - - Time: 5 minutes per recipe + Task-focused how-to guides. - [:octicons-arrow-right-24: View Recipes](how-to/recipes.md) - -- :material-puzzle-outline:{ .lg .middle } __Understand Adapters__ - - --- - - **Core concept:** Learn how adapters make Lighter task-agnostic and infinitely flexible. - - Time: 15 minutes - - [:octicons-arrow-right-24: Adapter Pattern](how-to/adapters.md) - -- :material-lightbulb:{ .lg .middle } __Architecture & Philosophy__ - - --- - - **Deep dive:** Understand the design decisions and how Lighter works internally. - - Time: 30 minutes - - [:octicons-arrow-right-24: Design Overview](design/overview.md) + [:octicons-arrow-right-24: Guides](guides/configuration.md)
---- - -## Community & Support - -
- -
- -### :material-account-group: Get Help +## Community -[:fontawesome-brands-discord: Discord](https://discord.gg/zJcnp6KrUp) - Chat with the community +- [:fontawesome-brands-discord: Discord](https://discord.gg/zJcnp6KrUp) - Get help, share configs +- [:fontawesome-brands-github: GitHub](https://github.com/project-lighter/lighter) - Report issues, contribute +- [:material-file-document: Paper](https://joss.theoj.org/papers/10.21105/joss.08101) - Cite us -[:material-frequently-asked-questions: FAQ](faq.md) - Common questions answered +## What Next? -[:material-bug: GitHub Issues](https://github.com/project-lighter/lighter/issues) - For bugs and features - -[:material-book-open-page-variant: Troubleshooting](how-to/troubleshooting.md) - Common problems - -
- -
- -### :material-star: Contribute - -[:fontawesome-brands-github: GitHub](https://github.com/project-lighter/lighter) - Star the repo - -[:material-file-document-edit: Documentation](https://github.com/project-lighter/lighter/tree/main/docs) - Improve the docs - -[:material-code-tags: Examples](https://github.com/project-lighter/lighter/tree/main/projects) - Share your configs - -
- -
- - -## Cite - -If you find it useful, please cite our [*Journal of Open Source Software* paper](https://joss.theoj.org/papers/10.21105/joss.08101): - -```bibtex -@article{lighter, - doi = {10.21105/joss.08101}, - url = {https://doi.org/10.21105/joss.08101}, - year = {2025}, - publisher = {The Open Journal}, - volume = {10}, - number = {111}, - pages = {8101}, - author = {Hadzic, Ibrahim and Pai, Suraj and Bressem, Keno and Foldyna, Borek and Aerts, Hugo JWL}, - title = {Lighter: Configuration-Driven Deep Learning}, - journal = {Journal of Open Source Software} -} -``` +- [**Quick Start** - 10 minutes to running model](quickstart.md) +- [**Configuration Guide** - Learn the syntax](guides/configuration.md) +- [**FAQ** - Common questions](faq.md) diff --git a/docs/migration/from-pytorch-lightning.md b/docs/migration/from-pytorch-lightning.md deleted file mode 100644 index ea9056c3..00000000 --- a/docs/migration/from-pytorch-lightning.md +++ /dev/null @@ -1,154 +0,0 @@ -# Migrating from PyTorch Lightning - -Quick guide for PyTorch Lightning users transitioning to Lighter. - -## Key Difference: Configuration Over Code - -Lighter uses YAML configs (powered by [Sparkwheel](https://project-lighter.github.io/sparkwheel/)) instead of Python classes for experiment definition. - -## Conceptual Mapping - -| PyTorch Lightning | Lighter | -|-------------------|---------| -| `LightningModule` | `System` + YAML config | -| `Trainer` | `Trainer` (same, from PL) | -| `training_step()` | Handled by `System` | -| `validation_step()` | Handled by `System` | -| `configure_optimizers()` | Optimizer in YAML | -| Custom callbacks | Same (PL callbacks work) | -| Loggers | Same (PL loggers work) | - -## Simple Example - -### Before (PyTorch Lightning) -```python -class LitModel(LightningModule): - def __init__(self): - super().__init__() - self.model = resnet18(num_classes=10) - self.criterion = CrossEntropyLoss() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = self.criterion(y_hat, y) - return loss - - def configure_optimizers(self): - return Adam(self.parameters(), lr=0.001) - -trainer = Trainer(max_epochs=10) -trainer.fit(model, train_loader) -``` - -### After (Lighter) -```yaml -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - - model: - _target_: torchvision.models.resnet18 - num_classes: 10 - - criterion: - _target_: torch.nn.CrossEntropyLoss - - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - - dataloaders: - train: ... # DataLoader config -``` - -```bash -lighter fit config.yaml -``` - -**Key insight:** Same Trainer, same training logic, different interface. - -## What You Need to Learn - -Only 3 things are Lighter-specific: - -1. **Sparkwheel Configuration Syntax** - [Configuration Guide](../how-to/configuration.md) | [Sparkwheel docs](https://project-lighter.github.io/sparkwheel/) -2. **Adapters** (Lighter's key feature) - [Adapters Guide](../how-to/adapters.md) -3. **Project Module** (optional) - [Project Module Guide](../how-to/project_module.md) - -## What Stays the Same - -Everything else is PyTorch Lightning: - -- **Trainer arguments** - [PL Trainer docs](https://lightning.ai/docs/pytorch/stable/common/trainer.html) -- **Callbacks** - [PL Callback docs](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html) -- **Loggers** - [PL Logger docs](https://lightning.ai/docs/pytorch/stable/extensions/logging.html) -- **Distributed training** - [PL distributed docs](https://lightning.ai/docs/pytorch/stable/common/trainer.html#devices) -- **Profiling** - [PL profiler docs](https://lightning.ai/docs/pytorch/stable/tuning/profiler.html) - -## Common Migration Patterns - -### Custom Model -Your `nn.Module` works as-is: -```yaml -system: - model: - _target_: my_project.models.MyCustomModel - arg1: value1 -``` - -### Custom Dataset -Your PyTorch `Dataset` works directly: -```yaml -system: - dataloaders: - train: - _target_: torch.utils.data.DataLoader - dataset: - _target_: my_project.datasets.MyDataset - data_path: /path/to/data -``` - -### Custom Callbacks -PL callbacks work without modification: -```yaml -trainer: - callbacks: - - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: val_loss - patience: 5 - - _target_: my_project.callbacks.MyCustomCallback - arg: value -``` - -### Learning Rate Schedulers -```yaml -system: - scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - optimizer: "@system::optimizer" - factor: 0.5 - patience: 10 -``` - -## When NOT to Migrate - -Lighter might not fit if: - -- You need highly custom training loops (stick with PL or PyTorch) -- You prefer writing code over configuration -- Your project doesn't run many experimental variations - -## Next Steps - -1. Start with the [Zero to Hero tutorial](../tutorials/get-started.md) -2. Try the [Image Classification Tutorial](../tutorials/get-started.md) -3. Understand [Design Philosophy](../design/philosophy.md) -4. Learn about [Adapters](../how-to/adapters.md) (Lighter's superpower) diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 00000000..7fc2e236 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,421 @@ +--- +title: Quick Start +--- + +# Quick Start + +Get a model training in 10 minutes. + +## 1. Install + +```bash +pip install lighter +``` + +## 2. Create Your Project + +Lighter uses a **project folder** pattern for organizing your code. This makes it easy to reference your custom models and datasets. + +### Step 1: Create the Project Structure + +```bash +mkdir mnist_classifier +cd mnist_classifier +``` + +Create these files: + +``` +mnist_classifier/ +├── __lighter__.py # Marker file (tells Lighter this is a project) +├── __init__.py # Makes it a Python package +├── model.py # Your model code +└── configs/ + └── config.yaml # Your experiment config +``` + +### Step 2: Add the Marker Files + +**`__lighter__.py`** (can be empty): +```python +# This file marks the directory as a Lighter project. +# When you run `lighter` from this directory, you can reference +# your code as `project.module.Class` +``` + +**`__init__.py`** (can be empty): +```python +# Makes this directory a Python package +``` + +!!! tip "The `project` Prefix" + Once you have `__lighter__.py`, Lighter auto-discovers your folder and makes it available as `project`. This means `model.py` becomes `project.model`, and you can reference classes as `project.model.ClassName`. + +## 3. Choose Your Approach + +Lighter works with **any** PyTorch Lightning module. Pick the style that fits your workflow: + +- **LightningModule** - Use existing Lightning code, add configs +- **LighterModule** - Less boilerplate, automatic logging + +=== "LightningModule" + + ### Write Your Module + + **`model.py`**: + + ```python + import pytorch_lightning as pl + import torch + import torch.nn.functional as F + + class MNISTModule(pl.LightningModule): + def __init__(self, learning_rate=0.001): + super().__init__() + self.lr = learning_rate + self.model = torch.nn.Sequential( + torch.nn.Linear(28 * 28, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 10) + ) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + + def training_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self(x), y) + self.log("train/loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self(x), y) + acc = (self(x).argmax(dim=1) == y).float().mean() + self.log("val/loss", loss) + self.log("val/acc", acc) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.lr) + ``` + + ### Create Config + + **`configs/config.yaml`**: + + ```yaml + trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 3 + accelerator: auto + + model: + _target_: project.model.MNISTModule # project.file.Class + learning_rate: 0.001 + + data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 64 + shuffle: true + dataset: + _target_: torchvision.datasets.MNIST + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.ToTensor + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 64 + dataset: + _target_: torchvision.datasets.MNIST + root: ./data + train: false + download: true + transform: + _target_: torchvision.transforms.ToTensor + ``` + + ### Run + + ```bash + lighter fit configs/config.yaml + ``` + + **That's it!** You now have: + + - ✅ Training and validation loops + - ✅ Automatic logging + - ✅ Checkpointing + - ✅ Progress bars + +=== "LighterModule" + + ### Write Your Module + + **`model.py`**: + + ```python + from lighter import LighterModule + import torch + + class MNISTModel(LighterModule): + def training_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) # Flatten + pred = self(x) + loss = self.criterion(pred, y) + + if self.train_metrics: + self.train_metrics(pred, y) + + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) # Flatten + pred = self(x) + loss = self.criterion(pred, y) + + if self.val_metrics: + self.val_metrics(pred, y) + + return {"loss": loss} + ``` + + ### Create Config + + **`configs/config.yaml`**: + + ```yaml + trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 3 + accelerator: auto + + model: + _target_: project.model.MNISTModel # project.file.Class + + network: + _target_: torch.nn.Sequential + _args_: + - _target_: torch.nn.Linear + in_features: 784 + out_features: 128 + - _target_: torch.nn.ReLU + - _target_: torch.nn.Linear + in_features: 128 + out_features: 10 + + criterion: + _target_: torch.nn.CrossEntropyLoss + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + + train_metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + + val_metrics: "%model::train_metrics" + + data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 64 + shuffle: true + dataset: + _target_: torchvision.datasets.MNIST + root: ./data + train: true + download: true + transform: + _target_: torchvision.transforms.ToTensor + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 64 + dataset: + _target_: torchvision.datasets.MNIST + root: ./data + train: false + download: true + transform: + _target_: torchvision.transforms.ToTensor + ``` + + ### Run + + ```bash + lighter fit configs/config.yaml + ``` + + **That's it!** You get everything from the LightningModule approach PLUS: + + - ✅ Automatic `configure_optimizers()` + - ✅ Dual logging (step + epoch) + - ✅ Config-driven metrics + +## 4. Experiment! + +Now that it's running, try changing things from the CLI: + +```bash +# Change learning rate +lighter fit configs/config.yaml model::lr=0.01 + +# Train longer +lighter fit configs/config.yaml trainer::max_epochs=10 + +# Larger batch size +lighter fit configs/config.yaml data::train_dataloader::batch_size=128 + +# Combine multiple changes +lighter fit configs/config.yaml \ + model::lr=0.01 \ + trainer::max_epochs=10 \ + data::train_dataloader::batch_size=128 +``` + +No file editing needed! + +## Understanding the Config + +Let's break down what's happening: + +### The Three Keys + +Every Lighter config has three main sections: + +```yaml +trainer: # How to run (PyTorch Lightning Trainer) +model: # What to run (your LightningModule) +data: # What data to use (DataModule) +``` + +### The `_target_` Pattern + +`_target_` tells Lighter what class to instantiate: + +```yaml +model: + _target_: project.model.MNISTModule # Create instance of this class + learning_rate: 0.001 # Pass as __init__ argument +``` + +This is equivalent to: + +```python +model = MNISTModule(learning_rate=0.001) +``` + +### The `project.` Prefix + +When you have `__lighter__.py` in your folder, Lighter auto-discovers it and makes it available as `project`: + +```yaml +# Your file: model.py +# Your class: MNISTModule +# Reference as: project.model.MNISTModule + +_target_: project.model.MNISTModule +``` + +This pattern keeps all your custom code organized and easy to reference. + +### References + +Use `@` to reference resolved values (after instantiation): + +```yaml +optimizer: + params: "$@model::network.parameters()" # Call method on instantiated network +``` + +Use `%` to copy config (for creating new instances): + +```yaml +val_metrics: "%model::train_metrics" # New instance with same config +``` + +⚠️ **Important:** Use `%` not `@` for metrics (they're stateful and need separate instances) + +## What You Just Learned + +- ✅ How to create a Lighter project with `__lighter__.py` +- ✅ How to reference your code as `project.*` +- ✅ Two ways to use Lighter (LightningModule or LighterModule) +- ✅ Basic config structure (trainer/model/data) +- ✅ How to run experiments +- ✅ How to override from CLI +- ✅ Key config syntax (`_target_`, `@`, `%`) + +## Next Steps + +### Learn More Config Syntax + +[Configuration Guide](guides/configuration.md) - Master Sparkwheel syntax + +### See Complete Examples + +[Image Classification](examples/image-classification.md) - Full CIFAR-10 example with all the bells and whistles + +### Organize Your Project + +[Custom Code Guide](guides/custom-code.md) - Best practices for structuring larger projects + +### Best Practices + +[Best Practices](guides/best-practices.md) - Production patterns + +## Common Next Questions + +**Q: How do I add custom datasets or transforms?** +A: Put them in separate files (e.g., `data.py`, `transforms.py`) and reference as `project.data.MyDataset` or `project.transforms.MyTransform` + +**Q: Can I use multiple GPUs?** +A: Yes! Just add `trainer::devices=-1` (all GPUs) or `trainer::devices=4` (4 GPUs) + +**Q: How do I save predictions?** +A: Use Writers - see [Training Guide](guides/training.md#saving-predictions) + +**Q: Something not working?** +A: Check [FAQ](faq.md) or [join Discord](https://discord.gg/zJcnp6KrUp) + +## Quick Reference + +```yaml +# Essential config structure +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 + +model: + _target_: project.model.MyModule # Your Lightning module + # ... your module's __init__ args ... + +data: + _target_: lighter.LighterDataModule + train_dataloader: ... + val_dataloader: ... +``` + +```bash +# Essential CLI commands +lighter fit config.yaml # Train +lighter validate config.yaml # Validate only +lighter test config.yaml # Test only +lighter predict config.yaml # Inference + +# Override any config value +lighter fit config.yaml key::path=value +``` + +[→ Full CLI Reference](reference/cli.md) diff --git a/docs/reference/cli.md b/docs/reference/cli.md new file mode 100644 index 00000000..0a5104cb --- /dev/null +++ b/docs/reference/cli.md @@ -0,0 +1,558 @@ +--- +title: CLI Reference +--- + +# CLI Reference + +Complete reference for Lighter's command-line interface. + +## Commands + +Lighter provides four main commands: + +```bash +lighter fit # Train and validate +lighter validate # Validate only +lighter test # Test only +lighter predict # Run inference +``` + +All commands use the same configuration system. + +## lighter fit + +Train your model with automatic validation. + +### Basic Usage + +```bash +lighter fit CONFIG [OPTIONS] [OVERRIDES...] +``` + +### Arguments + +| Argument | Description | Required | +|----------|-------------|----------| +| `CONFIG` | Path to YAML config file | Yes | +| `--ckpt_path PATH` | Checkpoint path to resume from ("last", "best", or file path) | No | +| `--weights_only BOOL` | Load only weights (security option) | No | +| `OVERRIDES` | Config overrides (key::path=value) | No | + +### Examples + +```bash +# Basic training +lighter fit config.yaml + +# Resume from checkpoint +lighter fit config.yaml --ckpt_path checkpoints/last.ckpt + +# With config overrides +lighter fit config.yaml model::optimizer::lr=0.01 + +# Multiple configs +lighter fit base.yaml,experiment.yaml + +# Combine CLI flags and overrides +lighter fit config.yaml --ckpt_path last trainer::max_epochs=100 +``` + +### Config Structure + +```yaml +trainer: + _target_: pytorch_lightning.Trainer + # ... trainer args ... + +model: + _target_: your.Module + # ... model args ... + +data: + _target_: lighter.LighterDataModule + # ... data args ... +``` + +### Output + +Creates directory structure: + +``` +outputs/ +└── YYYY-MM-DD/ + └── HH-MM-SS/ + ├── config.yaml # Config used + ├── checkpoints/ + │ └── last.ckpt # Latest checkpoint + └── logs/ # Training logs +``` + +## lighter validate + +Run validation on a trained model. + +### Basic Usage + +```bash +lighter validate CONFIG [OPTIONS] [OVERRIDES...] +``` + +### Arguments + +| Argument | Description | Required | +|----------|-------------|----------| +| `CONFIG` | Path to YAML config file | Yes | +| `--ckpt_path PATH` | Checkpoint path for validation ("last", "best", or file path) | No | +| `--verbose BOOL` | Print validation results (default: True) | No | +| `--weights_only BOOL` | Load only weights (security option) | No | +| `OVERRIDES` | Config overrides | No | + +### Examples + +```bash +# Validate with checkpoint +lighter validate config.yaml --ckpt_path checkpoints/best.ckpt + +# Override config +lighter validate config.yaml \ + --ckpt_path checkpoints/best.ckpt \ + data::val_dataloader::batch_size=128 +``` + +### Config Structure + +Same as `fit` command. + +### Requirements + +- Checkpoint file (`.ckpt`) +- `val_dataloader` in data config +- `validation_step` in your module + +## lighter test + +Run test on a trained model. + +### Basic Usage + +```bash +lighter test CONFIG [OPTIONS] [OVERRIDES...] +``` + +### Arguments + +| Argument | Description | Required | +|----------|-------------|----------| +| `CONFIG` | Path to YAML config file | Yes | +| `--ckpt_path PATH` | Checkpoint path for testing ("last", "best", or file path) | No | +| `--verbose BOOL` | Print test results (default: True) | No | +| `--weights_only BOOL` | Load only weights (security option) | No | +| `OVERRIDES` | Config overrides | No | + +### Examples + +```bash +# Test with checkpoint +lighter test config.yaml --ckpt_path checkpoints/best.ckpt + +# Multiple test sets +lighter test config.yaml \ + --ckpt_path checkpoints/best.ckpt \ + data::test_dataloader::dataset::root=./test_data +``` + +### Config Structure + +Same as `fit` command, with a test dataloader: + +```yaml +data: + test_dataloader: + _target_: torch.utils.data.DataLoader + # ... test dataloader config ... +``` + +### Requirements + +- Checkpoint file (`.ckpt`) +- `test_dataloader` in data config +- `test_step` in your module + +## lighter predict + +Run inference on data. + +### Basic Usage + +```bash +lighter predict CONFIG [OPTIONS] [OVERRIDES...] +``` + +### Arguments + +| Argument | Description | Required | +|----------|-------------|----------| +| `CONFIG` | Path to YAML config file | Yes | +| `--ckpt_path PATH` | Checkpoint path for predictions ("last", "best", or file path) | No | +| `--return_predictions BOOL` | Whether to return predictions (default: True) | No | +| `--weights_only BOOL` | Load only weights (security option) | No | +| `OVERRIDES` | Config overrides | No | + +### Examples + +```bash +# Basic prediction +lighter predict config.yaml --ckpt_path checkpoints/best.ckpt + +# With writer to save results +lighter predict config.yaml \ + --ckpt_path checkpoints/best.ckpt \ + 'trainer::callbacks=[{_target_: lighter.callbacks.CSVWriter}]' +``` + +### Config Structure + +Same as `fit` command, with a predict dataloader: + +```yaml +data: + predict_dataloader: + _target_: torch.utils.data.DataLoader + # ... predict dataloader config ... + +trainer: + callbacks: + - _target_: lighter.callbacks.CSVWriter # Optional: save predictions + write_interval: batch +``` + +### Requirements + +- Checkpoint file (`.ckpt`) +- `predict_dataloader` in data config +- `predict_step` in your module +- Optional: Writer callback to save results + +## Config Overrides + +Override any config value from command line. + +### Syntax + +```bash +lighter COMMAND config.yaml key::path=value +``` + +### Examples + +#### Simple Values + +```bash +# Numbers +lighter fit config.yaml model::optimizer::lr=0.01 + +# Strings +lighter fit config.yaml trainer::logger::name=my_experiment + +# Booleans +lighter fit config.yaml trainer::enable_checkpointing=false +``` + +#### Nested Values + +```bash +# Deep nesting +lighter fit config.yaml \ + model::optimizer::lr=0.01 \ + model::optimizer::weight_decay=0.0001 \ + model::network::num_classes=100 +``` + +#### Lists + +```bash +# Python list syntax +lighter fit config.yaml 'trainer::devices=[0,1,2,3]' +``` + +#### Objects + +```bash +# YAML object syntax +lighter fit config.yaml \ + 'trainer::callbacks=[{_target_: pytorch_lightning.callbacks.EarlyStopping, monitor: val/loss, patience: 10}]' +``` + +### Path Syntax + +Use `::` to navigate config hierarchy: + +```yaml +# Config structure +model: + optimizer: + lr: 0.001 +``` + +```bash +# Override +lighter fit config.yaml model::optimizer::lr=0.01 +``` + +## Config Merging + +Combine multiple config files. + +### Syntax + +```bash +lighter COMMAND config1.yaml,config2.yaml,... +``` + +### Behavior + +Later files override earlier ones (dictionary merge). + +### Examples + +```bash +# Base + experiment +lighter fit base.yaml,experiment.yaml + +# Multiple overrides +lighter fit base.yaml,data.yaml,model.yaml,overrides.yaml +``` + +### Example Files + +**base.yaml**: +```yaml +trainer: + max_epochs: 100 + accelerator: auto + +model: + network: + num_classes: 10 +``` + +**experiment.yaml**: +```yaml +trainer: + max_epochs: 200 # Override + +model: + optimizer: # Add + lr: 0.01 +``` + +**Result**: Merged config with `max_epochs=200` and new optimizer. + +## Environment Variables + +### LIGHTER_OUTPUT_DIR + +Change default output directory: + +```bash +export LIGHTER_OUTPUT_DIR=./my_outputs +lighter fit config.yaml +``` + +Or in config: + +```yaml +trainer: + default_root_dir: ./my_outputs +``` + +### CUDA_VISIBLE_DEVICES + +Control GPU visibility: + +```bash +# Use only GPU 2 +CUDA_VISIBLE_DEVICES=2 lighter fit config.yaml + +# Use GPUs 0 and 3 +CUDA_VISIBLE_DEVICES=0,3 lighter fit config.yaml trainer::devices=2 +``` + +### MASTER_ADDR / MASTER_PORT + +For multi-node training: + +```bash +MASTER_ADDR=node0 MASTER_PORT=12345 lighter fit config.yaml +``` + +## Common Patterns + +### Quick Debugging + +```bash +# Fast dev run (1 batch) +lighter fit config.yaml trainer::fast_dev_run=true + +# Overfit 10 batches +lighter fit config.yaml trainer::overfit_batches=10 + +# Limit batches +lighter fit config.yaml trainer::limit_train_batches=0.1 +``` + +### Hyperparameter Tuning + +```bash +# Learning rate sweep +for lr in 0.0001 0.001 0.01; do + lighter fit config.yaml model::optimizer::lr=$lr +done + +# Batch size sweep +for bs in 32 64 128 256; do + lighter fit config.yaml data::train_dataloader::batch_size=$bs +done +``` + +### Multi-GPU + +```bash +# All GPUs +lighter fit config.yaml trainer::devices=-1 trainer::strategy=ddp + +# Specific GPUs +lighter fit config.yaml trainer::devices=4 trainer::strategy=ddp + +# Specific GPU IDs +lighter fit config.yaml 'trainer::devices=[0,2,3]' trainer::strategy=ddp +``` + +### Resume Training + +```bash +# Resume from last checkpoint +lighter fit config.yaml --ckpt_path outputs/.../checkpoints/last.ckpt + +# Resume with different LR +lighter fit config.yaml \ + --ckpt_path outputs/.../checkpoints/last.ckpt \ + model::optimizer::lr=0.0001 +``` + +### Save Predictions + +```bash +# CSV output +lighter predict config.yaml \ + --ckpt_path checkpoints/best.ckpt \ + 'trainer::callbacks=[{_target_: lighter.callbacks.CSVWriter}]' + +# File output +lighter predict config.yaml \ + --ckpt_path checkpoints/best.ckpt \ + 'trainer::callbacks=[{_target_: lighter.callbacks.FileWriter}]' +``` + +## Exit Codes + +| Code | Meaning | +|------|---------| +| 0 | Success | +| 1 | Config error (invalid YAML, missing files) | +| 2 | Runtime error (training failed, OOM, etc.) | + +## Verbosity + +PyTorch Lightning controls logging verbosity. + +### Reduce Logging + +```yaml +trainer: + enable_progress_bar: false + enable_model_summary: false +``` + +### Increase Logging + +```bash +# Python logging +export PYTHONWARNINGS=default +lighter fit config.yaml +``` + +## Tips + +### Shell Completion + +Add to your shell config: + +```bash +# Bash +eval "$(_LIGHTER_COMPLETE=bash_source lighter)" + +# Zsh +eval "$(_LIGHTER_COMPLETE=zsh_source lighter)" +``` + +### Config Validation + +Validate config without training: + +```bash +lighter fit config.yaml trainer::fast_dev_run=true +``` + +### Find Config Issues + +Enable Sparkwheel debug output: + +```bash +SPARKWHEEL_DEBUG=1 lighter fit config.yaml +``` + +### See Resolved Config + +Add print in your module: + +```python +def __init__(self, ...): + super().__init__() + self.save_hyperparameters() + print(self.hparams) # See final values +``` + +## Next Steps + +- [Configuration Guide](../guides/configuration.md) - Learn config syntax +- [Training Guide](../guides/training.md) - Training workflows +- [Examples](../examples/image-classification.md) - Complete examples + +## Quick Reference + +```bash +# Basic commands +lighter fit config.yaml +lighter validate config.yaml +lighter test config.yaml +lighter predict config.yaml + +# Overrides +lighter fit config.yaml key::path=value + +# Multiple configs +lighter fit base.yaml,experiment.yaml + +# Checkpoints +lighter fit config.yaml --ckpt_path path/to/checkpoint.ckpt +lighter validate config.yaml --ckpt_path path/to/checkpoint.ckpt +lighter test config.yaml --ckpt_path path/to/checkpoint.ckpt +lighter predict config.yaml --ckpt_path path/to/checkpoint.ckpt + +# Multi-GPU +lighter fit config.yaml trainer::devices=4 trainer::strategy=ddp + +# Debugging +lighter fit config.yaml trainer::fast_dev_run=true +``` diff --git a/docs/tutorials/get-started.md b/docs/tutorials/get-started.md deleted file mode 100644 index c263873d..00000000 --- a/docs/tutorials/get-started.md +++ /dev/null @@ -1,562 +0,0 @@ ---- -title: Get Started ---- - -# Get Started with Lighter - -This guide takes you from installation to running experiments in 15 minutes, using proper project structure from the start. - -## Installation - -```bash -pip install lighter -``` - -## The Core Idea - -Traditional PyTorch Lightning requires writing training loops: - -```python -class MyModule(LightningModule): - def __init__(self): - self.model = Model() - self.criterion = nn.CrossEntropyLoss() - - def training_step(self, batch, batch_idx): - x, y = batch - pred = self.model(x) - loss = self.criterion(pred, y) - return loss - - def configure_optimizers(self): - return torch.optim.Adam(self.model.parameters(), lr=0.001) - -trainer = Trainer(max_epochs=10) -trainer.fit(module, train_loader, val_loader) -``` - -**Lighter replaces this with configuration:** - -```yaml -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - model: - _target_: MyModel - criterion: - _target_: torch.nn.CrossEntropyLoss - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - dataloaders: - train: ... - val: ... -``` - -```bash -lighter fit config.yaml -``` - -## Step 1: Create Your Project - -Set up a proper project structure (this will pay off as you add experiments): - -```bash -mkdir -p my_experiments/experiments -cd my_experiments -touch __init__.py -``` - -Your project structure: -``` -my_experiments/ -├── __init__.py -└── experiments/ - └── (configs will go here) -``` - -## Step 2: Minimal Example - -Create `experiments/minimal.yaml`: - -```yaml -project: . # Import from my_experiments/ - -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 3 - -system: - _target_: lighter.System - - model: - _target_: torch.nn.Linear - in_features: 784 # MNIST: 28x28 flattened - out_features: 10 # 10 digits - - criterion: - _target_: torch.nn.CrossEntropyLoss - - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 64 - dataset: - _target_: torchvision.datasets.MNIST - root: ./data - train: true - download: true - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Lambda - lambd: "$lambda x: x.view(-1)" # Flatten to 784 -``` - -Run it: - -```bash -lighter fit experiments/minimal.yaml -``` - -You just trained a neural network using only YAML configuration. - -## Step 3: Real Example (CIFAR-10) - -Create `experiments/cifar10.yaml`: - -```yaml -project: . - -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - accelerator: auto - -system: - _target_: lighter.System - - model: - _target_: torchvision.models.resnet18 - num_classes: 10 - - criterion: - _target_: torch.nn.CrossEntropyLoss - - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - - metrics: - train: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - val: "%system::metrics::train" # Reuse train metrics - - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 128 - shuffle: true - num_workers: 4 - dataset: - _target_: torchvision.datasets.CIFAR10 - root: ./data - train: true - download: true - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.RandomHorizontalFlip - - _target_: torchvision.transforms.RandomCrop - size: 32 - padding: 4 - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.4914, 0.4822, 0.4465] - std: [0.2470, 0.2435, 0.2616] - - val: - _target_: torch.utils.data.DataLoader - batch_size: 256 - num_workers: 4 - dataset: - _target_: torchvision.datasets.CIFAR10 - root: ./data - train: false - download: true - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.4914, 0.4822, 0.4465] - std: [0.2470, 0.2435, 0.2616] -``` - -Run it: - -```bash -lighter fit experiments/cifar10.yaml -``` - -You now have automatic: -- Training and validation loops -- Metrics computation and logging -- Loss tracking -- Checkpointing - -## Step 4: Add Custom Models - -Now here's why we set up a proper project structure. Let's add a custom CNN. - -Create `models/__init__.py` and `models/simple_cnn.py`: - -```bash -mkdir models -touch models/__init__.py -``` - -```python title="models/simple_cnn.py" -import torch.nn as nn - -class SimpleCNN(nn.Module): - def __init__(self, num_classes=10): - super().__init__() - self.conv1 = nn.Conv2d(3, 32, 3, padding=1) - self.relu = nn.ReLU() - self.pool = nn.MaxPool2d(2) - self.conv2 = nn.Conv2d(32, 64, 3, padding=1) - self.fc = nn.Linear(64 * 8 * 8, num_classes) - - def forward(self, x): - x = self.pool(self.relu(self.conv1(x))) - x = self.pool(self.relu(self.conv2(x))) - x = x.view(x.size(0), -1) - return self.fc(x) -``` - -Your project now looks like: -``` -my_experiments/ -├── __init__.py -├── experiments/ -│ ├── minimal.yaml -│ └── cifar10.yaml -├── models/ -│ ├── __init__.py -│ └── simple_cnn.py -└── data/ # Created by datasets -``` - -Create `experiments/custom_model.yaml`: - -```yaml -project: . # This imports my_experiments/ - -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - - model: - _target_: models.simple_cnn.SimpleCNN # Your custom model! - num_classes: 10 - - criterion: - _target_: torch.nn.CrossEntropyLoss - - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 128 - shuffle: true - dataset: - _target_: torchvision.datasets.CIFAR10 - root: ./data - train: true - download: true - transform: - _target_: torchvision.transforms.ToTensor -``` - -Run it: - -```bash -lighter fit experiments/custom_model.yaml -``` - -**This is the key insight**: By setting up proper structure from the start, adding custom components is natural, not a separate concept to learn. - -## Understanding the Syntax - -Lighter uses **[Sparkwheel](https://project-lighter.github.io/sparkwheel/)** for configuration. Here are the essentials: - -### `_target_:` Instantiate a Class - -```yaml -model: - _target_: torch.nn.Linear - in_features: 784 - out_features: 10 -``` - -**Equivalent to:** `model = torch.nn.Linear(in_features=784, out_features=10)` - -Works with any Python class—PyTorch, third-party, or your custom code. - -### `project:` Import Custom Modules - -```yaml -project: . # Import from current directory as a Python module -``` - -This makes `models/`, `datasets/`, `transforms/` etc. importable via `_target_`. - -### `$` Evaluate Python Expression - -```yaml -optimizer: - lr: "$0.001 * 2" # Evaluates to 0.002 -``` - -### `@` Resolved Reference - -```yaml -optimizer: - params: "$@system::model.parameters()" # Gets actual model instance, calls parameters() -``` - -Gets the instantiated object (after `_target_` processing). - -### `%` Raw Reference - -```yaml -metrics: - train: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - val: "%system::metrics::train" # Gets raw YAML, creates new instance -``` - -Gets the unprocessed YAML configuration (before instantiation). - -### `::` Path Notation - -```yaml -system::model # Navigate to model definition -system::optimizer::lr # Navigate to nested value -``` - -Navigate nested config with `::` separator—more concise than `["system"]["model"]`. - -!!! tip "Learn More" - For complete Sparkwheel documentation including advanced features, see **[Sparkwheel docs](https://project-lighter.github.io/sparkwheel/)**. - -## CLI Overrides - -Change hyperparameters without editing files: - -```bash -# Change learning rate -lighter fit experiments/cifar10.yaml system::optimizer::lr=0.01 - -# Train longer -lighter fit experiments/cifar10.yaml trainer::max_epochs=100 - -# Use multiple GPUs -lighter fit experiments/cifar10.yaml trainer::devices=2 - -# Combine multiple overrides -lighter fit experiments/cifar10.yaml \ - trainer::max_epochs=100 \ - system::optimizer::lr=0.001 \ - trainer::devices=4 -``` - -## Organizing Multiple Experiments - -As your project grows, organize configs by purpose: - -``` -my_experiments/ -├── __init__.py -├── experiments/ -│ ├── baselines/ -│ │ ├── resnet18.yaml -│ │ └── resnet50.yaml -│ ├── ablations/ -│ │ ├── no_augmentation.yaml -│ │ └── different_optimizer.yaml -│ └── production/ -│ └── final_model.yaml -├── models/ -│ ├── __init__.py -│ └── simple_cnn.py -└── datasets/ # Add custom datasets here - ├── __init__.py - └── my_dataset.py -``` - -## Merging Configs - -Create reusable config components: - -```yaml title="experiments/base.yaml" -project: . - -trainer: - _target_: pytorch_lightning.Trainer - accelerator: auto - -system: - _target_: lighter.System - criterion: - _target_: torch.nn.CrossEntropyLoss -``` - -```yaml title="experiments/resnet18.yaml" -system: - model: - _target_: torchvision.models.resnet18 - num_classes: 10 - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 -``` - -Combine them: - -```bash -lighter fit experiments/base.yaml,experiments/resnet18.yaml -``` - -Later configs override earlier ones, enabling modular experiment design. - -## Testing and Prediction - -```bash -# Test trained model -lighter test experiments/cifar10.yaml args::test::ckpt_path=checkpoints/best.ckpt - -# Generate predictions -lighter predict experiments/cifar10.yaml args::predict::ckpt_path=checkpoints/best.ckpt -``` - -## When You Need Adapters - -Your dataset returns a dict but your model needs tensors? Use adapters: - -```yaml -system: - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: "image" # Extract from dict - target_accessor: "label" -``` - -Your loss expects `(target, pred)` instead of `(pred, target)`? Swap them: - -```yaml -system: - adapters: - train: - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_argument: 1 - target_argument: 0 -``` - -**Adapters make Lighter task-agnostic**—they connect any data format to any model/loss/metric. - -[Learn more about adapters →](../how-to/adapters.md) - -### Continue Learning - -- **[Configuration Guide](../how-to/configuration.md)** - Complete syntax reference -- **[Adapters](../how-to/adapters.md)** - Handle any data format -- **[Recipes](../how-to/recipes.md)** - Ready-to-use patterns -- **[Architecture](../design/overview.md)** - How Lighter works internally - -### Quick Reference - -```yaml -# Essential Sparkwheel syntax -project: . # Import custom modules -_target_: module.ClassName # Instantiate class -$expression # Evaluate Python expression -@path::to::object # Resolved reference (instantiated object) -%path::to::config # Raw reference (unprocessed YAML) -path::nested::key # Path notation (navigate config) -::sibling::key # Relative reference (sibling in same section) -=key: # Replace operator (override default merge) -~key: null # Delete entire key -~key::1: null # Delete single list item -~key: [0, 2] # Delete multiple items (batch syntax) - -# Lighter CLI commands -lighter fit experiments/config.yaml # Train -lighter validate experiments/config.yaml # Validate -lighter test experiments/config.yaml # Test -lighter predict experiments/config.yaml # Predict - -# Override from CLI -lighter fit experiments/cfg.yaml key::path=value - -# Merge configs (automatic by default) -lighter fit experiments/base.yaml,experiments/exp.yaml -``` - -### Project Structure - -``` -my_experiments/ -├── __init__.py # Make it a module -├── experiments/ # All configs here -│ ├── base.yaml -│ ├── exp1.yaml -│ └── exp2.yaml -├── models/ # Custom models -│ ├── __init__.py -│ └── my_model.py -├── datasets/ # Custom datasets -│ ├── __init__.py -│ └── my_dataset.py -└── transforms/ # Custom transforms - ├── __init__.py - └── my_transform.py -``` - -## Getting Help - -- **Stuck?** [Troubleshooting Guide](../how-to/troubleshooting.md) -- **Questions?** [FAQ](../faq.md) -- **Coming from Lightning?** [Migration Guide](../migration/from-pytorch-lightning.md) -- **Community?** [Discord](https://discord.gg/zJcnp6KrUp) - -## Complete Example - -See full examples with this structure in the repository's [projects](https://github.com/project-lighter/lighter/projects/) directory. diff --git a/mkdocs.yml b/mkdocs.yml index 490d81ae..601741fb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -17,6 +17,15 @@ theme: - footnotes - navigation.tabs - navigation.top + - navigation.instant + - navigation.tracking + - navigation.sections + - navigation.indexes + - toc.follow + - search.suggest + - search.highlight + - search.share + - content.tabs.link palette: # Palette toggle for automatic mode @@ -52,42 +61,31 @@ plugins: - docs/gen_ref_pages.py - literate-nav: nav_file: SUMMARY.md - - section-index - mkdocstrings: handlers: python: paths: [src] options: - # Removed the default filter that excludes private members (that is, members whose names start with a single underscore). + # Removed the default filter that excludes private members filters: null nav: - Home: index.md - - Get Started: tutorials/get-started.md - - How-To Guides: - - Essentials: - - Run: how-to/run.md - - Configuration: how-to/configuration.md - - Project Module: how-to/project_module.md - - Features: - - Adapters: how-to/adapters.md - - Metrics: how-to/metrics.md - - Writers: how-to/writers.md - - Freezers: how-to/freezers.md - - Inferers: how-to/inferers.md - - Practical: - - Configuration Recipes: how-to/recipes.md - - Experiment Tracking: how-to/experiment_tracking.md - - Troubleshooting: how-to/troubleshooting.md - - Design & Architecture: - - Overview: design/overview.md - - System Internals: design/system.md - - Adapter Pattern: design/adapters.md - - Philosophy: design/philosophy.md - - FAQ: - - FAQ: faq.md - - Migration from Lightning: migration/from-pytorch-lightning.md - - API: reference/ + - Quick Start: quickstart.md + - Guides: + - Configuration: guides/configuration.md + - Custom Code: guides/custom-code.md + - Your LightningModule: guides/lightning-module.md + - LighterModule: guides/lighter-module.md + - Training: guides/training.md + - Best Practices: guides/best-practices.md + - Examples: + - Image Classification: examples/image-classification.md + - Multi-GPU Training: examples/multi-gpu.md + - Reference: + - CLI: reference/cli.md + - API: reference/ + - FAQ: faq.md markdown_extensions: - pymdownx.highlight: @@ -110,6 +108,7 @@ markdown_extensions: - pymdownx.emoji: emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_generator: !!python/name:material.extensions.emoji.to_svg + - tables repo_name: project-lighter/lighter repo_url: https://github.com/project-lighter/lighter diff --git a/paper/lighter_diagrams.pptx b/paper/lighter_diagrams.pptx deleted file mode 100644 index cbd83887..00000000 Binary files a/paper/lighter_diagrams.pptx and /dev/null differ diff --git a/paper/overview_all.png b/paper/overview_all.png deleted file mode 100644 index 97226bfe..00000000 Binary files a/paper/overview_all.png and /dev/null differ diff --git a/paper/overview_system.png b/paper/overview_system.png deleted file mode 100644 index 6ebaca1d..00000000 Binary files a/paper/overview_system.png and /dev/null differ diff --git a/paper/paper.bib b/paper/paper.bib deleted file mode 100644 index 4b95b8f8..00000000 --- a/paper/paper.bib +++ /dev/null @@ -1,134 +0,0 @@ -@software{Falcon_PyTorch_Lightning_2019, - author = {Falcon, William and {The PyTorch Lightning team}}, - doi = {10.5281/zenodo.3828935}, - license = {Apache-2.0}, - month = mar, - title = {{PyTorch Lightning}}, - url = {https://github.com/Lightning-AI/lightning}, - version = {1.4}, - year = {2019} -} - -@article{Cardoso_MONAI_An_open-source_2022, - author = {Cardoso, M. Jorge and Li, Wenqi and Brown, Richard and Ma, Nic and Kerfoot, Eric and Wang, Yiheng and Murray, Benjamin and Myronenko, Andriy and Zhao, Can and Yang, Dong and Nath, Vishwesh and He, Yufan and Xu, Ziyue and Hatamizadeh, Ali and Zhu, Wentao and Liu, Yun and Zheng, Mingxin and Tang, Yucheng and Yang, Isaac and Zephyr, Michael and Hashemian, Behrooz and Alle, Sachidanand and Zalbagi Darestani, Mohammad and Budd, Charlie and Modat, Marc and Vercauteren, Tom and Wang, Guotai and Li, Yiwen and Hu, Yipeng and Fu, Yunguan and Gorman, Benjamin and Johnson, Hans and Genereaux, Brad and Erdal, Barbaros S. and Gupta, Vikash and Diaz-Pinto, Andres and Dourson, Andre and Maier-Hein, Lena and Jaeger, Paul F. and Baumgartner, Michael and Kalpathy-Cramer, Jayashree and Flores, Mona and Kirby, Justin and Cooper, Lee A.D. and Roth, Holger R. and Xu, Daguang and Bericat, David and Floca, Ralf and Zhou, S. Kevin and Shuaib, Haris and Farahani, Keyvan and Maier-Hein, Klaus H. and Aylward, Stephen and Dogra, Prerna and Ourselin, Sebastien and Feng, Andrew}, - doi = {10.48550/arXiv.2211.02701}, - month = nov, - title = {{MONAI: An open-source framework for deep learning in healthcare}}, - year = {2022} -} - -@article{Pai2024, - title = "Foundation model for cancer imaging biomarkers", - author = "Pai, Suraj and Bontempi, Dennis and Hadzic, Ibrahim and - Prudente, Vasco and Soka{\v c}, Mateo and Chaunzwa, Tafadzwa L - and Bernatz, Simon and Hosny, Ahmed and Mak, Raymond H and - Birkbak, Nicolai J and Aerts, Hugo J W L", - abstract = "Foundation models in deep learning are characterized by a single - large-scale model trained on vast amounts of data serving as the - foundation for various downstream tasks. Foundation models are - generally trained using self-supervised learning and excel in - reducing the demand for training samples in downstream - applications. This is especially important in medicine, where - large labelled datasets are often scarce. Here, we developed a - foundation model for cancer imaging biomarker discovery by - training a convolutional encoder through self-supervised - learning using a comprehensive dataset of 11,467 radiographic - lesions. The foundation model was evaluated in distinct and - clinically relevant applications of cancer imaging-based - biomarkers. We found that it facilitated better and more - efficient learning of imaging biomarkers and yielded - task-specific models that significantly outperformed - conventional supervised and other state-of-the-art pretrained - implementations on downstream tasks, especially when training - dataset sizes were very limited. Furthermore, the foundation - model was more stable to input variations and showed strong - associations with underlying biology. Our results demonstrate - the tremendous potential of foundation models in discovering new - imaging biomarkers that may extend to other clinical use cases - and can accelerate the widespread translation of imaging - biomarkers into clinical settings.", - journal = "Nat. Mach. Intell.", - publisher = "Springer Science and Business Media LLC", - volume = 6, - number = 3, - pages = "354--367", - month = mar, - year = 2024, - keywords = "Biomarkers; Cancer imaging; Tumour biomarkers", - copyright = "https://creativecommons.org/licenses/by/4.0", - language = "en", - doi = {10.1038/s42256-024-00807-9} -} - -@ARTICLE{Pai2025, - title = "Vision foundation models for computed tomography", - author = "Pai, Suraj and Hadzic, Ibrahim and Bontempi, Dennis and - Bressem, Keno and Kann, Benjamin H and Fedorov, Andriy and - Mak, Raymond H and Aerts, Hugo J W L", - abstract = "Foundation models (FMs) have shown transformative potential - in radiology by performing diverse, complex tasks across - imaging modalities. Here, we developed CT-FM, a large-scale - 3D image-based pre-trained model designed explicitly for - various radiological tasks. CT-FM was pre-trained using - 148,000 computed tomography (CT) scans from the Imaging Data - Commons through label-agnostic contrastive learning. We - evaluated CT-FM across four categories of tasks, namely, - whole-body and tumor segmentation, head CT triage, medical - image retrieval, and semantic understanding, showing superior - performance against state-of-the-art models. Beyond - quantitative success, CT-FM demonstrated the ability to - cluster regions anatomically and identify similar anatomical - and structural concepts across scans. Furthermore, it - remained robust across test-retest settings and indicated - reasonable salient regions attached to its embeddings. This - study demonstrates the value of large-scale medical imaging - foundation models and by open-sourcing the model weights, - code, and data, aims to support more adaptable, reliable, and - interpretable AI solutions in radiology.", - year = 2025, - primaryClass = "eess.IV", - eprint = "2501.09001", - doi = {10.48550/arXiv.2501.09001} -} - -@article{Ludwig, - author = {Piero Molino and - Yaroslav Dudin and - Sai Sumanth Miryala}, - title = {Ludwig: a type-based declarative deep learning toolbox}, - journal = {CoRR}, - volume = {abs/1909.07930}, - year = {2019}, - url = {http://arxiv.org/abs/1909.07930}, - eprinttype = {arXiv}, - eprint = {1909.07930}, - timestamp = {Tue, 24 Sep 2019 11:33:51 +0200}, - biburl = {https://dblp.org/rec/journals/corr/abs-1909-07930.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} -} - -@software{Quadra, - title = {orobix/quadra}, - url = {https://github.com/orobix/quadra}, - author = {Mammana, Lorenzo and Malli, Refik Can and Polidori, Alessandro and ebernatene}, - date = {2025-05-20}, - year = {2025}, - month = {5}, - day = {20}, -} - -@article{Gandlf, - author={Pati, Sarthak and Thakur, Siddhesh P. and Hamamc{\i}, {\.{I}}brahim Ethem and Baid, Ujjwal and Baheti, Bhakti and Bhalerao, Megh and G{\"u}ley, Orhun and Mouchtaris, Sofia and Lang, David and Thermos, Spyridon and Gotkowski, Karol and Gonz{\'a}lez, Camila and Grenko, Caleb and Getka, Alexander and Edwards, Brandon and Sheller, Micah and Wu, Junwen and Karkada, Deepthi and Panchumarthy, Ravi and Ahluwalia, Vinayak and Zou, Chunrui and Bashyam, Vishnu and Li, Yuemeng and Haghighi, Babak and Chitalia, Rhea and Abousamra, Shahira and Kurc, Tahsin M. and Gastounioti, Aimilia and Er, Sezgin and Bergman, Mark and Saltz, Joel H. and Fan, Yong and Shah, Prashant and Mukhopadhyay, Anirban and Tsaftaris, Sotirios A. and Menze, Bjoern and Davatzikos, Christos and Kontos, Despina and Karargyris, Alexandros and Umeton, Renato and Mattson, Peter and Bakas, Spyridon}, - title={GaNDLF: the generally nuanced deep learning framework for scalable end-to-end clinical workflows}, - journal={Communications Engineering}, - year={2023}, - month={May}, - day={16}, - volume={2}, - number={1}, - pages={23}, - abstract={Deep Learning (DL) has the potential to optimize machine learning in both the scientific and clinical communities. However, greater expertise is required to develop DL algorithms, and the variability of implementations hinders their reproducibility, translation, and deployment. Here we present the community-driven Generally Nuanced Deep Learning Framework (GaNDLF), with the goal of lowering these barriers. GaNDLF makes the mechanism of DL development, training, and inference more stable, reproducible, interpretable, and scalable, without requiring an extensive technical background. GaNDLF aims to provide an end-to-end solution for all DL-related tasks in computational precision medicine. We demonstrate the ability of GaNDLF to analyze both radiology and histology images, with built-in support for k-fold cross-validation, data augmentation, multiple modalities and output classes. Our quantitative performance evaluation on numerous use cases, anatomies, and computational tasks supports GaNDLF as a robust application framework for deployment in clinical workflows.}, - issn={2731-3395}, - doi={10.1038/s44172-023-00066-3}, - url={https://doi.org/10.1038/s44172-023-00066-3} -} diff --git a/paper/paper.md b/paper/paper.md deleted file mode 100644 index 770351cf..00000000 --- a/paper/paper.md +++ /dev/null @@ -1,130 +0,0 @@ ---- -title: 'Lighter: Configuration-Driven Deep Learning' -tags: - - Python - - PyTorch - - deep learning - - configuration - - framework -authors: - - name: Ibrahim Hadzic - orcid: 0000-0002-8397-5940 - corresponding: true - affiliation: "1, 2" - - name: Suraj Pai - orcid: 0000-0001-8043-2230 - affiliation: "2, 3, 4" - - name: Keno Bressem - affiliation: "5, 6" - orcid: 0000-0001-9249-8624 - - name: Borek Foldyna - affiliation: 1 - orcid: 0000-0002-2466-4827 - - given-names: Hugo - dropping-particle: JWL - surname: Aerts - affiliation: "2, 3, 4" - orcid: 0000-0002-2122-2003 - -affiliations: - - name: Cardiovascular Imaging Research Center, Massachusetts General Hospital, Harvard Medical School, United States of America - index: 1 - - name: Radiology and Nuclear Medicine, CARIM & GROW, Maastricht University, The Netherlands - index: 2 - - name: Artificial Intelligence in Medicine (AIM) Program, Mass General Brigham, Harvard Medical School, Harvard Institutes of Medicine, United States of America - index: 3 - - name: Department of Radiation Oncology, Brigham and Women’s Hospital, Dana-Farber Cancer Institute, Harvard Medical School, United States of America - index: 4 - - name: Technical University of Munich, School of Medicine and Health, Klinikum rechts der Isar, TUM University Hospital, Germany - index: 5 - - name: Department of Cardiovascular Radiology and Nuclear Medicine, Technical University of Munich, School of Medicine and Health, German Heart Center, TUM University Hospital, Germany - index: 6 -date: 24 February 2025 -bibliography: paper.bib - ---- - -# Summary - -Lighter is a configuration-driven deep learning (DL) [framework](https://github.com/project-lighter/lighter) that separates experimental setup from code implementation. Models, datasets, and other components are defined through structured configuration files (configs). Configs serve as snapshots of the experiments, enhancing reproducibility while eliminating unstructured and repetitive scripts. Lighter uses (i) PyTorch Lightning [@Falcon_PyTorch_Lightning_2019] to implement a task-agnostic DL logic, and (ii) [MONAI Bundle configuration](https://docs.monai.io/en/stable/config_syntax.html#) [@Cardoso_MONAI_An_open-source_2022] to manage experiments using YAML configs. - -# Statement of Need - -Lighter addresses several challenges in DL experimentation: - -1. **Repetitive and Error-Prone Setups**: DL typically involves significant boilerplate code for training loops, data loading, and metric calculations. The numerous hyperparameters and components across experiments can easily become complex and error-prone. Lighter abstracts these repetitive tasks and uses centralized configs for a clear, manageable experimental setup, reducing tedium and potential for errors. - -2. **Reproducibility and Collaboration**: Inconsistent or complex codebases hinder collaboration and experiment reproduction. Lighter's self-documenting configs offer clear, structured snapshots of each experiment. This greatly improves reproducibility and simplifies how teams share and reuse setups. - -3. **Pace of Research Iteration**: The cumulative effect of these challenges inherently slows down the research cycle. Lighter streamlines the entire experimental process, allowing researchers to focus on core hypotheses and iterate on ideas efficiently. - -# State of the Field - -Config-driven frameworks like Ludwig [@Ludwig], Quadra [@Quadra], and GaNDLF [@Gandlf] offer high level of abstraction by providing predefined structures and pipelines. While this approach simplifies usage, it limits flexibility to modify the pipeline or extend components, often requiring direct source code changes. -Lighter takes a different approach by providing medium-level abstraction. It implements a flexible pipeline that maintains direct compatibility with standard PyTorch components (models, datasets, optimizers). The pipeline itself is modifiable to any task via [adapters](#adapters), while custom code is [importable via config](#project-specific-modules) without source code modifications. - -# Design - -Lighter is built upon three fundamental components (\autoref{fig:overview_all}): - -1. **`Config`**: serves as the primary interface for interacting with Lighter. It parses and validates YAML configs that define all components, creating a self-documenting record of each experiment. - -2. **`System`**: encapsulates the components (model, optimizer, scheduler, loss function, metrics, and dataloaders) and connects them into a pipeline that can be customized through [adapters](#adapters) (\autoref{fig:overview_system}). - -3. **`Trainer`**: PyTorch Lightning's `Trainer` handles aspects like distributed or mixed-precision training and checkpoint management. Lighter uses it to execute the protocol defined by the `System`. - -![**Lighter Overview.** `Config` leverages MONAI's `ConfigParser` for parsing the user-defined YAML configs, and its features are used by Runner to instantiate the `System` and `Trainer`. `Trainer` is used directly from PyTorch Lightning, whereas `System` inherits from `LightningModule`, ensuring its compatibility with `Trainer` while implementing a logic generalizable to any task or type of data. Finally, `Runner` runs the paired `Trainer` and `System` for a particular stage (e.g., fit or test).\label{fig:overview_all}](overview_all.png) - -![**Flowchart of the `lighter.System`.** A `batch` from the `DataLoader` is processed by `BatchAdapter` to extract `input`, `target` (optional), and `identifier` (optional). The `Model` generates `pred` (predictions) from the `input`. `CriterionAdapter` and `MetricsAdapter` compute loss and metrics, respectively, by applying optional transformations and routing arguments for the loss and metric functions. Results, including loss, metrics, and other data prepared for logging by the `LoggingAdapter` are returned to the `Trainer`.\label{fig:overview_system}](overview_system.png) - -## Adaptability Through Modular Design - -### Adapters - -If we consider all possible DL tasks, we will find it challenging to implement a single pipeline that supports all. Instead, frameworks often implement per-task pipelines (e.g., segmentation, classification, etc.). By contrast, Lighter implements a unified pipeline modifiable via *adapter classes*. In software design, *adapter design pattern* enables components with incompatible interfaces to work together by *bridging* them using an adapter class. In Lighter, these bridges (\autoref{fig:overview_system}) specify how components should interact across data types and tasks. For example, a model's output will differ based on the task (e.g., segmentation, regression), and the adapter will specify how to pass them on to the next component (e.g., criterion or metrics). This design allows Lighter to handle any task without requiring changes to the source code. - -```yaml -# Example of an adapter transforming and routing data to the loss function -adapters: - train: - criterion: - _target_: lighter.adapters.CriterionAdapter - pred_transforms: # Apply sigmoid activation to predictions - _target_: torch.sigmoid - pred_argument: 0 # Pass 'pred' to criterion's first arg - target_argument: 1 # Pass 'target' to criterion's second arg -``` - -### Project-specific modules - -Using custom components does not require modifying the framework. Instead, they can be defined within a *project folder* like: - -``` -joss_project -├── __init__.py -└── models/ - ├── __init__.py - └── mlp.py -``` - -By specifying the project path in the config, it is imported as a module whose components can be referenced in the config: - -```yaml -project: /path/to/joss_project # Path to the directory above -system: - model: - _target_: project.models.mlp.MLP # Reference to the custom model - input_size: 784 - num_classes: 10 -``` - -# Research Contributions That Use Lighter - -- Foundation model for cancer imaging biomarkers [@Pai2024] -- Vision Foundation Models for Computed Tomography [@Pai2025] - -# Acknowledgments - -We thank John Zielke for the adapter design pattern idea. We thank Wenqi Li, Nic Ma, Yun Liu, and Eric Kerfoot for their continuous support with MONAI Bundle. - -# References diff --git a/paper/paper.pdf b/paper/paper.pdf deleted file mode 100644 index d4447548..00000000 Binary files a/paper/paper.pdf and /dev/null differ diff --git a/projects/cifar10/__lighter__.py b/projects/cifar10/__lighter__.py new file mode 100644 index 00000000..5f6aee28 --- /dev/null +++ b/projects/cifar10/__lighter__.py @@ -0,0 +1,3 @@ +# This file marks the directory as a Lighter project. +# When you run `lighter` from this directory, the project will be auto-discovered +# and you can reference modules `project.module.submodule` diff --git a/projects/cifar10/configs/example.yaml b/projects/cifar10/configs/example.yaml new file mode 100644 index 00000000..ddc48c30 --- /dev/null +++ b/projects/cifar10/configs/example.yaml @@ -0,0 +1,98 @@ +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 2 + accelerator: cpu + # devices: 1 # 2 + # strategy: ddp + log_every_n_steps: 10 + logger: False + callbacks: + - _target_: lighter.callbacks.FileWriter + directory: predictions + value_key: pred + writer_fn: tensor + +model: + _target_: project.model.CIFAR10Model + + network: + _target_: project.models.net.Net + + criterion: + _target_: torch.nn.CrossEntropyLoss + + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + weight_decay: 0.00001 + + # Metrics using MetricCollection with auto-naming from list + train_metrics: + _target_: torchmetrics.MetricCollection + metrics: + - _target_: torchmetrics.Accuracy + task: multiclass + num_classes: 10 + average: macro + - _target_: torchmetrics.F1Score + task: multiclass + num_classes: 10 + average: macro + - _target_: torchmetrics.Precision + task: multiclass + num_classes: 10 + average: macro + - _target_: torchmetrics.Recall + task: multiclass + num_classes: 10 + average: macro + val_metrics: "%::train_metrics" + test_metrics: "%::train_metrics" + +# NEW: Data configuration using LighterDataModule +data: + _target_: lighter.LighterDataModule + + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 512 + num_workers: 2 + pin_memory: True + shuffle: True + dataset: + _target_: torchvision.datasets.CIFAR10 + download: True + root: ./.datasets/ + train: True + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + target_transform: null + + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 512 + num_workers: 2 + pin_memory: True + shuffle: False + dataset: + _target_: torchvision.datasets.CIFAR10 + download: True + root: ./.datasets/ + train: False + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + target_transform: null + + test_dataloader: "%data::val_dataloader" + predict_dataloader: "%data::val_dataloader" diff --git a/projects/cifar10/configs/example_overrides.yaml b/projects/cifar10/configs/example_overrides.yaml new file mode 100644 index 00000000..bbce6591 --- /dev/null +++ b/projects/cifar10/configs/example_overrides.yaml @@ -0,0 +1,2 @@ +trainer::max_epochs: 1 +model::optimizer::_target_: torch.optim.SGD diff --git a/projects/cifar10/experiments/example.yaml b/projects/cifar10/experiments/example.yaml deleted file mode 100644 index 224fa653..00000000 --- a/projects/cifar10/experiments/example.yaml +++ /dev/null @@ -1,103 +0,0 @@ -project: ./projects/cifar10 - -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 100 - accelerator: cpu - # devices: 1 # 2 - # strategy: ddp - log_every_n_steps: 10 - logger: False - callbacks: - - _target_: lighter.callbacks.FileWriter - path: '$f"{@project}/predictions"' - writer: tensor - -system: - _target_: lighter.System - model: - _target_: "project.models.net.Net" - - criterion: - _target_: torch.nn.CrossEntropyLoss - - optimizer: - _target_: torch.optim.Adam - params: "$@system::model.parameters()" - lr: 0.001 - weight_decay: 0.00001 - - metrics: - train: - - _target_: torchmetrics.Accuracy - task: multiclass - num_classes: 10 - average: macro - - _target_: torchmetrics.F1Score - task: multiclass - num_classes: 10 - average: macro - - _target_: torchmetrics.Precision - task: multiclass - num_classes: 10 - average: macro - - _target_: torchmetrics.Recall - task: multiclass - num_classes: 10 - average: macro - val: "%::train" - test: "%::train" - - dataloaders: - train: - _target_: torch.utils.data.DataLoader - batch_size: 512 - num_workers: 2 - pin_memory: True - shuffle: True - dataset: - _target_: torchvision.datasets.CIFAR10 - download: True - root: ./.datasets/ - train: True - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - target_transform: null - - val: - _target_: torch.utils.data.DataLoader - batch_size: 512 - num_workers: 2 - pin_memory: True - shuffle: False - dataset: - _target_: torchvision.datasets.CIFAR10 - download: True - root: ./.datasets/ - train: False - transform: - _target_: torchvision.transforms.Compose - transforms: - - _target_: torchvision.transforms.ToTensor - - _target_: torchvision.transforms.Normalize - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] - target_transform: null - - test: "%::val" - predict: "%::val" - - adapters: - train: - batch: - _target_: lighter.adapters.BatchAdapter - input_accessor: 0 - target_accessor: 1 - val: "%::train" - test: "%::train" - predict: "%::train" diff --git a/projects/cifar10/experiments/example_overrides.yaml b/projects/cifar10/experiments/example_overrides.yaml deleted file mode 100644 index 6a4fdc8f..00000000 --- a/projects/cifar10/experiments/example_overrides.yaml +++ /dev/null @@ -1,2 +0,0 @@ -trainer::max_epochs: 1 -system::optimizer::_target_: torch.optim.SGD diff --git a/projects/cifar10/model.py b/projects/cifar10/model.py new file mode 100644 index 00000000..1f6690d1 --- /dev/null +++ b/projects/cifar10/model.py @@ -0,0 +1,61 @@ +"""CIFAR10 Model implementation using LighterModule class.""" + +from lighter import LighterModule + + +class CIFAR10Model(LighterModule): + """ + Simple classification model for CIFAR10. + + Users have full control over step logic while framework handles automatic logging. + """ + + def training_step(self, batch, batch_idx): + """Training step with user-defined logic.""" + # Extract batch data + x, y = batch + + # Forward pass + pred = self(x) + + # Compute loss using criterion from config + if self.criterion is None: + raise RuntimeError("criterion is required for training but was not set in config") + loss = self.criterion(pred, y) + + # Update metrics (user calls them explicitly) + if self.train_metrics is not None: + self.train_metrics(pred, y) + + # Return dict - framework logs automatically + return {"loss": loss, "pred": pred, "target": y} + + def validation_step(self, batch, batch_idx): + """Validation step with user-defined logic.""" + x, y = batch + pred = self(x) + if self.criterion is None: + raise RuntimeError("criterion is required for validation but was not set in config") + loss = self.criterion(pred, y) + + if self.val_metrics is not None: + self.val_metrics(pred, y) + + return {"loss": loss, "pred": pred, "target": y} + + def test_step(self, batch, batch_idx): + """Test step with user-defined logic.""" + x, y = batch + pred = self(x) + + if self.test_metrics is not None: + self.test_metrics(pred, y) + + # No loss required in test mode + return {"pred": pred, "target": y} + + def predict_step(self, batch, batch_idx): + """Prediction step - return dict for FileWriter compatibility.""" + x, y = batch + pred = self(x) + return {"pred": pred} diff --git a/pyproject.toml b/pyproject.toml index ecb9be54..9e1e3c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,9 +35,10 @@ dependencies = [ "torchmetrics>=1.2.0", "tensorboard>=2.11.2", "requests>=2.31.0", - "sparkwheel>=0.0.5", + "sparkwheel>=0.0.9", "rich>=13.7.0", "torchvision>=0.20.0", + "cloudpickle>=3.0.0", ] [project.urls] @@ -54,8 +55,6 @@ dev = [ { include-group = "quality" }, { include-group = "types" }, { include-group = "test" }, - "ipykernel==6.29.5", - "ipywidgets==8.1.5", "pre-commit-uv==4.1.4", ] doc = [ @@ -71,16 +70,15 @@ quality = ["ruff>=0.5.0"] types = [ "mypy>=1.14.1", "typing-extensions>=4.4.0", + "types-PyYAML>=6.0.0", ] test = [ "pytest>=7.4.0", "pytest-html>=3.2.0", "pytest-cov>=4.0.0", "coverage>=7.0.5", - "coverage-badge>=1.1.0", "pytest-metadata>=3.1.1", - "aiohttp>=3.8.3", - "av>=12.0.0", + "pytest-github-actions-annotate-failures>=0.2.0", ] [tool.mypy] @@ -91,7 +89,6 @@ show_traceback = true allow_redefinition = false check_untyped_defs = true -disallow_any_generics = true disallow_incomplete_defs = true ignore_missing_imports = true implicit_reexport = false diff --git a/src/lighter/__init__.py b/src/lighter/__init__.py index 371e5b7f..f0fab7af 100644 --- a/src/lighter/__init__.py +++ b/src/lighter/__init__.py @@ -8,7 +8,8 @@ _setup_logging() +from .data import LighterDataModule # noqa: E402 from .engine.runner import Runner # noqa: E402 -from .system import System # noqa: E402 +from .model import LighterModule # noqa: E402 -__all__ = ["Runner", "System"] +__all__ = ["LighterDataModule", "LighterModule", "Runner"] diff --git a/src/lighter/adapters.py b/src/lighter/adapters.py deleted file mode 100644 index 81de0ca0..00000000 --- a/src/lighter/adapters.py +++ /dev/null @@ -1,317 +0,0 @@ -from collections.abc import Callable -from typing import Any - -from lighter.utils.misc import ensure_list - - -class _TransformsAdapter: - """ - Adapter for applying transformations to data. - - Args: - input_transforms: A single or a list of transforms to apply to the input data. - target_transforms: A single or a list of transforms to apply to the target data. - pred_transforms: A single or a list of transforms to apply to the prediction data. - """ - - def __init__( - self, - input_transforms: Callable | list[Callable] | None = None, - target_transforms: Callable | list[Callable] | None = None, - pred_transforms: Callable | list[Callable] | None = None, - ): - self.input_transforms = input_transforms - self.target_transforms = target_transforms - self.pred_transforms = pred_transforms - - def __call__(self, input: Any, target: Any, pred: Any) -> tuple[Any, Any, Any]: - """ - Applies the specified transforms to the input, target, and prediction data. - - Args: - input: The input data. - target: The target data. - pred: The prediction data. - - Returns: - The transformed (input, target, prediction) data. - """ - input = self._transform(input, self.input_transforms) - target = self._transform(target, self.target_transforms) - pred = self._transform(pred, self.pred_transforms) - return input, target, pred - - def _transform(self, data: Any, transforms: Callable | list[Callable]) -> Any: - """ - Applies a list of transform functions to the data. - - Args: - data: The data to be transformed. - transforms: A single transform function or a list of functions. - - Returns: - The transformed data. - - Raises: - ValueError: If any transform is not callable. - """ - for transform in ensure_list(transforms): - if callable(transform): - data = transform(data) - else: - raise ValueError(f"Invalid transform type for transform: {transform}") - return data - - -class _ArgumentsAdapter: - """ - Base adapter for adapting arguments to a function based on specified argument names or positions. - """ - - def __init__( - self, - input_argument: int | str | None = None, - target_argument: int | str | None = None, - pred_argument: int | str | None = None, - ): - # Ensure that the positionals are consecutive integers. - # There cannot be positional 0 and 2, without 1. Same with a positional 1 without 0. - positionals = sorted(arg for arg in (input_argument, target_argument, pred_argument) if isinstance(arg, int)) - if positionals != list(range(len(positionals))): - raise ValueError("Positional arguments must be consecutive integers starting from 0.") - - self.input_argument = input_argument - self.target_argument = target_argument - self.pred_argument = pred_argument - - def __call__(self, input: Any, target: Any, pred: Any) -> tuple[list[Any], dict[str, Any]]: - """ - Adapts the input, target, and prediction data to the specified argument positions or names. - - Args: - input: The input data to be adapted. - target: The target data to be adapted. - pred: The prediction data to be adapted. - - Returns: - A tuple containing a list of positional arguments and a dictionary of keyword arguments. - """ - args = [] # List to store positional arguments - kwargs = {} # Dictionary to store keyword arguments - - # Mapping of argument names to their respective values - argument_map = {"input_argument": input, "target_argument": target, "pred_argument": pred} - - # Iterate over the argument map to adapt arguments - for arg_name, value in argument_map.items(): - # Get the position or name of the argument from the instance attributes - arg_position = getattr(self, arg_name) - if arg_position is not None: - if isinstance(arg_position, int): - # Insert the value into the args list at the specified position - args.insert(arg_position, value) - elif isinstance(arg_position, str): - # Add the value to the kwargs dictionary with the specified name - kwargs[arg_position] = value - else: - # Raise an error if the argument type is invalid - raise ValueError(f"Invalid {arg_name} type: {type(arg_position)}") - - # Return the adapted positional and keyword arguments - return args, kwargs - - -class _ArgumentsAndTransformsAdapter(_ArgumentsAdapter, _TransformsAdapter): - """ - A generic adapter for applying functions (criterion or metrics) to data. - """ - - def __init__( - self, - input_argument: int | str | None = None, - target_argument: int | str | None = None, - pred_argument: int | str | None = None, - input_transforms: list[Callable] | None = None, - target_transforms: list[Callable] | None = None, - pred_transforms: list[Callable] | None = None, - ): - """ - Initializes the Arguments and Transforms Adapter. - - Args: - input_argument: Position or name for the input data. - target_argument: Position or name for the target data. - pred_argument: Position or name for the prediction data. - input_transforms: Transforms to apply to the input data. - target_transforms: Transforms to apply to the target data. - pred_transforms: Transforms to apply to the prediction data. - - Raises: - ValueError: If transforms are provided without corresponding argument specifications. - """ - # Validate transform arguments - if input_argument is None and input_transforms is not None: - raise ValueError("Input transforms provided but input_argument is None") - if target_argument is None and target_transforms is not None: - raise ValueError("Target transforms provided but target_argument is None") - if pred_argument is None and pred_transforms is not None: - raise ValueError("Pred transforms provided but pred_argument is None") - - _ArgumentsAdapter.__init__(self, input_argument, target_argument, pred_argument) - _TransformsAdapter.__init__(self, input_transforms, target_transforms, pred_transforms) - - def __call__(self, fn: Callable, input: Any, target: Any, pred: Any) -> Any: - """ - Applies transforms and adapts arguments before calling the provided function. - - Args: - fn: The function/method to be called (e.g., a loss function or metric). - input: The input data. - target: The target data. - pred: The prediction data. - - Returns: - The result of the function call. - """ - # Apply the transforms to the input, target, and prediction data - input, target, pred = _TransformsAdapter.__call__(self, input, target, pred) - # Map the input, target, and prediction data to the function arguments - args, kwargs = _ArgumentsAdapter.__call__(self, input, target, pred) - # Call the provided function with the adapted arguments - return fn(*args, **kwargs) - - -class BatchAdapter: - def __init__( - self, - input_accessor: int | str | Callable, - target_accessor: int | str | Callable | None = None, - identifier_accessor: int | str | Callable | None = None, - ): - """ - Initializes BatchAdapter with accessors for input, target, and identifier. - - Args: - input_accessor: Accessor for the identifier data. Can be an index (lists/tuples), a key (dictionaries), - a callable (custom batch processing). - target_accessor: Accessor for the target data. Can be an index (for lists/tuples), - a key (for dictionaries), or a callable (for custom batch processing). - identifier_accessor: Accessor for the identifier data. Can be an index (lists/tuples), a key (dictionaries), - a callable (custom batch processing), or None if no identifier is present. - """ - self.input_accessor = input_accessor - self.target_accessor = target_accessor - self.identifier_accessor = identifier_accessor - - def __call__(self, batch: Any) -> tuple[Any, Any, Any]: - """ - Accesses the identifier, input, and target data from the batch. - - Args: - batch: The batch data from which to extract information. - - Returns: - A tuple containing (identifier, input, target). - - Raises: - ValueError: If accessors are invalid for the provided batch structure. - """ - input = self._access_value(batch, self.input_accessor) - target = self._access_value(batch, self.target_accessor) - identifier = self._access_value(batch, self.identifier_accessor) - return input, target, identifier - - def _access_value(self, data: Any, accessor: int | str | Callable) -> Any: - """ - Accesses a value from the data using the provided accessor. - - Args: - data: The data to access the value from. - accessor: The accessor to use. Can be an index (for lists/tuples), - a key (for dictionaries), or a callable. - - Returns: - The accessed value. - - Raises: - ValueError: If the accessor type or data structure is invalid. - """ - if accessor is None: - return None - elif isinstance(accessor, int) and isinstance(data, (tuple, list)): - return data[accessor] - elif isinstance(accessor, str) and isinstance(data, dict): - return data[accessor] - elif callable(accessor): - return accessor(data) - else: - raise ValueError(f"Invalid accessor {accessor} of type {type(accessor)} for data type {type(data)}.") - - -class CriterionAdapter(_ArgumentsAndTransformsAdapter): - """ - This adapter processes and transforms the input, target, and prediction data, if specified, - and forwards them to the specified arguments of a criterion (loss function). - """ - - def __call__(self, criterion: Callable, input: Any, target: Any, pred: Any) -> Any: - """ - Applies transforms and adapts arguments before calling the provided metric function. - - Args: - criterion: The criterion (loss function). - input: The input data to transform with `input_transforms` if specified and pass to the metric with - the position or argument name specified by `input_argument`. - target: The target data to transform with `target_transforms` if specified and pass to the metric with - the position or argument name specified by `target_argument`. - pred: The prediction data to transform with `pred_transforms` if specified and pass to the metric with - the position or argument name specified by `pred_argument`. - - Returns: - The result of the metric function call. - """ - return super().__call__(criterion, input, target, pred) - - -class MetricsAdapter(_ArgumentsAndTransformsAdapter): - """ - This adapter processes and transforms the input, target, and prediction data, if specified, - and forwards them to the specified arguments of a metric. - """ - - def __call__(self, metric: Callable, input: Any, target: Any, pred: Any) -> Any: - """ - Applies transforms and adapts arguments before calling the provided metric function. - - Args: - metric: The metric. - input: The input data to transform with `input_transforms` if specified and pass to the metric with - the position or argument name specified by `input_argument`. - target: The target data to transform with `target_transforms` if specified and pass to the metric with - the position or argument name specified by `target_argument`. - pred: The prediction data to transform with `pred_transforms` if specified and pass to the metric with - the position or argument name specified by `pred_argument`. - - Returns: - The result of the metric function call. - """ - return super().__call__(metric, input, target, pred) - - -class LoggingAdapter(_TransformsAdapter): - """ - Adapter for applying logging transformations to data. - - This adapter handles the transformation of input, target, and prediction data - specifically for logging purposes. It can preprocess or format the data before - logging, ensuring consistency and readability in logs. - - """ - - def __init__( - self, - input_transforms: list[Callable] | None = None, - target_transforms: list[Callable] | None = None, - pred_transforms: list[Callable] | None = None, - ): - super().__init__(input_transforms, target_transforms, pred_transforms) diff --git a/src/lighter/callbacks/__init__.py b/src/lighter/callbacks/__init__.py index e37b8d2d..fabed5d9 100644 --- a/src/lighter/callbacks/__init__.py +++ b/src/lighter/callbacks/__init__.py @@ -1,5 +1,5 @@ +from .csv_writer import CsvWriter +from .file_writer import FileWriter from .freezer import Freezer -from .writer.file import FileWriter -from .writer.table import TableWriter -__all__ = ["Freezer", "FileWriter", "TableWriter"] +__all__ = ["CsvWriter", "FileWriter", "Freezer"] diff --git a/src/lighter/callbacks/base_writer.py b/src/lighter/callbacks/base_writer.py new file mode 100644 index 00000000..39c21f05 --- /dev/null +++ b/src/lighter/callbacks/base_writer.py @@ -0,0 +1,75 @@ +import gc +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +from loguru import logger +from pytorch_lightning import Callback, Trainer + +from lighter.model import LighterModule +from lighter.utils.types.enums import Stage + + +class BaseWriter(ABC, Callback): + """ + Base class for defining custom Writers. It provides a structure to save predictions. + + Subclasses should implement the `write` method to define the saving strategy. + + Args: + path (str | Path): Path for saving predictions. + """ + + def __init__(self, path: str | Path) -> None: + self.path = Path(path) + + @abstractmethod + def write(self, outputs: dict[str, Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """ + Abstract method to define how the outputs of a prediction batch should be saved. + Args: + outputs: The dictionary of outputs from the prediction step. + batch: The current batch. + batch_idx: The index of the batch. + dataloader_idx: The index of the dataloader. + """ + + def setup(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None: + if stage != Stage.PREDICT: + return + + self.path = trainer.strategy.broadcast(self.path, src=0) + directory = self.path.parent if self.path.suffix else self.path + + if self.path.exists(): + logger.warning(f"{self.path} already exists, existing predictions will be overwritten.") + + if trainer.is_global_zero: + directory.mkdir(parents=True, exist_ok=True) + + trainer.strategy.barrier() + + if not directory.exists(): + raise RuntimeError( + f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access." + ) + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: LighterModule, + outputs: dict[str, Any], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if not outputs: + return + self.write(outputs, batch, batch_idx, dataloader_idx) + + # Clear the predictions to save CPU memory. This is a temporary workaround for a known issue in PyTorch + # Lightning, where predictions can accumulate in memory. This line accesses a private attribute + # `_predictions` of the `predict_loop`, which is a brittle dependency and may break in future + # versions of Lightning. For more details, see: https://github.com/Lightning-AI/pytorch-lightning/issues/19398 + trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)] + gc.collect() diff --git a/src/lighter/callbacks/csv_writer.py b/src/lighter/callbacks/csv_writer.py new file mode 100644 index 00000000..469336ab --- /dev/null +++ b/src/lighter/callbacks/csv_writer.py @@ -0,0 +1,182 @@ +""" +This module provides the CsvWriter class, which saves predictions in a table format, such as CSV. +""" + +import csv +from io import TextIOWrapper +from pathlib import Path +from typing import Any + +import pandas as pd +import torch +import torch.distributed as dist +from pytorch_lightning import Trainer + +from lighter.callbacks.base_writer import BaseWriter +from lighter.model import LighterModule +from lighter.utils.types.enums import Stage + + +class CsvWriter(BaseWriter): + """ + Writer for saving predictions in a CSV format. It accumulates predictions in a temporary + file and saves them to the final destination at the end of the prediction epoch. + + Args: + path (str | Path): Path to save the final CSV file. + keys (list[str]): A list of keys to be included in the CSV file. + These keys must be present in the `outputs` dictionary + from the prediction step. + + Example: + ```yaml + trainer: + callbacks: + - _target_: lighter.callbacks.CsvWriter + path: predictions.csv + keys: [id, pred, target] + ``` + """ + + def __init__(self, path: str | Path, keys: list[str]) -> None: + super().__init__(path) + self.keys = keys + self._temp_path: Path | None = None + self._csv_writer: Any = None # csv.writer type is not easily annotated + self._csv_file: TextIOWrapper | None = None + + def _close_file(self) -> None: + """Close the CSV file if it's open and reset related state.""" + if self._csv_file is not None and not self._csv_file.closed: + self._csv_file.close() + self._csv_file = None + self._csv_writer = None + + def setup(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None: + if stage != Stage.PREDICT: + return + super().setup(trainer, pl_module, stage) + + # Create a temporary file for writing predictions + self._temp_path = self.path.with_suffix(f".tmp_rank{trainer.global_rank}{self.path.suffix}") + self._csv_file = open(self._temp_path, "w", newline="") + self._csv_writer = csv.writer(self._csv_file) + # Write header + self._csv_writer.writerow(self.keys) + + def _get_sequence_length(self, value: Any) -> int | None: + if isinstance(value, (list, tuple)): + return len(value) + elif isinstance(value, torch.Tensor): + if value.ndim == 0: # Scalar tensor + return 1 + else: + return len(value) # For non-scalar tensors, len() works + return None # Not a sequence type we care about + + def _get_record_value(self, value: Any, index: int) -> Any: + if isinstance(value, (list, tuple)): + return value[index] + elif isinstance(value, torch.Tensor): + if value.ndim == 0: # Scalar tensor + return value.item() # Get Python scalar + else: + # For non-scalar tensors, get the item at index. + # If the item itself is a scalar tensor, convert to Python scalar. + # Otherwise, convert to a list (e.g., for image data). + item = value[index] + return item.item() if item.ndim == 0 else item.tolist() + else: + return value # Non-sequence value, return as is (assumed to be for all samples) + + def write(self, outputs: dict[str, Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: + if self._csv_writer is None: + return + + # Validate that at least one configured key is present in outputs + present_keys = [key for key in self.keys if key in outputs] + if not present_keys: + missing_keys = self.keys + raise KeyError( + f"CsvWriter: none of the configured keys {missing_keys} were found in outputs. " + f"Available keys in outputs: {list(outputs.keys())}" + ) + + # Determine the number of samples in the batch. + num_samples = 0 + for key in self.keys: + if key in outputs: + length = self._get_sequence_length(outputs[key]) + if length is not None: + num_samples = length + break + else: + # If it's not a sequence type we handle, assume it's a single sample + if num_samples == 0: + num_samples = 1 + + # Validate that all list-like or tensor outputs have the same length + for key in self.keys: + if key in outputs: + current_len = self._get_sequence_length(outputs[key]) + + # Only validate if it's a sequence type and its length is not None + if current_len is not None and current_len != num_samples: + raise ValueError( + f"CsvWriter found inconsistent lengths for keys: " + f"expected {num_samples}, but found {current_len} for key '{key}'." + ) + + # Transpose the dictionary of lists into a list of per-sample records and write to CSV + for i in range(num_samples): + record = [] + for key in self.keys: + if key not in outputs: + raise KeyError(f"CsvWriter expected key '{key}' in outputs but it was missing.") + + value = outputs[key] + record.append(self._get_record_value(value, i)) + self._csv_writer.writerow(record) + + def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterModule) -> None: + """ + At the end of the prediction epoch, it saves the temporary file to the final destination. + """ + if self._csv_file is None: + return + + # Close the temporary file + self._close_file() + + all_temp_paths: list[Path | None] = [None] * trainer.world_size + if dist.is_initialized(): + dist.all_gather_object(all_temp_paths, self._temp_path) + else: + all_temp_paths = [self._temp_path] + + if trainer.is_global_zero: + # Read all temporary files into pandas DataFrames and concatenate them + dfs = [pd.read_csv(path) for path in all_temp_paths if path is not None] + if not dfs: + return + df = pd.concat(dfs, ignore_index=True) + + # Save the final CSV file + df.to_csv(self.path, index=False) + + # Remove all temporary files + for path in all_temp_paths: + if path is not None: + path.unlink() + + # Reset temporary path + self._temp_path = None + + def on_exception(self, trainer: Trainer, pl_module: LighterModule, exception: BaseException) -> None: + """Close the file on errors to prevent file handle leaks.""" + self._close_file() + + def teardown(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None: + """Guarantee cleanup when stage is PREDICT.""" + if stage == Stage.PREDICT: + self._close_file() diff --git a/src/lighter/callbacks/file_writer.py b/src/lighter/callbacks/file_writer.py new file mode 100644 index 00000000..64ae6d32 --- /dev/null +++ b/src/lighter/callbacks/file_writer.py @@ -0,0 +1,230 @@ +"""Callback for persisting prediction artifacts to the filesystem.""" + +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any + +import torch +import torchvision +from loguru import logger +from pytorch_lightning import Trainer + +from lighter.callbacks.base_writer import BaseWriter +from lighter.model import LighterModule +from lighter.utils.types.enums import Stage + +# +# Registry +# + + +class WriterRegistry: + """A registry for writer functions, allowing them to be registered by name and retrieved later.""" + + def __init__(self) -> None: + self._registry: dict[str, Callable] = {} + + def register(self, name: str) -> Callable: + """Register a new writer function in this registry as a decorator. + + Args: + name: The unique name to register the writer under. + + Returns: + A decorator that registers the decorated function. + + Raises: + ValueError: If a writer with the given name is already registered. + """ + + def decorator(fn: Callable) -> Callable: + if name in self._registry: + raise ValueError(f"Writer with name '{name}' is already registered.") + self._registry[name] = fn + return fn + + return decorator + + def get(self, name: str) -> Callable: + """Get a writer from the registry by its registered name. + + Args: + name: The name of the writer to retrieve. + + Returns: + The writer function associated with the given name. + + Raises: + ValueError: If no writer with the given name is registered. + """ + if name not in self._registry: + raise ValueError(f"Writer with name '{name}' is not registered.") + return self._registry[name] + + +writer_registry = WriterRegistry() + + +# +# Writer Functions +# +@writer_registry.register(name="tensor") +def write_tensor(path: Path, tensor: torch.Tensor, *, suffix: str = ".pt") -> None: + """Serialise a tensor to disk using :func:`torch.save`.""" + + torch.save(tensor, path.with_suffix(suffix)) # nosec B614 + + +@writer_registry.register(name="image_2d") +def write_image_2d(path: Path, tensor: torch.Tensor, *, suffix: str = ".png") -> None: + """Write a 2D tensor as an image using PNG encoding.""" + if tensor.ndim != 3: + raise ValueError(f"write_image_2d expects a 3D tensor (CHW), got {tensor.ndim} dimensions.") + path = path.with_suffix(suffix) + # Scale to [0, 255] and convert to uint8 + tensor = (tensor.float().clamp(0, 1) * 255).to(torch.uint8) + torchvision.io.write_png(tensor, str(path)) + + +@writer_registry.register(name="image_3d") +def write_image_3d(path: Path, tensor: torch.Tensor, *, suffix: str = ".png") -> None: + """Write a 3D tensor as a 2D image by stacking slices vertically.""" + if tensor.ndim != 4: + raise ValueError(f"write_image_3d expects a 4D tensor (CDHW), got {tensor.ndim} dimensions.") + path = path.with_suffix(suffix) + # CDHW -> C(D*H)W + shape = tensor.shape + tensor = tensor.view(shape[0], shape[1] * shape[2], shape[3]) + # Scale to [0, 255] and convert to uint8 + tensor = (tensor.float().clamp(0, 1) * 255).to(torch.uint8) + torchvision.io.write_png(tensor, str(path)) + + +@writer_registry.register(name="text") +def write_text(path: Path, value: Any, *, suffix: str = ".txt", encoding: str = "utf-8") -> None: + """Write the string representation of *value* to disk.""" + + path = path.with_suffix(suffix) + with path.open("w", encoding=encoding) as file: + file.write(str(value)) + + +class FileWriter(BaseWriter): + """ + Persist a prediction value per sample to disk. + + Args: + directory: Directory to save prediction files. + value_key: Key in the prediction outputs dict containing values to save. + writer_fn: Writer function name (e.g., "tensor", "image_2d", "text") or callable. + name_key: Optional key for custom file names. If None, uses sequential numbering. + + Example: + ```yaml + trainer: + callbacks: + - _target_: lighter.callbacks.FileWriter + directory: predictions/ + value_key: pred + writer_fn: tensor + ``` + """ + + def __init__( + self, + directory: str | Path, + value_key: str, + writer_fn: str | Callable[[Path, Any], None], + name_key: str | None = None, + ) -> None: + super().__init__(directory) + self.value_key = value_key + self.name_key = name_key + if isinstance(writer_fn, str): + self.writer_fn = writer_registry.get(writer_fn) + elif callable(writer_fn): + self.writer_fn = writer_fn + else: + raise TypeError("writer_fn must be a string or a callable") + + self._counter: int | None = None + self._step: int = 1 + + def setup(self, trainer: Trainer, pl_module: LighterModule, stage: str) -> None: + super().setup(trainer, pl_module, stage) + if stage != Stage.PREDICT: + return + + if self.path.suffix: + raise ValueError("FileWriter expects 'directory' to be a directory path, not a file path") + + if trainer.is_global_zero: + self.path.mkdir(parents=True, exist_ok=True) + + if trainer.world_size > 1: + self._step = trainer.world_size + self._counter = trainer.global_rank + else: + self._step = 1 + self._counter = 0 + + def write(self, outputs: dict[str, Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: # noqa: ARG002 + if self._counter is None: + logger.debug("FileWriter received outputs before setup; skipping batch") + return + + values = self._to_sequence(outputs, self.value_key) + if not values: + logger.debug("FileWriter value key '{}' yielded no samples; skipping batch", self.value_key) + return + + if self.name_key is not None: + names = self._to_sequence(outputs, self.name_key) + if len(names) != len(values): + raise ValueError( + "Length mismatch between value key " + f"'{self.value_key}' ({len(values)}) and name key " + f"'{self.name_key}' ({len(names)})." + ) + else: + names = [] + + for offset, value in enumerate(values): + global_index = self._counter + offset * self._step + name = self._prepare_name(names[offset]) if names else global_index + + target_path = self.path / str(name) + target_path.parent.mkdir(parents=True, exist_ok=True) + + prepared_value = self._prepare_value(value) + self.writer_fn(target_path, prepared_value) + + self._counter += len(values) * self._step + + @staticmethod + def _prepare_value(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().cpu() + return value + + @staticmethod + def _prepare_name(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().cpu().item() if value.ndim == 0 else value.detach().cpu().tolist() + return value + + @staticmethod + def _to_sequence(outputs: dict[str, Any], key: str) -> list: + if key not in outputs: + raise KeyError(f"FileWriter expected key '{key}' in outputs but it was missing.") + + value = outputs[key] + if isinstance(value, torch.Tensor): + if value.ndim == 0: + return [value] + return [tensor for tensor in value] + if isinstance(value, (list, tuple)): + return list(value) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return list(value) + return [value] diff --git a/src/lighter/callbacks/freezer.py b/src/lighter/callbacks/freezer.py index ac4eef99..d4357419 100644 --- a/src/lighter/callbacks/freezer.py +++ b/src/lighter/callbacks/freezer.py @@ -5,10 +5,8 @@ from typing import Any from loguru import logger -from pytorch_lightning import Callback, Trainer -from torch.nn import Module +from pytorch_lightning import Callback, LightningModule, Trainer -from lighter import System from lighter.utils.misc import ensure_list @@ -57,101 +55,94 @@ def __init__( self._frozen_state = False - def on_train_batch_start(self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int) -> None: + def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: """ - Called at the start of each training batch to potentially freeze parameters. + Called at the start of each training batch to freeze or unfreeze model parameters. Args: trainer: The trainer instance. - pl_module: The System instance. + pl_module: The LightningModule instance. batch: The current batch. batch_idx: The index of the batch. """ - self._on_batch_start(trainer, pl_module) - - def on_validation_batch_start( - self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> None: - self._on_batch_start(trainer, pl_module) - - def on_test_batch_start( - self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> None: - self._on_batch_start(trainer, pl_module) - - def on_predict_batch_start( - self, trainer: Trainer, pl_module: System, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> None: - self._on_batch_start(trainer, pl_module) - - def _on_batch_start(self, trainer: Trainer, pl_module: System) -> None: - """ - Freezes or unfreezes model parameters based on the current step or epoch. - - Args: - trainer: The trainer instance. - pl_module: The System instance. - """ current_step = trainer.global_step current_epoch = trainer.current_epoch - if self.until_step is not None and current_step >= self.until_step: - if self._frozen_state: - logger.info(f"Reached step {self.until_step} - unfreezing the previously frozen layers.") - self._set_model_requires_grad(pl_module, True) - return - - if self.until_epoch is not None and current_epoch >= self.until_epoch: + # Unfreeze if the step or epoch limit has been reached. + unfreeze_step = self.until_step is not None and current_step >= self.until_step + unfreeze_epoch = self.until_epoch is not None and current_epoch >= self.until_epoch + if unfreeze_step or unfreeze_epoch: if self._frozen_state: - logger.info(f"Reached epoch {self.until_epoch} - unfreezing the previously frozen layers.") - self._set_model_requires_grad(pl_module, True) + logger.info("Unfreezing the model.") + self._set_model_requires_grad(pl_module, requires_grad=True) + self._frozen_state = False return + # Freeze if not already frozen. if not self._frozen_state: - self._set_model_requires_grad(pl_module, False) + logger.info("Freezing the model.") + self._set_model_requires_grad(pl_module, requires_grad=False) + self._frozen_state = True - def _set_model_requires_grad(self, model: Module | System, requires_grad: bool) -> None: + def _set_model_requires_grad(self, model: LightningModule, requires_grad: bool) -> None: """ - Sets the requires_grad attribute for model parameters, effectively freezing or unfreezing them. + Sets the `requires_grad` attribute for model parameters. + + When freezing (requires_grad=False): + - Freeze specified parameters + - Keep all others trainable (requires_grad=True) + - Respect exception rules + + When unfreezing (requires_grad=True): + - Unfreeze specified parameters + - Keep all others trainable Args: model: The model whose parameters to modify. requires_grad: Whether to allow gradients (unfreeze) or not (freeze). """ - # If the model is a `System`, get the underlying PyTorch model. - if isinstance(model, System): - model = model.model + # If the model is a `LighterModule`, get the underlying network so users + # can specify layer names without the "network." prefix. + from lighter import LighterModule + + target = model.network if isinstance(model, LighterModule) else model frozen_layers = [] - # Freeze the specified parameters. - for name, param in model.named_parameters(): - # Leave the excluded-from-freezing parameters trainable. - if self.except_names and name in self.except_names: - param.requires_grad = True - continue - if self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with): + unfrozen_layers = [] + + for name, param in target.named_parameters(): + # Check if the parameter should be excluded from freezing. + is_excepted = (self.except_names and name in self.except_names) or ( + self.except_name_starts_with and any(name.startswith(prefix) for prefix in self.except_name_starts_with) + ) + if is_excepted: + # Exceptions are always trainable param.requires_grad = True + if not requires_grad: # Only log when we're in freezing mode + unfrozen_layers.append(name) continue - # Freeze/unfreeze the specified parameters, based on the `requires_grad` argument. - if self.names and name in self.names: - param.requires_grad = requires_grad - frozen_layers.append(name) - continue - if self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with): + # Check if the parameter should be frozen/unfrozen. + is_to_freeze = (self.names and name in self.names) or ( + self.name_starts_with and any(name.startswith(prefix) for prefix in self.name_starts_with) + ) + if is_to_freeze: param.requires_grad = requires_grad - frozen_layers.append(name) - continue - - # Otherwise, leave the parameter trainable. - param.requires_grad = True + if not requires_grad: + frozen_layers.append(name) + else: + unfrozen_layers.append(name) + else: + # Not specified and not excepted - keep trainable + param.requires_grad = True - self._frozen_state = not requires_grad - # Log only when freezing the parameters. - if self._frozen_state: + # Log the frozen/unfrozen layers. + if frozen_layers: logger.info( - f"Setting requires_grad={requires_grad} the following layers" + f"Froze layers: {frozen_layers}" + (f" until step {self.until_step}" if self.until_step is not None else "") + (f" until epoch {self.until_epoch}" if self.until_epoch is not None else "") - + f": {frozen_layers}" ) + if unfrozen_layers: + suffix = " (excepted from freeze)" if not requires_grad else "" + logger.info(f"Unfroze layers: {unfrozen_layers}{suffix}") diff --git a/src/lighter/callbacks/utils.py b/src/lighter/callbacks/utils.py deleted file mode 100644 index 9028edb6..00000000 --- a/src/lighter/callbacks/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -This module provides utility functions for callbacks, including mode conversion and image preprocessing. -""" - -import torch -import torchvision -from torch import Tensor - - -def preprocess_image(image: Tensor) -> Tensor: - """ - Preprocess image for logging. For multiple 2D images, creates a grid. - For 3D images, stacks slices vertically. For multiple 3D images, creates a grid - with each column showing a different 3D image stacked vertically. - - Args: - image: A 2D or 3D image tensor. - - Returns: - Tensor: The preprocessed image ready for logging. - """ - # If 3D (BCDHW), concat the images vertically and horizontally. - if image.ndim == 5: - shape = image.shape - # BCDHW -> BC(D*H)W. Combine slices of a 3D images vertically into a single 2D image. - image = image.view(shape[0], shape[1], shape[2] * shape[3], shape[4]) - # BCDHW -> 1CDH(B*W). Concat images in the batch horizontally, and unsqueeze to add back the B dim. - image = torch.cat([*image], dim=-1).unsqueeze(0) - # If only one image in the batch, select it and return it. Same happens when the images are 3D as they - # are combined into a single image. `make_grid` is called when a batch of multiple 2D image is provided. - return image[0] if image.shape[0] == 1 else torchvision.utils.make_grid(image, nrow=8) diff --git a/src/lighter/callbacks/writer/__init__.py b/src/lighter/callbacks/writer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/lighter/callbacks/writer/base.py b/src/lighter/callbacks/writer/base.py deleted file mode 100644 index 4a29c356..00000000 --- a/src/lighter/callbacks/writer/base.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -This module provides the base class for defining custom writers in Lighter, -allowing predictions to be saved in various formats. -""" - -import gc -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Callable - -import torch -from loguru import logger -from pytorch_lightning import Callback, Trainer -from torch import Tensor - -from lighter import System -from lighter.utils.types.enums import Data, Stage - - -class BaseWriter(ABC, Callback): - """ - Base class for defining custom Writer. It provides the structure to save predictions in various formats. - - Subclasses should implement: - 1) `self.writers` attribute to specify the supported formats and their corresponding writer functions. - 2) `self.write()` method to specify the saving strategy for a prediction. - - Args: - path (str | Path): Path for saving predictions. - writer (str | Callable): Writer function or name of a registered writer. - """ - - def __init__(self, path: str | Path, writer: str | Callable) -> None: - self.path = Path(path) - - # Check if the writer is a string and if it exists in the writers dictionary - if isinstance(writer, str): - if writer not in self.writers: - raise ValueError(f"Writer for format {writer} does not exist. Available writers: {self.writers.keys()}.") - self.writer = self.writers[writer] - else: - # If the writer is not a string, it is assumed to be a callable function - self.writer = writer - - # Prediction counter. Used when IDs are not provided. Initialized in `self.setup()` based on the DDP rank. - self._pred_counter = None - - @property - @abstractmethod - def writers(self) -> dict[str, Callable]: - """ - Property to define the default writer functions. - """ - - @abstractmethod - def write(self, tensor: Tensor, identifier: int | str) -> None: - """ - Method to define how a tensor should be saved. The input tensor will be a single tensor without - the batch dimension. - - For each supported format, there should be a corresponding writer function registered in `self.writers` - A specific writer function can be retrieved using `self.get_writer(self.format)`. - - Args: - tensor (Tensor): Tensor, without the batch dimension, to be saved. - identifier (int): Identifier for the tensor, can be used for naming files or adding table records. - """ - - def setup(self, trainer: Trainer, pl_module: System, stage: str) -> None: - """ - Sets up the writer, ensuring the path is ready for saving predictions. - - Args: - trainer (Trainer): The trainer instance. - pl_module (System): The System instance. - stage (str): The current stage of training. - """ - if stage != Stage.PREDICT: - return - - # Initialize the prediction count with the rank of the current process - self._pred_counter = torch.distributed.get_rank() if trainer.world_size > 1 else 0 - - # Ensure all distributed nodes write to the same path - self.path = trainer.strategy.broadcast(self.path, src=0) - directory = self.path.parent if self.path.suffix else self.path - - # Warn if the path already exists - if self.path.exists(): - logger.warning(f"{self.path} already exists, existing predictions will be overwritten.") - - if trainer.is_global_zero: - directory.mkdir(parents=True, exist_ok=True) - - # Wait for rank 0 to create the directory - trainer.strategy.barrier() - - # Ensure all distributed nodes have access to the path - if not directory.exists(): - raise RuntimeError( - f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access." - ) - - def on_predict_batch_end( - self, trainer: Trainer, pl_module: System, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> None: - """ - Callback method executed at the end of each prediction batch to write predictions with unique IDs. - - Args: - trainer (Trainer): The trainer instance. - pl_module (System): The System instance. - outputs (Any): The outputs from the prediction step. - batch (Any): The current batch. - batch_idx (int): The index of the batch. - dataloader_idx (int): The index of the dataloader. - """ - # If the IDs are not provided, generate global unique IDs based on the prediction count. DDP supported. - if outputs[Data.IDENTIFIER] is None: - batch_size = len(outputs[Data.PRED]) - world_size = trainer.world_size - outputs[Data.IDENTIFIER] = list( - range( - self._pred_counter, # Start: counted globally, initialized with the rank of the current process - self._pred_counter + batch_size * world_size, # Stop: count the total batch size across all processes - world_size, # Step: each process writes predictions for every Nth sample - ) - ) - self._pred_counter += batch_size * world_size - - # Ensure equal number of predictions and identifiers - if len(outputs[Data.IDENTIFIER]) != len(outputs[Data.PRED]): - raise ValueError( - f"The number of predictions ({len(outputs[Data.PRED])}) does not" - f"match the number of identifiers ({len(outputs[Data.IDENTIFIER])})" - ) - - for pred, identifier in zip(outputs[Data.PRED], outputs[Data.IDENTIFIER], strict=True): - self.write(tensor=pred, identifier=identifier) - - # Clear the predictions to save CPU memory. https://github.com/Lightning-AI/pytorch-lightning/issues/19398 - trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)] - gc.collect() diff --git a/src/lighter/callbacks/writer/file.py b/src/lighter/callbacks/writer/file.py deleted file mode 100644 index 95ea9242..00000000 --- a/src/lighter/callbacks/writer/file.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -This module provides the FileWriter class, which writes predictions to files in various formats. -""" - -from typing import Callable - -import torch -import torchvision -from torch import Tensor - -from lighter.callbacks.utils import preprocess_image -from lighter.callbacks.writer.base import BaseWriter - - -class FileWriter(BaseWriter): - """ - Writer for saving predictions to files in various formats including tensors, images, and videos. - Custom writer functions can be provided to extend supported formats. - Args: - path: Directory path where output files will be saved. - writer: Either a string specifying a built-in writer or a custom writer function. - Built-in writers: - - "tensor": Saves raw tensor data (.pt) - - "image": Saves as image file (.png) - - "video": Saves as video file (.mp4) - Custom writers must: - - Accept (path, tensor) arguments - - Handle single tensor input (no batch dimension) - - Save output to the specified path - """ - - @property - def writers(self) -> dict[str, Callable]: - return { - "tensor": write_tensor, - "image": write_image, - "video": write_video, - } - - def write(self, tensor: Tensor, identifier: int | str) -> None: - """ - Writes the tensor to a file using the specified writer. - - Args: - tensor: The tensor to write. - identifier: Identifier for naming the file. - """ - if not self.path.is_dir(): - raise RuntimeError(f"FileWriter expects a directory path, got {self.path}") - - # Determine the path for the file based on prediction count. The suffix must be added by the writer function. - path = self.path / str(identifier) - # Write the tensor to the file. - self.writer(path, tensor) - - -def write_tensor(path, tensor): - """ - Writes a tensor to a file in .pt format. - - Args: - path: The path to save the tensor. - tensor: The tensor to save. - """ - torch.save(tensor, path.with_suffix(".pt")) # nosec B614 - - -def write_image(path, tensor): - """ - Writes a tensor as an image file in .png format. - - Args: - path: The path to save the image. - tensor: The tensor representing the image. - """ - path = path.with_suffix(".png") - tensor = preprocess_image(tensor) - torchvision.io.write_png(tensor, str(path)) - - -def write_video(path, tensor): - """ - Writes a tensor as a video file in .mp4 format. - - Args: - path: The path to save the video. - tensor: The tensor representing the video in CTHW format. - """ - path = path.with_suffix(".mp4") - # Video tensor must be divisible by 2. Pad the height and width if needed. - _, _, h, w = tensor.shape - pad_h = (2 - h % 2) % 2 - pad_w = (2 - w % 2) % 2 - if pad_h > 0 or pad_w > 0: - tensor = torch.nn.functional.pad(tensor, (0, pad_w, 0, pad_h)) - # Video tensor must be THWC. Permute CTHW -> THWC. - tensor = tensor.permute(1, 2, 3, 0) - # Video tensor must have 3 channels (RGB). Repeat the channel dim to convert grayscale to RGB. - if tensor.shape[-1] == 1: - tensor = tensor.repeat(1, 1, 1, 3) - # Video tensor must be in the range [0, 1]. Scale to [0, 255]. - tensor = (tensor * 255).to(torch.uint8) - torchvision.io.write_video(str(path), tensor, fps=24) diff --git a/src/lighter/callbacks/writer/table.py b/src/lighter/callbacks/writer/table.py deleted file mode 100644 index f9cd7baa..00000000 --- a/src/lighter/callbacks/writer/table.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -This module provides the TableWriter class, which saves predictions in a table format, such as CSV. -""" - -import itertools -from pathlib import Path -from typing import Any, Callable - -import pandas as pd -import torch -from pytorch_lightning import Trainer - -from lighter import System -from lighter.callbacks.writer.base import BaseWriter - - -class TableWriter(BaseWriter): - """ - Writer for saving predictions in a table format, such as CSV. - - Args: - path: CSV filepath. - writer: Writer function or name of a registered writer. - """ - - def __init__(self, path: str | Path, writer: str | Callable) -> None: - super().__init__(path, writer) - self.csv_records = [] - - @property - def writers(self) -> dict[str, Callable]: - return { - "tensor": lambda tensor: tensor.item() if tensor.numel() == 1 else tensor.tolist(), - } - - def write(self, tensor: Any, identifier: int | str) -> None: - """ - Writes the tensor as a table record using the specified writer. - - Args: - tensor: The tensor to record. Should not have a batch dimension. - identifier: Identifier for the record. - """ - self.csv_records.append({"identifier": identifier, "pred": self.writer(tensor)}) - - def on_predict_epoch_end(self, trainer: Trainer, pl_module: System) -> None: - """ - Called at the end of the prediction epoch to save predictions to a CSV file. - - Args: - trainer: The trainer instance. - pl_module: The System instance. - """ - # If in distributed data parallel mode, gather records from all processes to rank 0. - if trainer.world_size > 1: - gather_csv_records = [None] * trainer.world_size if trainer.is_global_zero else None - torch.distributed.gather_object(self.csv_records, gather_csv_records, dst=0) - if trainer.is_global_zero: - self.csv_records = list(itertools.chain(*gather_csv_records)) - - # Save the records to a CSV file - if trainer.is_global_zero: - df = pd.DataFrame(self.csv_records) - try: - df = df.sort_values("identifier") - except TypeError: - pass - df = df.set_index("identifier") - df.to_csv(self.path) - - # Clear the records after saving - self.csv_records = [] diff --git a/src/lighter/data.py b/src/lighter/data.py new file mode 100644 index 00000000..010b4e51 --- /dev/null +++ b/src/lighter/data.py @@ -0,0 +1,110 @@ +""" +LighterDataModule - A simple wrapper for organizing dataloaders in YAML configs. + +This module provides LighterDataModule, a helper class that wraps PyTorch dataloaders +so they can be configured in YAML without requiring a custom LightningDataModule. +""" + +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + + +class LighterDataModule(LightningDataModule): + """ + A lightweight wrapper for organizing dataloaders in configuration files. + + This class exists purely as a convenience helper - it wraps pre-configured + PyTorch DataLoaders so you can use Lighter's configuration system without + having to write a custom LightningDataModule from scratch. + + When to use LighterDataModule: + - Simple datasets that don't need complex preprocessing + - Quick experiments where you want to configure dataloaders in YAML + - Cases where your data pipeline is straightforward + + When to write a custom LightningDataModule: + - Complex data preparation (downloading, extraction, processing) + - Multi-process data setup with prepare_data() and setup() + - Advanced preprocessing pipelines + - Data that requires stage-specific transformations + - Sharing reusable data modules across projects + + Args: + train_dataloader: DataLoader for training (used in fit stage) + val_dataloader: DataLoader for validation (used in fit and validate stages) + test_dataloader: DataLoader for testing (used in test stage) + predict_dataloader: DataLoader for predictions (used in predict stage) + + Example: + ```yaml + # config.yaml + data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: true + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: true + transform: + _target_: torchvision.transforms.ToTensor + val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 32 + shuffle: false + dataset: + _target_: torchvision.datasets.CIFAR10 + root: ./data + train: false + transform: + _target_: torchvision.transforms.ToTensor + + model: + _target_: project.MyModel + network: ... + optimizer: ... + + trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 10 + ``` + + Note: + This is just a thin wrapper around PyTorch Lightning's LightningDataModule. + It doesn't add any special logic - it simply holds your dataloaders and + returns them when Lightning asks for them. + + If you need more control (prepare_data, setup, etc.), write a custom + LightningDataModule instead. + """ + + def __init__( + self, + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + predict_dataloader: DataLoader | None = None, + ) -> None: + super().__init__() + self._train_dataloader = train_dataloader + self._val_dataloader = val_dataloader + self._test_dataloader = test_dataloader + self._predict_dataloader = predict_dataloader + + def train_dataloader(self) -> DataLoader | None: + """Return the training dataloader.""" + return self._train_dataloader + + def val_dataloader(self) -> DataLoader | None: + """Return the validation dataloader.""" + return self._val_dataloader + + def test_dataloader(self) -> DataLoader | None: + """Return the test dataloader.""" + return self._test_dataloader + + def predict_dataloader(self) -> DataLoader | None: + """Return the prediction dataloader.""" + return self._predict_dataloader diff --git a/src/lighter/engine/runner.py b/src/lighter/engine/runner.py index 80d292d7..60a6d77f 100644 --- a/src/lighter/engine/runner.py +++ b/src/lighter/engine/runner.py @@ -4,166 +4,244 @@ """ import argparse +from pathlib import Path +from typing import Any -from pytorch_lightning import Trainer, seed_everything +import yaml +from loguru import logger +from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything from sparkwheel import Config, ValidationError -from lighter.engine.schema import LighterConfig -from lighter.system import System from lighter.utils.dynamic_imports import import_module_from_path -from lighter.utils.types.enums import Mode, Stage +from lighter.utils.types.enums import Stage +# ============================================================================ +# Helper Classes - Each Does One Thing +# ============================================================================ + + +class ProjectImporter: + """Discovers and imports user project modules.""" + + @staticmethod + def auto_discover_and_import() -> bool: + """ + Auto-discover project from __lighter__.py marker file. + Returns True if project was imported, False otherwise. + """ + cwd = Path.cwd() + marker = cwd / "__lighter__.py" + + if not marker.exists(): + return False + + import_module_from_path("project", cwd) + logger.info(f"Imported 'project' module from '{cwd}'") + return True + + +class ConfigLoader: + """Loads and validates configuration using Sparkwheel.""" + + @staticmethod + def load(inputs: list) -> Config: + """ + Load config from inputs (files, dicts, overrides). + + Sparkwheel auto-detects: + - Strings without '=' → file paths + - Strings with '=' → overrides + - Dicts → merged into config + """ + try: + config = Config() # No schema validation for now + for item in inputs: + config.update(item) + return config + except ValidationError as e: + raise ValueError(f"Configuration loading failed:\n{e}") from e -class Runner: - """ - Executes training stages using validated and resolved configurations. - The Runner loads configurations using Sparkwheel, applies CLI overrides, - validates against the schema, prunes unused components for the stage, - and executes the appropriate PyTorch Lightning trainer method. +class Runner: """ + Orchestrates training stage execution by coordinating helper classes. - STAGE_MODES = { - Stage.FIT: [Mode.TRAIN, Mode.VAL], - Stage.VALIDATE: [Mode.VAL], - Stage.TEST: [Mode.TEST], - Stage.PREDICT: [Mode.PREDICT], - } + Runner delegates responsibilities to specialized helper classes: + - ProjectImporter: Auto-discovers and imports user project modules via __lighter__.py marker + - ConfigLoader: Loads and validates configurations using Sparkwheel - def __init__(self) -> None: - """Initialize the runner with empty state.""" - self.config: Config | None = None - self.system: System | None = None - self.trainer: Trainer | None = None + Runner focuses on resolving and validating components (model, trainer, datamodule) + and executing the requested training stage. + """ def run( self, stage: Stage, - config: str | list[str] | dict, - overrides: list[str] | None = None, + inputs: list, + **stage_kwargs: Any, ) -> None: """ - Run a training stage with configuration and overrides. + Run a training stage with configuration inputs. + + Orchestrates the complete training workflow: + 1. Loads configuration via ConfigLoader (delegates to Sparkwheel for auto-detection) + 2. Auto-discovers and imports project modules via ProjectImporter + 3. Resolves and validates model, trainer, and datamodule components + 4. Saves configuration (to log directory, logger, and model hyperparameters) + 5. Executes the requested training stage Args: stage: Stage to run (fit, validate, test, predict) - config: Config file path(s) or dict. If string, supports comma-separated paths. - overrides: List of CLI override strings in format "key::path=value" + inputs: List of config file paths, dicts, and/or overrides. + Passed to ConfigLoader.load() which delegates to Sparkwheel for auto-detection: + - Strings without '=' → file paths + - Strings with '=' → overrides + - Dicts → merged into config + **stage_kwargs: Additional keyword arguments from CLI (e.g., ckpt_path, verbose) + passed directly to the trainer stage method Raises: ValueError: If config validation fails or required components are missing - TypeError: If system or trainer are not the correct type + TypeError: If model or trainer are not the correct type """ seed_everything() - # Handle comma-separated config files - if isinstance(config, str) and "," in config: - config = config.split(",") + # 1. Load configuration + config = ConfigLoader.load(inputs) - # Load config with CLI overrides and validation (all in one step!) - try: - self.config = Config.from_cli( - config, - overrides or [], - schema=LighterConfig, - ) - except ValidationError as e: - raise ValueError(f"Configuration validation failed:\n{e}") from e + # 2. Auto-discover and import project + ProjectImporter.auto_discover_and_import() - # Prune unused components for this stage - self._prune_for_stage(stage) + # 3. Resolve components + model = self._resolve_model(config) + trainer = self._resolve_trainer(config) + datamodule = self._resolve_datamodule(config, model) - # Setup and run - self._setup(stage) - self._execute(stage) + # 4. Save configuration to trainer's log directory, logger, and model hparams for checkpoint access + self._save_config(config, trainer, model) - def _prune_for_stage(self, stage: Stage) -> None: - """ - Remove unused components using Sparkwheel's delete directive (~). + # 5. Execute stage + self._execute(stage, model, trainer, datamodule, **stage_kwargs) - Args: - stage: Current stage being executed - """ - if self.config is None: - raise ValueError("Config must be loaded before pruning") - - required = set(self.STAGE_MODES[stage]) - all_modes = {Mode.TRAIN, Mode.VAL, Mode.TEST, Mode.PREDICT} - - # Build delete directives for unused modes - deletes = {} - for mode in all_modes - required: - deletes[f"~system::dataloaders::{mode}"] = None - deletes[f"~system::metrics::{mode}"] = None - - # Remove optimizer/scheduler/criterion for non-training stages - if stage != Stage.FIT: - deletes["~system::optimizer"] = None - deletes["~system::scheduler"] = None - if stage != Stage.VALIDATE: - deletes["~system::criterion"] = None - - # Keep only args for this stage - for s in [Stage.FIT, Stage.VALIDATE, Stage.TEST, Stage.PREDICT]: - if s != stage: - deletes[f"~args::{s}"] = None - - # Apply deletions - self.config.update(deletes) - - def _setup(self, stage: Stage) -> None: + def _resolve_model(self, config: Config) -> LightningModule: + """Resolve and validate model from config.""" + model = config.resolve("model") + if not isinstance(model, LightningModule): + raise TypeError(f"model must be LightningModule or LighterModule, got {type(model)}") + return model + + def _resolve_trainer(self, config: Config) -> Trainer: + """Resolve and validate trainer from config.""" + trainer = config.resolve("trainer") + if not isinstance(trainer, Trainer): + raise TypeError(f"trainer must be Trainer, got {type(trainer)}") + return trainer + + def _resolve_datamodule(self, config: Config, model: LightningModule) -> LightningDataModule | None: """ - Setup system and trainer from configuration. + Resolve and validate datamodule from config. Args: - stage: Current stage being executed + config: Configuration object + model: Resolved model (checked for built-in dataloaders) + + Returns: + LightningDataModule instance or None if model defines its own dataloaders Raises: - TypeError: If system or trainer are not the correct type + TypeError: If data key exists but is not a LightningDataModule """ - if self.config is None: - raise ValueError("Config must be loaded before setup") - - # Import project module if specified - project = self.config.get("project") - if project: - import_module_from_path("project", project) - - # Resolve system - self.system = self.config.resolve("system") - if not isinstance(self.system, System): - raise TypeError(f"system must be System, got {type(self.system)}") - - # Resolve trainer - self.trainer = self.config.resolve("trainer") - if not isinstance(self.trainer, Trainer): - raise TypeError(f"trainer must be Trainer, got {type(self.trainer)}") - - # Save config to system checkpoint and trainer logger - if self.system: - self.system.save_hyperparameters(self.config.get()) - if self.trainer and self.trainer.logger: - self.trainer.logger.log_hyperparams(self.config.get()) - - def _execute(self, stage: Stage) -> None: + # Data key is optional - plain Lightning modules can define their own dataloaders + if config.get("data") is None: + # Check if model has dataloader methods (plain Lightning module) + has_dataloaders = any( + hasattr(model, method) + for method in ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"] + ) + if not has_dataloaders: + raise ValueError( + "Missing required 'data:' config key and model does not define dataloader methods. " + "Either:\n" + "1. Add 'data:' config key:\n" + " data:\n" + " _target_: lighter.LighterDataModule\n" + " train_dataloader: ...\n" + "2. Or define dataloader methods in your LightningModule (train_dataloader, val_dataloader, etc.)" + ) + return None + + # Resolve and validate data key + datamodule = config.resolve("data") + if not isinstance(datamodule, LightningDataModule): + raise TypeError( + f"data must be LightningDataModule (or lighter.LighterDataModule), got {type(datamodule)}. " + "Example:\n" + "data:\n" + " _target_: lighter.LighterDataModule\n" + " train_dataloader:\n" + " _target_: torch.utils.data.DataLoader\n" + " # ... config ..." + ) + + return datamodule + + def _save_config(self, config: Config, trainer: Trainer, model: LightningModule) -> None: """ - Execute the training stage. + Save configuration to multiple destinations. - Args: - stage: Stage to execute + Saves the configuration to: + - Model (for checkpoint access via model.hparams) + - Logger (for experiment tracking via log_hyperparams) + - Log directory (as config.yaml file) - Raises: - AttributeError: If trainer doesn't have the stage method + Args: + config: Configuration object to save + trainer: Trainer (uses trainer.logger and trainer.log_dir) + model: Model to save hyperparameters to """ - if self.config is None or self.trainer is None or self.system is None: - raise ValueError("Config, trainer, and system must be set up before execution") - # Get stage-specific arguments - args = self.config.resolve(f"args::{stage}", default={}) + # Save to model checkpoint (for model.hparams access) + model.save_hyperparameters({"config": config.get()}) + + # If no logger, skip other saves + if not trainer.logger: + return - # Execute the stage method - stage_method = getattr(self.trainer, str(stage)) - stage_method(self.system, **args) + # Save to logger (for experiment tracking) + trainer.logger.log_hyperparams(config.get()) + + # Save as config.yaml to log directory if it exists + if trainer.log_dir: + config_file = Path(trainer.log_dir) / "config.yaml" + config_file.parent.mkdir(parents=True, exist_ok=True) + with open(config_file, "w") as f: + yaml.dump(config.get(), f, default_flow_style=False, sort_keys=False, indent=4) + logger.info(f"Saved config to: {config_file}") + + def _execute( + self, + stage: Stage, + model: LightningModule, + trainer: Trainer, + datamodule: LightningDataModule | None, + **stage_kwargs: Any, + ) -> None: + """ + Execute the training stage. + + Args: + stage: Stage to execute (fit, validate, test, predict) + model: Resolved model + trainer: Resolved trainer + datamodule: Resolved datamodule (None if model defines its own dataloaders) + **stage_kwargs: Additional keyword arguments from CLI (e.g., ckpt_path, verbose) + """ + stage_method = getattr(trainer, str(stage)) + if datamodule is not None: + stage_method(model, datamodule=datamodule, **stage_kwargs) + else: + # Plain Lightning module with built-in dataloaders + stage_method(model, **stage_kwargs) def cli() -> None: @@ -180,6 +258,27 @@ def cli() -> None: help="Available commands", ) + # Common arguments shared by all stages + def add_common_args(stage_parser): + """Add common arguments to a stage subparser.""" + stage_parser.add_argument( + "inputs", + nargs="+", + help="Config files and overrides. Example: config.yaml model::optimizer::lr=0.001", + ) + stage_parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help='Path to checkpoint. Can be "last", "best", or a file path.', + ) + stage_parser.add_argument( + "--weights_only", + action="store_true", + default=None, + help="Restrict checkpoint loading to state_dicts of torch.Tensor (safer for untrusted sources).", + ) + # Fit subcommand fit_parser = subparsers.add_parser( "fit", @@ -187,20 +286,12 @@ def cli() -> None: description="Train a model using the specified configuration file.", epilog="Examples:\n" " lighter fit config.yaml\n" - " lighter fit config.yaml system::optimizer::lr=0.001\n" - " lighter fit base.yaml,experiment.yaml trainer::max_epochs=100", + " lighter fit config.yaml --ckpt_path checkpoint.ckpt\n" + " lighter fit config.yaml model::optimizer::lr=0.001\n" + " lighter fit base.yaml experiment.yaml --ckpt_path last trainer::max_epochs=100", formatter_class=argparse.RawDescriptionHelpFormatter, ) - fit_parser.add_argument( - "config", - help="Path to config file(s), comma-separated for multiple files", - ) - fit_parser.add_argument( - "overrides", - nargs="*", - default=[], - help='Configuration overrides in format "key::path=value"', - ) + add_common_args(fit_parser) # Validate subcommand validate_parser = subparsers.add_parser( @@ -209,18 +300,16 @@ def cli() -> None: description="Validate a model using the specified configuration file.", epilog="Examples:\n" " lighter validate config.yaml\n" - " lighter validate config.yaml system::model::weights=checkpoint.ckpt", + " lighter validate config.yaml --ckpt_path best\n" + " lighter validate config.yaml --ckpt_path checkpoint.ckpt --verbose", formatter_class=argparse.RawDescriptionHelpFormatter, ) + add_common_args(validate_parser) validate_parser.add_argument( - "config", - help="Path to config file(s), comma-separated for multiple files", - ) - validate_parser.add_argument( - "overrides", - nargs="*", - default=[], - help='Configuration overrides in format "key::path=value"', + "--verbose", + action="store_true", + default=None, + help="Print validation results (default: True).", ) # Test subcommand @@ -228,18 +317,18 @@ def cli() -> None: "test", help="Test a model", description="Test a model using the specified configuration file.", - epilog="Examples:\n lighter test config.yaml\n lighter test config.yaml system::model::weights=checkpoint.ckpt", + epilog="Examples:\n" + " lighter test config.yaml\n" + " lighter test config.yaml --ckpt_path best\n" + " lighter test config.yaml --ckpt_path checkpoint.ckpt --verbose", formatter_class=argparse.RawDescriptionHelpFormatter, ) + add_common_args(test_parser) test_parser.add_argument( - "config", - help="Path to config file(s), comma-separated for multiple files", - ) - test_parser.add_argument( - "overrides", - nargs="*", - default=[], - help='Configuration overrides in format "key::path=value"', + "--verbose", + action="store_true", + default=None, + help="Print test results (default: True).", ) # Predict subcommand @@ -249,26 +338,27 @@ def cli() -> None: description="Run predictions using the specified configuration file.", epilog="Examples:\n" " lighter predict config.yaml\n" - " lighter predict config.yaml system::model::weights=checkpoint.ckpt", + " lighter predict config.yaml --ckpt_path best\n" + " lighter predict config.yaml --ckpt_path checkpoint.ckpt --return_predictions", formatter_class=argparse.RawDescriptionHelpFormatter, ) + add_common_args(predict_parser) predict_parser.add_argument( - "config", - help="Path to config file(s), comma-separated for multiple files", - ) - predict_parser.add_argument( - "overrides", - nargs="*", - default=[], - help='Configuration overrides in format "key::path=value"', + "--return_predictions", + action="store_true", + default=None, + help="Whether to return predictions (default: True except with process-spawning accelerators).", ) # Parse arguments args = parser.parse_args() + # Extract stage kwargs (exclude command and inputs) + stage_kwargs = {k: v for k, v in vars(args).items() if k not in ["command", "inputs"] and v is not None} + # Execute command try: - Runner().run(args.command, args.config, args.overrides) + Runner().run(args.command, args.inputs, **stage_kwargs) except Exception as e: # Suppress exception chain to avoid duplicate tracebacks raise e from None diff --git a/src/lighter/engine/schema.py b/src/lighter/engine/schema.py deleted file mode 100644 index cd3c38bf..00000000 --- a/src/lighter/engine/schema.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Defines the schema for configuration validation using Sparkwheel's validation with dataclasses. -""" - -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class AdapterConfig: - """Adapter configuration for a specific mode.""" - - batch: Optional[dict] = None - criterion: Optional[dict] = None - metrics: Optional[dict] = None - logging: Optional[dict] = None - - -@dataclass -class PredictAdapterConfig: - """Adapter configuration for predict mode (no criterion).""" - - batch: Optional[dict] = None - logging: Optional[dict] = None - - -@dataclass -class AdaptersConfig: - """Adapters configuration for all modes.""" - - train: Optional[dict] = None # Can be AdapterConfig but keep flexible - val: Optional[dict] = None - test: Optional[dict] = None - predict: Optional[dict] = None - - -@dataclass -class MetricsConfig: - """Metrics configuration for different stages.""" - - train: Optional[list | dict] = None - val: Optional[list | dict] = None - test: Optional[list | dict] = None - - -@dataclass -class DataloadersConfig: - """Dataloaders configuration for different stages.""" - - train: Optional[dict] = None - val: Optional[dict] = None - test: Optional[dict] = None - predict: Optional[dict] = None - - -@dataclass -class SystemConfig: - """System configuration with model, optimizer, scheduler, etc.""" - - model: Optional[dict] = None - criterion: Optional[dict] = None - optimizer: Optional[dict] = None - scheduler: Optional[dict] = None - inferer: Optional[dict] = None - metrics: Optional[MetricsConfig] = None - dataloaders: Optional[DataloadersConfig] = None - adapters: Optional[AdaptersConfig] = None - - -@dataclass -class ArgsConfig: - """Arguments to pass to Trainer stage methods.""" - - fit: Optional[dict] = None - validate: Optional[dict] = None - test: Optional[dict] = None - predict: Optional[dict] = None - - -@dataclass -class LighterConfig: - """Main Lighter configuration schema.""" - - trainer: dict # pytorch_lightning.Trainer - system: SystemConfig # lighter.System - project: Optional[str] = None - vars: Optional[dict] = None - args: Optional[ArgsConfig] = None diff --git a/src/lighter/model.py b/src/lighter/model.py new file mode 100644 index 00000000..071354bf --- /dev/null +++ b/src/lighter/model.py @@ -0,0 +1,423 @@ +""" +This module provides the core LighterModule class that extends PyTorch Lightning's LightningModule. +Users implement abstract step methods while the framework handles automatic dual logging. +""" + +from collections.abc import Callable +from typing import Any + +import pytorch_lightning as pl +import torch +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torchmetrics import Metric, MetricCollection + +from lighter.utils.misc import get_optimizer_stats +from lighter.utils.types.enums import Mode + + +class LighterModule(pl.LightningModule): + """ + Minimal base class for deep learning models in Lighter. + + Users should: + - Subclass and implement the step methods they need (training_step, validation_step, etc.) + - Define their own batch processing, loss computation, metric updates + - Configure data separately using the 'data:' config key + + Framework provides: + - Automatic dual logging of losses (step + epoch) + - Automatic dual logging of metrics (step + epoch) + - Optimizer configuration + + Args: + network: Neural network model + criterion: Loss function (optional, user can compute loss manually in step) + optimizer: Optimizer (required for training) + scheduler: Learning rate scheduler (optional) + train_metrics: Training metrics (optional, user calls them in step) + val_metrics: Validation metrics (optional) + test_metrics: Test metrics (optional) + + Example: + class MyModel(LighterModule): + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + + # Option 1: Use self.criterion if provided + loss = self.criterion(pred, y) if self.criterion else F.cross_entropy(pred, y) + + # User calls metrics themselves + if self.train_metrics: + self.train_metrics(pred, y) + + return {"loss": loss, "pred": pred, "target": y} + + def validation_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) if self.criterion else F.cross_entropy(pred, y) + if self.val_metrics: + self.val_metrics(pred, y) + return {"loss": loss, "pred": pred, "target": y} + + def test_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + if self.test_metrics: + self.test_metrics(pred, y) + return {"pred": pred, "target": y} + + def predict_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + return pred + """ + + def __init__( + self, + network: Module, + criterion: Callable | None = None, + optimizer: Optimizer | None = None, + scheduler: LRScheduler | None = None, + train_metrics: Metric | MetricCollection | None = None, + val_metrics: Metric | MetricCollection | None = None, + test_metrics: Metric | MetricCollection | None = None, + ) -> None: + super().__init__() + + # Core components + self.network = network + self.criterion = criterion + self.optimizer = optimizer + self.scheduler = scheduler + + # Metrics (registered as modules) + self.train_metrics = self._prepare_metrics(train_metrics) + self.val_metrics = self._prepare_metrics(val_metrics) + self.test_metrics = self._prepare_metrics(test_metrics) + + def _prepare_metrics(self, metrics: Metric | MetricCollection | None) -> Metric | MetricCollection | None: + """Validate metrics - must be Metric or MetricCollection.""" + if metrics is None: + return None + + if isinstance(metrics, (Metric, MetricCollection)): + return metrics + + raise TypeError( + f"metrics must be Metric or MetricCollection, got {type(metrics).__name__}.\n\n" + f"Single metric:\n" + f" train_metrics:\n" + f" _target_: torchmetrics.Accuracy\n" + f" task: multiclass\n\n" + f"Multiple metrics:\n" + f" train_metrics:\n" + f" _target_: torchmetrics.MetricCollection\n" + f" metrics:\n" + f" - _target_: torchmetrics.Accuracy\n" + f" task: multiclass\n" + f" - _target_: torchmetrics.F1Score\n" + f" task: multiclass" + ) + + # ============================================================================ + # Step Methods - Override as Needed + # ============================================================================ + + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]: + """ + Define training logic. + + User responsibilities: + - Extract data from batch + - Call self(input) for forward pass + - Compute loss + - Call self.train_metrics(pred, target) if configured + - Return loss tensor or dict with 'loss' key + + Framework automatically logs loss and metrics. + + Returns: + Either: + - Tensor: The loss value (simplest option) + - Dict with required 'loss' key and optional keys: + - pred: Model predictions (for callbacks) + - target: Target labels (for callbacks) + - input: Input data (for callbacks) + - Any other keys you need + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement training_step() to use trainer.fit(). " + f"See https://project-lighter.github.io/lighter/guides/lighter-module/" + ) + + def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]: + """ + Define validation logic. + + Similar to training_step but typically without gradients. + Call self.val_metrics(pred, target) if configured. + + Returns: + Either: + - Tensor: The loss value + - Dict with 'loss' key + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement validation_step() to use validation. " + f"See https://project-lighter.github.io/lighter/guides/lighter-module/" + ) + + def test_step(self, batch: Any, batch_idx: int) -> torch.Tensor | dict[str, Any]: + """ + Define test logic. + + Loss is optional. Call self.test_metrics(pred, target) if configured. + + Returns: + Either: + - Tensor: The loss value (optional in test mode) + - Dict with optional 'loss' key. Can include pred, target, etc. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement test_step() to use trainer.test(). " + f"See https://project-lighter.github.io/lighter/guides/lighter-module/" + ) + + def predict_step(self, batch: Any, batch_idx: int) -> Any: + """ + Define prediction logic. + + User responsibilities: + - Extract data from batch + - Call self(input) for forward pass + - Return predictions in desired format + + No automatic logging happens in predict mode. + Return any format you need (tensor, dict, list, etc.). + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement predict_step() to use trainer.predict(). " + f"See https://project-lighter.github.io/lighter/guides/lighter-module/" + ) + + # ============================================================================ + # Forward Pass - Simple Delegation + # ============================================================================ + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """ + Forward pass - simply delegates to self.network. + + Override if you need custom forward logic. + """ + return self.network(*args, **kwargs) + + # ============================================================================ + # Batch-End Hooks - Automatic Logging + # ============================================================================ + + def _on_batch_end(self, outputs: torch.Tensor | dict[str, Any], batch_idx: int) -> None: + """Common batch-end logic for all modes.""" + outputs = self._normalize_output(outputs) + self._log_outputs(outputs, batch_idx) + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: + """Framework hook - automatically logs training outputs.""" + self._on_batch_end(outputs, batch_idx) + + def on_validation_batch_end( + self, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Framework hook - automatically logs validation outputs.""" + self._on_batch_end(outputs, batch_idx) + + def on_test_batch_end( + self, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Framework hook - automatically logs test outputs.""" + self._on_batch_end(outputs, batch_idx) + + def _normalize_output(self, output: torch.Tensor | dict[str, Any]) -> dict[str, Any]: + """ + Normalize step output to dict format. + + Args: + output: Either: + - torch.Tensor: Loss value (normalized to {"loss": tensor}) + - dict: Must contain outputs. Can include: + - "loss": torch.Tensor or dict with "total" key + - "pred", "target", "input": Additional data for callbacks + + Returns: + Dict with normalized structure + + Raises: + TypeError: If output is neither Tensor nor dict + ValueError: If loss dict is missing 'total' key + """ + if isinstance(output, torch.Tensor): + return {"loss": output} + elif isinstance(output, dict): + # Validate loss structure if present + if "loss" in output and isinstance(output["loss"], dict): + if "total" not in output["loss"]: + raise ValueError( + f"Loss dict must include 'total' key. " + f"Got keys: {list(output['loss'].keys())}. " + f"Example: {{'loss': {{'total': combined, 'ce': ce_loss, 'reg': reg_loss}}}}" + ) + return output + else: + raise TypeError( + f"Step method must return torch.Tensor or dict. " + f"Got {type(output).__name__} instead. " + f"Examples:\n" + f" - return loss # Simple tensor\n" + f' - return {{"loss": loss, "pred": pred}}' + ) + + def _log_outputs(self, outputs: dict[str, Any], batch_idx: int) -> None: + """ + Log all outputs from a step. + + Override this method to customize logging behavior. + Default: dual logging (step + epoch) for loss and metrics. + + Args: + outputs: Dict from user's step method + batch_idx: Current batch index + """ + if self.trainer.logger is None: + return + self._log_loss(outputs.get("loss")) + self._log_metrics() + self._log_optimizer_stats(batch_idx) + + def _log_loss(self, loss: torch.Tensor | dict[str, Any] | None) -> None: + """ + Log loss with dual pattern (step + epoch). + + Args: + loss: Loss tensor or dict from step method. + If dict, must have 'total' key (validated in _normalize_output). + """ + if loss is None: + return + + # Log scalar or dict + if isinstance(loss, dict): + for name, value in loss.items(): + name = f"{self.mode}/loss/{name}" + self._log(name, value, on_step=True) + self._log(name, value, on_epoch=True, sync_dist=True) + else: + name = f"{self.mode}/loss" + self._log(name, loss, on_step=True) + self._log(name, loss, on_epoch=True, sync_dist=True) + + def _log_metrics(self) -> None: + """ + Log metrics with dual pattern (step + epoch). + + User already called metrics in their step method. + Handles both single Metric and MetricCollection. + """ + metrics = getattr(self, f"{self.mode}_metrics", None) + if metrics is None: + return + + if isinstance(metrics, MetricCollection): + # MetricCollection - iterate over named metrics + for name, metric in metrics.items(): + name = f"{self.mode}/metrics/{name}" + self._log(name, metric, on_step=True) + self._log(name, metric, on_epoch=True, sync_dist=True) + else: + # Single Metric - use class name (consistent with MetricCollection auto-naming) + name = f"{self.mode}/metrics/{metrics.__class__.__name__}" + self._log(name, metrics, on_step=True) + self._log(name, metrics, on_epoch=True, sync_dist=True) + + def _log_optimizer_stats(self, batch_idx: int) -> None: + """ + Log optimizer stats once per epoch in train mode. + + Args: + batch_idx: Current batch index + """ + if self.mode != Mode.TRAIN or batch_idx != 0 or self.optimizer is None: + return + + # Optimizer stats only logged per epoch + for name, stat in get_optimizer_stats(self.optimizer).items(): + name = f"{self.mode}/{name}" + self._log(name, stat, on_epoch=True, sync_dist=False) + + def _log(self, name: str, value: Any, on_step: bool = False, on_epoch: bool = False, sync_dist: bool = False) -> None: + suffix = "step" if on_step and not on_epoch else "epoch" + self.log( + f"{name}/{suffix}", + value, + logger=True, + on_step=on_step, + on_epoch=on_epoch, + sync_dist=sync_dist, + ) + + # ============================================================================ + # Lightning Optimizer Configuration + # ============================================================================ + + def configure_optimizers(self): + """Configure optimizer and scheduler.""" + if self.optimizer is None: + raise ValueError("Optimizer not configured.") + + if self.scheduler is None: + return {"optimizer": self.optimizer} + else: + return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler} + + # ============================================================================ + # Properties + # ============================================================================ + + @property + def mode(self) -> str: + """ + Current execution mode. + + Returns: + "train", "val", "test", or "predict" + + Raises: + RuntimeError: If called outside trainer context + """ + if self.trainer is None: + raise RuntimeError("LighterModule is not attached to a Trainer.") + + if self.trainer.sanity_checking: + return Mode.VAL + + if self.trainer.training: + return Mode.TRAIN + elif self.trainer.validating: + return Mode.VAL + elif self.trainer.testing: + return Mode.TEST + elif self.trainer.predicting: + return Mode.PREDICT + else: + raise RuntimeError("Cannot determine mode outside Lightning execution.") diff --git a/src/lighter/system.py b/src/lighter/system.py deleted file mode 100644 index 1962e29b..00000000 --- a/src/lighter/system.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -This module defines the System class, which encapsulates the components of a deep learning system, -including the model, optimizer, datasets, and more. It extends PyTorch Lightning's LightningModule. -""" - -from collections.abc import Callable -from dataclasses import asdict -from typing import Any - -import pytorch_lightning as pl -from torch import Tensor -from torch.nn import Module -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import collate_str_fn, default_collate_fn_map -from torchmetrics import Metric, MetricCollection - -from lighter.utils.misc import get_optimizer_stats, hasarg -from lighter.utils.patches import PatchedModuleDict -from lighter.utils.types.containers import Adapters, DataLoaders, Metrics -from lighter.utils.types.enums import Data, Mode - -# Patch the original collate function to allow None values in the batch. -default_collate_fn_map.update({type(None): collate_str_fn}) - - -class System(pl.LightningModule): - """ - System encapsulates the components of a deep learning system, extending PyTorch Lightning's LightningModule. - - Args: - model: Model. - optimizer: Optimizer. - scheduler: Learning rate scheduler. - criterion: Criterion (loss) function. - metrics: Metrics for train, val, and test. Supports a single/list/dict of `torchmetrics` metrics. - dataloaders: Dataloaders for train, val, test, and predict. - adapters: Adapters for batch preparation, criterion argument adaptation, metrics argument adaptation, and logging data adaptation. - inferer: Inferer to use in val/test/predict modes. Custom inferers can be defined to handle inference logic. - - """ - - def __init__( - self, - model: Module, - dataloaders: dict[str, DataLoader], - optimizer: Optimizer | None = None, - scheduler: LRScheduler | None = None, - criterion: Callable | None = None, - metrics: dict[str, Metric | list[Metric] | dict[str, Metric]] | None = None, - adapters: dict[str, Callable] | None = None, - inferer: Callable | None = None, - ) -> None: - super().__init__() - - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - self.criterion = criterion - self.inferer = inferer - - # Containers - self.dataloaders = DataLoaders(**(dataloaders or {})) - self.metrics = Metrics(**(metrics or {})) - self.adapters = Adapters(**(adapters or {})) - - # Turn metrics container into a ModuleDict to register them properly. - self.metrics = PatchedModuleDict(asdict(self.metrics)) - - self.mode = None - self._setup_mode_hooks() - - def _step(self, batch: dict, batch_idx: int) -> dict[str, Any] | Any: - """ - Performs a step in the specified mode, processing the batch and calculating loss and metrics. - - Args: - batch: The batch of data. - batch_idx: The index of the batch. - Returns: - dict or Any: For predict step, returns prediction only. For other steps, - returns dict with loss, metrics, input, target, pred, and identifier. Loss is None - for test step, metrics is None if unspecified. - """ - input, target, identifier = self._prepare_batch(batch) - pred = self.forward(input) - - loss = self._calculate_loss(input, target, pred) - metrics = self._calculate_metrics(input, target, pred) - - self._log_stats(loss, metrics, batch_idx) - output = self._prepare_output(identifier, input, target, pred, loss, metrics) - return output - - def _prepare_batch(self, batch: dict) -> tuple[Any, Any, Any]: - """ - Prepares the batch data. - - Args: - batch: The input batch dictionary. - - Returns: - tuple: A tuple containing (input, target, identifier). - """ - adapters = getattr(self.adapters, self.mode) - input, target, identifier = adapters.batch(batch) - return input, target, identifier - - def forward(self, input: Any) -> Any: - """ - Forward pass through the model. - - Args: - input: The input data. - - Returns: - Any: The model's output. - """ - - # Pass `epoch` and/or `step` argument to forward if it accepts them - kwargs = {} - if hasarg(self.model.forward, Data.EPOCH): - kwargs[Data.EPOCH] = self.current_epoch - if hasarg(self.model.forward, Data.STEP): - kwargs[Data.STEP] = self.global_step - - # Predict. Use inferer if available in val, test, and predict modes. - if self.inferer and self.mode in [Mode.VAL, Mode.TEST, Mode.PREDICT]: - return self.inferer(input, self.model, **kwargs) - return self.model(input, **kwargs) - - def _calculate_loss(self, input: Any, target: Any, pred: Any) -> Tensor | dict[str, Tensor] | None: - """ - Calculates the loss using the criterion if in train or validation mode. - - Args: - input: The input data. - target: The target data. - pred: The model predictions. - - Returns: - The calculated loss or None if not in train/val mode. - - Raises: - ValueError: If criterion is not specified in train/val mode or if loss dict is missing 'total' key. - """ - loss = None - if self.mode in [Mode.TRAIN, Mode.VAL]: - if self.criterion is None: - raise ValueError("Please specify 'system.criterion' in the config.") - - adapters = getattr(self.adapters, self.mode) - loss = adapters.criterion(self.criterion, input, target, pred) - - if isinstance(loss, dict) and "total" not in loss: - raise ValueError( - "The loss dictionary must include a 'total' key that combines all sublosses. " - "Example: {'total': combined_loss, 'subloss1': loss1, ...}" - ) - return loss - - def _calculate_metrics(self, input: Any, target: Any, pred: Any) -> Any | None: - """ - Calculates the metrics if not in predict mode. - - Args: - input: The input data. - target: The target data. - pred: The model predictions. - - Returns: - The calculated metrics or None if in predict mode or no metrics specified. - """ - if self.mode == Mode.PREDICT or self.metrics[self.mode] is None: - return None - - adapters = getattr(self.adapters, self.mode) - metrics = adapters.metrics(self.metrics[self.mode], input, target, pred) - return metrics - - def _log_stats(self, loss: Tensor | dict[str, Tensor], metrics: MetricCollection, batch_idx: int) -> None: - """ - Logs the loss, metrics, and optimizer statistics. - - Args: - loss: The calculated loss. - metrics: The calculated metrics. - batch_idx: The index of the batch. - """ - if self.trainer.logger is None: - return - - # Loss - if loss is not None: - if not isinstance(loss, dict): - self._log(f"{self.mode}/{Data.LOSS}/{Data.STEP}", loss, on_step=True) - self._log(f"{self.mode}/{Data.LOSS}/{Data.EPOCH}", loss, on_epoch=True) - else: - for name, subloss in loss.items(): - self._log(f"{self.mode}/{Data.LOSS}/{name}/{Data.STEP}", subloss, on_step=True) - self._log(f"{self.mode}/{Data.LOSS}/{name}/{Data.EPOCH}", subloss, on_epoch=True) - - # Metrics - if metrics is not None: - for name, metric in metrics.items(): - self._log(f"{self.mode}/{Data.METRICS}/{name}/{Data.STEP}", metric, on_step=True) - self._log(f"{self.mode}/{Data.METRICS}/{name}/{Data.EPOCH}", metric, on_epoch=True) - - # Optimizer's lr, momentum, beta. Logged in train mode and once per epoch. - if self.mode == Mode.TRAIN and batch_idx == 0: - for name, optimizer_stat in get_optimizer_stats(self.optimizer).items(): - self._log(f"{self.mode}/{name}", optimizer_stat, on_epoch=True) - - def _log(self, name: str, value: Any, on_step: bool = False, on_epoch: bool = False) -> None: - """Log a key, value pair. Syncs across distributed nodes if `on_epoch` is True. - - Args: - name (str): key to log. - value (Any): value to log. - on_step (bool, optional): if True, logs on step. - on_epoch (bool, optional): if True, logs on epoch with sync_dist=True. - """ - batch_size = getattr(self.dataloaders, self.mode).batch_size - self.log(name, value, logger=True, batch_size=batch_size, on_step=on_step, on_epoch=on_epoch, sync_dist=on_epoch) - - def _prepare_output( - self, - identifier: Any, - input: Any, - target: Any, - pred: Any, - loss: Tensor | dict[str, Tensor] | None, - metrics: Any | None, - ) -> dict[str, Any]: - """ - Prepares the data to be returned by the step function to callbacks. - - Args: - identifier: The batch identifier. - input: The input data. - target: The target data. - pred: The model predictions. - loss: The calculated loss. - metrics: The calculated metrics. - - Returns: - dict: A dictionary containing all the step information. - """ - adapters = getattr(self.adapters, self.mode) - input, target, pred = adapters.logging(input, target, pred) - return { - Data.IDENTIFIER: identifier, - Data.INPUT: input, - Data.TARGET: target, - Data.PRED: pred, - Data.LOSS: loss, - Data.METRICS: metrics, - Data.STEP: self.global_step, - Data.EPOCH: self.current_epoch, - } - - def configure_optimizers(self) -> dict[str, Optimizer | LRScheduler] | None: - """ - Configures the optimizers and learning rate schedulers. - - Returns: - dict: A dictionary containing the optimizer and scheduler. - - Raises: - ValueError: If optimizer is not specified. - """ - if self.optimizer is None: - raise ValueError("Please specify 'system.optimizer' in the config.") - if self.scheduler is None: - return {"optimizer": self.optimizer} - else: - return {"optimizer": self.optimizer, "lr_scheduler": self.scheduler} - - def _setup_mode_hooks(self): - """ - Sets up the training, validation, testing, and prediction hooks based on defined dataloaders. - """ - if self.dataloaders.train is not None: - self.training_step = self._step - self.train_dataloader = lambda: self.dataloaders.train - self.on_train_start = lambda: self._on_mode_start(Mode.TRAIN) - self.on_train_end = self._on_mode_end - if self.dataloaders.val is not None: - self.validation_step = self._step - self.val_dataloader = lambda: self.dataloaders.val - self.on_validation_start = lambda: self._on_mode_start(Mode.VAL) - self.on_validation_end = self._on_mode_end - if self.dataloaders.test is not None: - self.test_step = self._step - self.test_dataloader = lambda: self.dataloaders.test - self.on_test_start = lambda: self._on_mode_start(Mode.TEST) - self.on_test_end = self._on_mode_end - if self.dataloaders.predict is not None: - self.predict_step = self._step - self.predict_dataloader = lambda: self.dataloaders.predict - self.on_predict_start = lambda: self._on_mode_start(Mode.PREDICT) - self.on_predict_end = self._on_mode_end - - def _on_mode_start(self, mode: str | None) -> None: - """ - Sets the current mode at the start of a phase. - - Args: - mode: The mode to set (train, val, test, or predict). - """ - self.mode = mode - - def _on_mode_end(self) -> None: - """ - Resets the mode at the end of a phase. - """ - self.mode = None - - @property - def learning_rate(self) -> float: - """ - Gets the learning rate of the optimizer. - - Returns: - float: The learning rate. - - Raises: - ValueError: If there are multiple optimizer parameter groups. - """ - if len(self.optimizer.param_groups) > 1: - raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.") - return self.optimizer.param_groups[0]["lr"] - - @learning_rate.setter - def learning_rate(self, value: float) -> None: - """ - Sets the learning rate of the optimizer. - - Args: - value: The new learning rate. - - Raises: - ValueError: If there are multiple optimizer parameter groups. - """ - if len(self.optimizer.param_groups) > 1: - raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.") - self.optimizer.param_groups[0]["lr"] = value diff --git a/src/lighter/utils/data.py b/src/lighter/utils/data.py index c0e56817..4bed4c31 100644 --- a/src/lighter/utils/data.py +++ b/src/lighter/utils/data.py @@ -1,37 +1,71 @@ import random -from typing import Any, Callable +from collections.abc import Callable +from typing import Any -import torch +from loguru import logger +from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate def collate_replace_corrupted( - batch: Any, dataset: torch.utils.data.Dataset, default_collate_fn: Callable | None = None + batch: Any, dataset: Dataset, default_collate_fn: Callable | None = None, max_retries: int = 100 ) -> Any: """ - Collate function to handle corrupted examples in a batch by replacing them with valid ones. + Collate function that handles corrupted examples in a batch by replacing them with valid ones. + + This function is designed to prevent training interruptions due to data corruption. + It logs a warning to alert the user about the number of corrupted samples found. Args: batch: The batch of data from the DataLoader. dataset: The dataset being used, which should return `None` for corrupted examples. default_collate_fn: The default collate function to use once the batch is clean. + max_retries: Maximum number of retry iterations to prevent infinite loops when replacements + are also corrupted. Defaults to 100. Returns: A batch with corrupted examples replaced by valid ones. + + Raises: + RuntimeError: If max_retries is reached and corrupted samples still remain, indicating + a high corruption rate in the dataset. """ # Use `torch.utils.data.dataloader.default_collate` if no other default collate function is specified. default_collate_fn = default_collate_fn if default_collate_fn is not None else default_collate - # Idea from https://stackoverflow.com/a/57882783 - original_batch_len = len(batch) - # Filter out all the Nones (corrupted examples). - batch = list(filter(lambda x: x is not None, batch)) - filtered_batch_len = len(batch) - # Num of corrupted examples. - num_corrupted = original_batch_len - filtered_batch_len + + num_corrupted = 0 + iterations = 0 + while True: + # Filter out corrupted samples (None). + original_len = len(batch) + batch = [sample for sample in batch if sample is not None] + current_len = len(batch) + + # Calculate the number of corrupted samples in this iteration. + newly_corrupted = original_len - current_len + if newly_corrupted == 0: + # No more corrupted samples, break the loop. + break + + # Check if we've exceeded the maximum retry limit. + iterations += 1 + if iterations > max_retries: + raise RuntimeError( + f"Reached maximum retry limit ({max_retries}) while trying to replace corrupted samples. " + f"Found {num_corrupted + newly_corrupted} total corrupted samples with {newly_corrupted} " + f"still remaining. This indicates a high corruption rate in the dataset. " + f"Consider investigating the dataset integrity or increasing max_retries." + ) + + num_corrupted += newly_corrupted + + # Replace corrupted samples with new random samples from the dataset. + replacements = [dataset[random.randint(0, len(dataset) - 1)] for _ in range(newly_corrupted)] # type: ignore[arg-type] + batch.extend(replacements) + + # Log a warning if any corrupted samples were found and replaced. if num_corrupted > 0: - # Replace a corrupted example with another randomly selected example. - batch.extend([dataset[random.randint(0, len(dataset) - 1)] for _ in range(num_corrupted)]) - # Recursive call to replace the replacements if they are corrupted. - return collate_replace_corrupted(batch, dataset) - # Finally, when the whole batch is fine, apply the default collate function. + logger.warning(f"Found and replaced {num_corrupted} corrupted samples in a batch.") + + # Apply the default collate function to the clean batch. return default_collate_fn(batch) diff --git a/src/lighter/utils/dynamic_imports.py b/src/lighter/utils/dynamic_imports.py index 3ec9c368..2c393e26 100644 --- a/src/lighter/utils/dynamic_imports.py +++ b/src/lighter/utils/dynamic_imports.py @@ -1,109 +1,204 @@ +"""Dynamic module imports with multiprocessing spawn support. + +Enables importing Python packages from arbitrary paths while maintaining compatibility +with multiprocessing spawn workers (used by PyTorch DataLoader with num_workers > 0). + +Uses cloudpickle to serialize dynamically imported modules by value, embedding class +definitions in the pickle stream rather than storing module paths that workers can't resolve. + +The key insight is that we need cloudpickle for user-defined classes (to serialize them +by value), but we must preserve ForkingPickler's special reducers for multiprocessing +internals (pipes, queues, connections, file descriptors). We achieve this by creating +a hybrid pickler that inherits from ForkingPickler but also includes cloudpickle's +dispatch table. """ -This module provides utilities for dynamic imports, allowing optional imports and importing modules from paths. -""" -import importlib +from __future__ import annotations + +import importlib.abc +import importlib.machinery import importlib.util import sys -from dataclasses import dataclass, field +from collections.abc import Sequence +from io import BytesIO +from multiprocessing import reduction +from multiprocessing.reduction import ForkingPickler from pathlib import Path +from types import ModuleType +from typing import IO, Any, cast -from loguru import logger +import cloudpickle +__all__ = ["import_module_from_path"] -def optional_import(module_name: str) -> tuple[object, bool]: - """ - Import a module optionally, returning a tuple of (module, is_available). - Args: - module_name: Name of the module to import. +class _ModuleRegistry: + """Maps dynamically imported module names to their filesystem paths.""" - Returns: - Tuple of (module or None, bool indicating if import succeeded). - """ - try: - module = importlib.import_module(module_name) - return module, True - except ImportError: - return None, False + def __init__(self) -> None: + self._modules: dict[str, Path] = {} + def register(self, name: str, path: Path) -> None: + self._modules[name] = path -@dataclass -class OptionalImports: - """ - Handles optional imports, allowing modules to be imported only if they are available. + def get(self, name: str) -> Path | None: + """Get the registered path for a module name.""" + return self._modules.get(name) - Attributes: - imports: A dictionary to store the imported modules. + def find_root(self, fullname: str) -> tuple[str, Path] | None: + """Find the registered root module for an import name (e.g., 'project.sub' -> 'project').""" + for name, path in self._modules.items(): + if fullname == name or fullname.startswith(f"{name}."): + return name, path + return None - Example: - ``` - from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS - writer = OPTIONAL_IMPORTS["tensorboard"].SummaryWriter() - ``` - """ - imports: dict[str, object] = field(default_factory=dict) +_registry = _ModuleRegistry() + + +class _DynamicModuleFinder(importlib.abc.MetaPathFinder): + """Meta path finder for submodules of dynamically imported packages.""" + + def find_spec( + self, + fullname: str, + path: Sequence[str | bytes] | None, + target: ModuleType | None = None, + ) -> importlib.machinery.ModuleSpec | None: + root = _registry.find_root(fullname) + if root is None: + return None + + root_name, root_path = root + file_path = self._resolve_path(fullname, root_name, root_path) + if file_path is None or not file_path.is_file(): + return None + + return importlib.machinery.ModuleSpec( + fullname, + _DynamicModuleLoader(str(file_path), fullname), + origin=str(file_path), + is_package=file_path.name == "__init__.py", + ) - def __getitem__(self, module_name: str) -> object: - """ - Get the imported module by name, importing it if necessary. + def _resolve_path(self, fullname: str, root_name: str, root_path: Path) -> Path | None: + """Resolve 'project.models.net' to '/path/to/project/models/net.py'.""" + if fullname == root_name: + return root_path / "__init__.py" - Args: - module_name: Name of the module to import. + relative = fullname[len(root_name) + 1 :].replace(".", "/") - Raises: - ImportError: If the module is not available. + # Try package first, then module + package_init = root_path / relative / "__init__.py" + if package_init.is_file(): + return package_init + return root_path / f"{relative}.py" - Returns: - object: The imported module. - """ - """Get the imported module by name. - Args: - module_name: Name of the module to import. +class _DynamicModuleLoader(importlib.abc.Loader): + """Loader that registers modules with cloudpickle for by-value serialization.""" - Raises: - ImportError: If the module is not available. + def __init__(self, filepath: str, fullname: str) -> None: + self._filepath = filepath + self._fullname = fullname - Returns: - Imported module. - """ - if module_name not in self.imports: - self.imports[module_name], module_available = optional_import(module_name) - if not module_available: - raise ImportError(f"'{module_name}' is not available. Make sure that it is installed and spelled correctly.") - return self.imports[module_name] + def create_module(self, spec: importlib.machinery.ModuleSpec) -> ModuleType | None: + return None + def exec_module(self, module: ModuleType) -> None: + importlib.machinery.SourceFileLoader(self._fullname, self._filepath).exec_module(module) + cloudpickle.register_pickle_by_value(module) -OPTIONAL_IMPORTS = OptionalImports() +class _HybridPickler(ForkingPickler): + """Pickler combining ForkingPickler's reducers with cloudpickle's by-value serialization. -def import_module_from_path(module_name: str, module_path: Path) -> None: + ForkingPickler has special reducers for multiprocessing internals (pipes, queues, + connections, file descriptors) that must be preserved. cloudpickle can serialize + dynamically defined classes by value. This hybrid uses both: ForkingPickler's + dispatch table takes priority (for multiprocessing internals), then cloudpickle's + reducer_override handles user-defined classes. """ - Import a module from a given path and assign it a specified name. + + def __init__(self, file: IO[bytes], protocol: int | None = None) -> None: + super().__init__(file, protocol) + # Merge cloudpickle's dispatch into our dispatch_table + # ForkingPickler's reducers (from _extra_reducers) take priority + cloudpickle_dispatch = cloudpickle.CloudPickler.dispatch.copy() + cloudpickle_dispatch.update(self.dispatch_table) + self.dispatch_table = cloudpickle_dispatch + + def reducer_override(self, obj: Any) -> Any: + """Use cloudpickle's reducer for objects not in dispatch_table.""" + # If it's in ForkingPickler's extra reducers, let standard dispatch handle it + # _extra_reducers is a class attribute that exists at runtime but isn't typed + extra_reducers = cast(dict[type, Any], getattr(ForkingPickler, "_extra_reducers", {})) + if type(obj) in extra_reducers: + return NotImplemented + + # For everything else, try cloudpickle's reducer_override + # This handles dynamically imported classes registered with register_pickle_by_value + pickler = cloudpickle.CloudPickler(BytesIO()) + return pickler.reducer_override(obj) + + +def _hybrid_dump(obj: Any, file: IO[bytes], protocol: int | None = None) -> None: + """Replacement for reduction.dump using hybrid pickler.""" + _HybridPickler(file, protocol).dump(obj) + + +# Module initialization: install finder and patch multiprocessing +sys.meta_path.insert(0, _DynamicModuleFinder()) +reduction.dump = _hybrid_dump # type: ignore[assignment] + + +def import_module_from_path(module_name: str, module_path: Path | str) -> ModuleType: + """Import a package from a filesystem path with multiprocessing support. Args: - module_name: Name to assign to the imported module. - module_path: Path to the module being imported. + module_name: Name to assign to the module (e.g., "project"). + module_path: Path to the package directory (must contain __init__.py). + + Returns: + The imported module. Raises: - ValueError: If the module has already been imported. - FileNotFoundError: If the `__init__.py` file is not found in the module path. + FileNotFoundError: If module_path doesn't contain __init__.py. + ModuleNotFoundError: If the module cannot be loaded. + ValueError: If module_name was already imported from a different path. + + Example: + >>> import_module_from_path("project", "/path/to/project") + >>> from project.models import MyModel # Works in DataLoader workers! """ - # Based on https://stackoverflow.com/a/41595552. + module_path = Path(module_path).resolve() + # Check if already imported if module_name in sys.modules: - logger.warning(f"{module_name} has already been imported as module.") - return - - module_path = Path(module_path).resolve() / "__init__.py" - if not module_path.is_file(): - raise FileNotFoundError(f"No `__init__.py` in `{module_path}`.") - spec = importlib.util.spec_from_file_location(module_name, str(module_path)) - if spec is None: - raise ModuleNotFoundError(f"Could not find module '{module_name}' at '{module_path}'.") + existing_path = _registry.get(module_name) + if existing_path is not None and existing_path != module_path: + raise ValueError(f"Module '{module_name}' was already imported from '{existing_path}'.") + # Same path - return cached module (normal Python behavior) + return sys.modules[module_name] + + init_file = module_path / "__init__.py" + + if not init_file.is_file(): + raise FileNotFoundError(f"No __init__.py in '{module_path}'.") + + spec = importlib.util.spec_from_file_location(module_name, str(init_file)) + if spec is None or spec.loader is None: + raise ModuleNotFoundError(f"Could not load '{module_name}' from '{module_path}'.") + module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) sys.modules[module_name] = module - logger.info(f"Imported {module_path.parent} as module '{module_name}'.") + try: + spec.loader.exec_module(module) + cloudpickle.register_pickle_by_value(module) + _registry.register(module_name, module_path) + except Exception: + # Clean up on failure so retry sees a clean state + sys.modules.pop(module_name, None) + raise + + return module diff --git a/src/lighter/utils/logging.py b/src/lighter/utils/logging.py index ce9f37dc..3d94c038 100644 --- a/src/lighter/utils/logging.py +++ b/src/lighter/utils/logging.py @@ -7,6 +7,7 @@ """ import importlib +from typing import Any # List of modules to suppress in Rich traceback for cleaner output SUPPRESSED_MODULES = [ @@ -43,7 +44,7 @@ def _setup_logging(): import rich.traceback from loguru import logger - def formatter(record: dict) -> str: + def formatter(record: dict[str, Any]) -> str: """Format log messages for better readability and clarity. Used to configure Loguru with a Rich handler.""" lvl_name = record["level"].name lvl_color = LOGGING_COLOR_MAP.get(lvl_name, "cyan") diff --git a/src/lighter/utils/misc.py b/src/lighter/utils/misc.py index 28abf4e4..508130b4 100644 --- a/src/lighter/utils/misc.py +++ b/src/lighter/utils/misc.py @@ -27,25 +27,6 @@ def ensure_list(input: Any) -> list: return [input] -def setattr_dot_notation(obj: Callable, attr: str, value: Any) -> None: - """ - Sets an attribute on an object using dot notation. - - Args: - obj: The object on which to set the attribute. - attr: The attribute name, which can use dot notation for nested attributes. - value: The value to set the attribute to. - """ - if "." not in attr: - if not hasattr(obj, attr): - raise AttributeError(f"`{get_name(obj, True)}` has no attribute `{attr}`.") - setattr(obj, attr, value) - # Solve recursively if the attribute is defined in dot-notation - else: - obj_name, attr = attr.split(".", maxsplit=1) - setattr_dot_notation(getattr(obj, obj_name), attr, value) - - def hasarg(fn: Callable, arg_name: str) -> bool: """ Checks if a callable (function, method, or class) has a specific argument. @@ -85,9 +66,9 @@ def get_name(_callable: Callable, include_module_name: bool = False) -> str: def get_optimizer_stats(optimizer: Optimizer) -> dict[str, float]: """ - Extract learning rates and momentum values from a PyTorch optimizer. + Extract hyperparameters from a PyTorch optimizer. - Collects learning rate and momentum/beta values from each parameter group + Collects learning rate and other key hyperparameters from each parameter group in the optimizer and returns them in a dictionary. Keys are formatted to show the optimizer type and group number (if multiple groups exist). @@ -95,29 +76,37 @@ def get_optimizer_stats(optimizer: Optimizer) -> dict[str, float]: optimizer: The PyTorch optimizer to extract values from. Returns: - dict[str, float]: dictionary containing: - - Learning rates: "optimizer/{name}/lr[/group{N}]" - - Momentum values: "optimizer/{name}/momentum[/group{N}]" + dict[str, float]: dictionary containing optimizer hyperparameters: + - Learning rate: "optimizer/{name}/lr[/group{N}]" + - Momentum: "optimizer/{name}/momentum[/group{N}]" (SGD, RMSprop) + - Beta1: "optimizer/{name}/beta1[/group{N}]" (Adam variants) + - Beta2: "optimizer/{name}/beta2[/group{N}]" (Adam variants) + - Weight decay: "optimizer/{name}/weight_decay[/group{N}]" Where [/group{N}] is only added for optimizers with multiple groups. """ stats_dict = {} for group_idx, group in enumerate(optimizer.param_groups): - lr_key = f"optimizer/{optimizer.__class__.__name__}/lr" - momentum_key = f"optimizer/{optimizer.__class__.__name__}/momentum" + base_key = f"optimizer/{optimizer.__class__.__name__}" - # Add group index to the key if there are multiple parameter groups - if len(optimizer.param_groups) > 1: - lr_key += f"/group{group_idx + 1}" - momentum_key += f"/group{group_idx + 1}" + # Add group index suffix if there are multiple parameter groups + suffix = f"/group{group_idx + 1}" if len(optimizer.param_groups) > 1 else "" - # Extracting learning rate - stats_dict[lr_key] = group["lr"] + # Always extract learning rate (present in all optimizers) + stats_dict[f"{base_key}/lr{suffix}"] = group["lr"] - # Extracting momentum or betas[0] if available + # Extract momentum (SGD, RMSprop) if "momentum" in group: - stats_dict[momentum_key] = group["momentum"] + stats_dict[f"{base_key}/momentum{suffix}"] = group["momentum"] + + # Extract betas (Adam, AdamW, NAdam, RAdam, etc.) if "betas" in group: - stats_dict[momentum_key] = group["betas"][0] + stats_dict[f"{base_key}/beta1{suffix}"] = group["betas"][0] + if len(group["betas"]) > 1: + stats_dict[f"{base_key}/beta2{suffix}"] = group["betas"][1] + + # Extract weight decay if non-zero + if "weight_decay" in group and group["weight_decay"] != 0: + stats_dict[f"{base_key}/weight_decay{suffix}"] = group["weight_decay"] return stats_dict diff --git a/src/lighter/utils/model.py b/src/lighter/utils/model.py deleted file mode 100644 index fddabbfa..00000000 --- a/src/lighter/utils/model.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -This module provides utility functions for manipulating PyTorch models, such as replacing layers or loading state_dicts. -""" - -import torch -from loguru import logger -from torch.nn import Identity, Module, Sequential - -from lighter.utils.misc import setattr_dot_notation - - -def replace_layer_with(model: Module, layer_name: str, new_layer: Module) -> Module: - """ - Replaces a specified layer in a PyTorch model with a new layer. - - Args: - model: The model to modify. - layer_name: The name of the layer to replace, - using dot notation if necessary (e.g. "layer10.fc.weights"). - new_layer: The new layer to insert. - - Returns: - Module: The modified model with the new layer. - """ - setattr_dot_notation(model, layer_name, new_layer) - return model - - -def replace_layer_with_identity(model: Module, layer_name: str) -> Module: - """ - Replaces a specified layer in a PyTorch model with an Identity layer. - - Args: - model: The model to modify. - layer_name: The name of the layer to replace with an Identity layer, - using dot notation if necessary (e.g. "layer10.fc.weights"). - - Returns: - Module: The modified model with the Identity layer. - """ - return replace_layer_with(model, layer_name, Identity()) - - -def remove_n_last_layers_sequentially(model: Module(), num_layers=1) -> Sequential: - """ - Removes a specified number of layers from the end of a model and returns it as a Sequential model. - - Args: - model: The model to modify. - num_layers: The number of layers to remove from the end. - - Returns: - Sequential: The modified model as a Sequential container. - """ - return Sequential(*list(model.children())[:-num_layers]) - - -def adjust_prefix_and_load_state_dict( - model: Module, - ckpt_path: str, - ckpt_to_model_prefix: dict[str, str] | None = None, - layers_to_ignore: list[str] | None = None, -) -> Module: - """ - This function loads a state dictionary from a checkpoint file into a model using `torch.load(strict=False)`. - It supports remapping layer names between the checkpoint and model through the `ckpt_to_model_prefix` parameter. - - This is useful when loading weights from a model that was trained as part of a larger architecture, - where the layer names may not match the standalone version of the model. - - Before using `ckpt_to_model_prefix`, it's recommended to: - 1. Check the layer names in both the checkpoint and target model - 2. Map the mismatched prefixes accordingly - - Args: - model: The model to load the state_dict into. - ckpt_path: The path to the checkpoint file. - ckpt_to_model_prefix: Mapping of checkpoint prefixes to model prefixes. - layers_to_ignore: Layers to ignore when loading the state_dict. - - Returns: - Module: The model with the loaded state_dict. - - Raises: - ValueError: If there is no overlap between the checkpoint's and model's state_dict. - """ - # Load checkpoint and handle if state_dict is nested. - ckpt = torch.load(ckpt_path) # nosec B614 - if "state_dict" in ckpt: - # System has a model attribute that contains the actual model, remove the "model." prefix - ckpt = {key.replace("model.", ""): value for key, value in ckpt["state_dict"].items()} - - # Adjust checkpoint keys based on prefix mapping - adjusted_ckpt = {} - if ckpt_to_model_prefix: - for ckpt_prefix, model_prefix in ckpt_to_model_prefix.items(): - ckpt_prefix = f"{ckpt_prefix}." if ckpt_prefix and not ckpt_prefix.endswith(".") else ckpt_prefix - model_prefix = f"{model_prefix}." if model_prefix and not model_prefix.endswith(".") else model_prefix - - if ckpt_prefix: - adjusted_ckpt.update( - {key.replace(ckpt_prefix, model_prefix): value for key, value in ckpt.items() if ckpt_prefix in key} - ) - else: - adjusted_ckpt.update({f"{model_prefix}{key}": value for key, value in ckpt.items()}) - - if not adjusted_ckpt: - adjusted_ckpt = ckpt - else: - adjusted_ckpt = ckpt - - # Remove ignored layers if specified - if layers_to_ignore: - for layer in layers_to_ignore: - adjusted_ckpt.pop(layer) - - # Verify overlap between model and checkpoint keys - model_keys = list(model.state_dict().keys()) - ckpt_keys = list(adjusted_ckpt.keys()) - if not set(model_keys) & set(ckpt_keys): - raise ValueError( - "There is no overlap between checkpoint's and model's state_dict." - f"\nModel keys: {model_keys[0] + ', ..., ' + model_keys[-1] if model_keys else '[]'}" - f"\nCheckpoint keys: {ckpt_keys[0] + ', ..., ' + ckpt_keys[-1] if ckpt_keys else '[]'}" - ) - # Load state dict and handle incompatible keys - incompatible_keys = model.load_state_dict(adjusted_ckpt, strict=False) - if incompatible_keys.missing_keys or incompatible_keys.unexpected_keys: - logger.info(f"Encountered incompatible keys during checkpoint loading. If intended, ignore.\n{incompatible_keys}") - else: - logger.info("Checkpoint loaded successfully.") - - return model diff --git a/src/lighter/utils/patches.py b/src/lighter/utils/patches.py deleted file mode 100644 index 8946f55b..00000000 --- a/src/lighter/utils/patches.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Contains code that patches certain issues from other libraries that we expect will be resolved in the future. -""" - -from torch.nn import ModuleDict - - -class PatchedModuleDict(ModuleDict): - """ - This class provides a workaround for key conflicts in PyTorch's ModuleDict by ensuring unique internal keys. - """ - - # https://github.com/pytorch/pytorch/issues/71203 - def __init__(self, modules=None): - """ - Initializes the PatchedModuleDict with optional modules. - - Args: - modules (dict, optional): A dictionary of modules to initialize the ModuleDict. - """ - self._key_map = {} - super().__init__(modules) - - def __setitem__(self, key, module): - """ - Sets the module for the given key, ensuring a unique internal key. - - Args: - key (str): The key to associate with the module. - module (torch.nn.Module): The module to store. - """ - internal_key = f"_{key}" - while internal_key in self._modules: - internal_key = f"_{internal_key}" - self._key_map[key] = internal_key - super().__setitem__(internal_key, module) - - def __getitem__(self, key): - """ - Retrieves the module associated with the given key. - - Args: - key (str): The key for which to retrieve the module. - - Returns: - torch.nn.Module: The module associated with the key. - """ - internal_key = self._key_map.get(key, key) - return super().__getitem__(internal_key) - - def __delitem__(self, key): - """ - Deletes the module associated with the given key. - - Args: - key (str): The key for which to delete the module. - """ - internal_key = self._key_map.pop(key, key) - super().__delitem__(internal_key) - - def __contains__(self, key): - """ - Checks if a module is associated with the given key. - - Args: - key (str): The key to check. - - Returns: - bool: True if the key exists, False otherwise. - """ - internal_key = self._key_map.get(key, key) - return super().__contains__(internal_key) - - def keys(self): - """ - Returns the keys of the modules. - - Returns: - KeysView: A view of the keys in the dictionary. - """ - return self._key_map.keys() - - def items(self): - """ - Returns the items (key, module) in the dictionary. - - Returns: - Generator: A generator yielding key, module pairs. - """ - return ((key, self._modules[internal_key]) for key, internal_key in self._key_map.items()) - - def values(self): - """ - Returns the modules in the dictionary. - - Returns: - Generator: A generator yielding modules. - """ - return (self._modules[internal_key] for internal_key in self._key_map.values()) diff --git a/src/lighter/utils/types/containers.py b/src/lighter/utils/types/containers.py deleted file mode 100644 index c434ba80..00000000 --- a/src/lighter/utils/types/containers.py +++ /dev/null @@ -1,102 +0,0 @@ -from dataclasses import dataclass, field, fields, is_dataclass -from typing import Any - -from torchmetrics import Metric, MetricCollection - -from lighter.adapters import BatchAdapter, CriterionAdapter, LoggingAdapter, MetricsAdapter - - -def nested(cls): - """ - Decorator to handle nested dataclass creation. - Example: - ``` - @nested - @dataclass - class Example: - ... - ``` - """ - original_init = cls.__init__ - - def __init__(self, *args, **kwargs): - for f in fields(cls): - if is_dataclass(f.type) and f.name in kwargs: - kwargs[f.name] = f.type(**kwargs[f.name]) - original_init(self, *args, **kwargs) - - cls.__init__ = __init__ - return cls - - -@dataclass -class Metrics: - train: Metric | MetricCollection | None = None - val: Metric | MetricCollection | None = None - test: Metric | MetricCollection | None = None - - def __post_init__(self): - self.train = self._convert_to_collection(self.train) - self.val = self._convert_to_collection(self.val) - self.test = self._convert_to_collection(self.test) - - def _convert_to_collection(self, x): - if x is not None and not isinstance(x, MetricCollection): - return MetricCollection(x) - return x - - -@dataclass -class DataLoaders: - train: Any | None = None - val: Any | None = None - test: Any | None = None - predict: Any | None = None - - -@dataclass -class Train: - """Train mode sub-dataclass for Adapters.""" - - batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=0, target_accessor=1)) - criterion: CriterionAdapter = field(default_factory=lambda: CriterionAdapter(pred_argument=0, target_argument=1)) - metrics: MetricsAdapter = field(default_factory=lambda: MetricsAdapter(pred_argument=0, target_argument=1)) - logging: LoggingAdapter = field(default_factory=LoggingAdapter) - - -@dataclass -class Val: - """Val mode sub-dataclass for Adapters.""" - - batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=0, target_accessor=1)) - criterion: CriterionAdapter = field(default_factory=lambda: CriterionAdapter(pred_argument=0, target_argument=1)) - metrics: MetricsAdapter = field(default_factory=lambda: MetricsAdapter(pred_argument=0, target_argument=1)) - logging: LoggingAdapter = field(default_factory=LoggingAdapter) - - -@dataclass -class Test: - """Test mode sub-dataclass for Adapters.""" - - batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=0, target_accessor=1)) - metrics: MetricsAdapter = field(default_factory=lambda: MetricsAdapter(pred_argument=0, target_argument=1)) - logging: LoggingAdapter = field(default_factory=LoggingAdapter) - - -@dataclass -class Predict: - """Predict mode sub-dataclass for Adapters.""" - - batch: BatchAdapter = field(default_factory=lambda: BatchAdapter(input_accessor=lambda batch: batch)) - logging: LoggingAdapter = field(default_factory=LoggingAdapter) - - -@nested -@dataclass -class Adapters: - """Root configuration class for all adapters across different modes.""" - - train: Train = field(default_factory=Train) - val: Val = field(default_factory=Val) - test: Test = field(default_factory=Test) - predict: Predict = field(default_factory=Predict) diff --git a/src/lighter/utils/types/enums.py b/src/lighter/utils/types/enums.py index a2dc6cfa..69a1645e 100644 --- a/src/lighter/utils/types/enums.py +++ b/src/lighter/utils/types/enums.py @@ -11,17 +11,6 @@ def __str__(self) -> str: return str(self.value) -class Data(StrEnum): - IDENTIFIER = "identifier" - INPUT = "input" - TARGET = "target" - PRED = "pred" - LOSS = "loss" - METRICS = "metrics" - STEP = "step" - EPOCH = "epoch" - - class Stage(StrEnum): FIT = "fit" VALIDATE = "validate" diff --git a/tests/configs/test1.yaml b/tests/configs/test1.yaml index fb99ddcc..e05280ff 100644 --- a/tests/configs/test1.yaml +++ b/tests/configs/test1.yaml @@ -1,12 +1,3 @@ _requires_: ["$import os"] vars: some_var: "value" -args: - fit: - some_arg: "value" - validate: - some_arg: "value" - predict: - some_arg: "value" - test: - some_arg: "value" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..f9555d80 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,140 @@ +"""Shared test fixtures for the Lighter test suite. + +This module provides common fixtures used across multiple test files to reduce +duplication and maintain consistency. +""" + +from unittest.mock import MagicMock + +import pytest +import torch +from torch import nn +from torch.utils.data import Dataset + +# ============================================================================ +# Mock Fixtures +# ============================================================================ + + +@pytest.fixture +def mock_trainer(): + """Create a mock PyTorch Lightning Trainer for testing. + + Returns: + MagicMock: A configured mock trainer with common attributes. + """ + trainer = MagicMock() + trainer.world_size = 1 + trainer.global_rank = 0 + trainer.is_global_zero = True + trainer.logger = MagicMock() + trainer.strategy.broadcast = lambda x, src: x + trainer.strategy.barrier = MagicMock() + trainer.predict_loop._predictions = [[]] + trainer.predict_loop.num_dataloaders = 1 + + # State flags + trainer.training = False + trainer.validating = False + trainer.testing = False + trainer.predicting = False + trainer.sanity_checking = False + + return trainer + + +# ============================================================================ +# Dataset Fixtures +# ============================================================================ + + +class DummyDataset(Dataset): + """Simple dataset returning random tensors and integer labels. + + Args: + size: Number of samples in the dataset + input_dim: Dimension of input tensors + num_classes: Number of classes for labels + """ + + def __init__(self, size=32, input_dim=4, num_classes=2): + super().__init__() + self.size = size + self.input_dim = input_dim + self.num_classes = num_classes + self.data = [] + + for _ in range(self.size): + x = torch.randn(input_dim) + y = torch.randint(0, num_classes, size=()).item() + self.data.append((x, y)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return self.size + + +@pytest.fixture +def dummy_dataset(): + """Create a simple dataset for testing. + + Returns: + DummyDataset: Dataset with 32 samples, 4D inputs, 2 classes + """ + return DummyDataset(size=32, input_dim=4, num_classes=2) + + +# ============================================================================ +# Model Fixtures +# ============================================================================ + + +class SimpleModel(nn.Module): + """Simple feedforward model with a single linear layer. + + Args: + in_features: Input dimension + out_features: Output dimension + """ + + def __init__(self, in_features=4, out_features=2): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + + def forward(self, x): + return self.linear(x) + + +@pytest.fixture +def simple_model(): + """Create a simple model for testing. + + Returns: + SimpleModel: Model with 4 input features, 2 output features + """ + return SimpleModel(in_features=4, out_features=2) + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def mock_trainer_state(trainer, training=False, validating=False, testing=False, predicting=False, sanity_checking=False): + """Helper function to set trainer state flags for testing mode detection. + + Args: + trainer: Mock trainer object + training: Set training flag + validating: Set validating flag + testing: Set testing flag + predicting: Set predicting flag + sanity_checking: Set sanity_checking flag + """ + trainer.training = training + trainer.validating = validating + trainer.testing = testing + trainer.predicting = predicting + trainer.sanity_checking = sanity_checking diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..5da10265 --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +"""Test fixtures and helper modules for Lighter tests.""" diff --git a/tests/fixtures/plain_lightning_modules.py b/tests/fixtures/plain_lightning_modules.py new file mode 100644 index 00000000..3bf2ca7b --- /dev/null +++ b/tests/fixtures/plain_lightning_modules.py @@ -0,0 +1,118 @@ +"""Plain PyTorch Lightning modules for testing Lighter compatibility. + +These modules demonstrate that Lighter works with ANY PyTorch Lightning module, +not just lighter.System. Users have complete freedom to use plain Lightning. +""" + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from lighter import LighterModule + + +class SimpleDataset(Dataset): + """Minimal dataset for testing.""" + + def __init__(self, size=32): + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, idx): + x = torch.randn(10) + y = torch.randint(0, 2, (1,)).item() + return x, y + + +class PlainLightningModule(pl.LightningModule): + """ + A plain PyTorch Lightning module with NO Lighter-specific code. + + This demonstrates that Lighter works with any LightningModule. + """ + + def __init__(self, input_size=10, hidden_size=20, output_size=2, learning_rate=0.001): + super().__init__() + self.save_hyperparameters() + + self.model = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size), + ) + self.learning_rate = learning_rate + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log("train/loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log("val/loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + def train_dataloader(self): + return DataLoader(SimpleDataset(32), batch_size=8) + + def val_dataloader(self): + return DataLoader(SimpleDataset(16), batch_size=8) + + +class LightningModuleWithDataloaders(pl.LightningModule): + """Lightning module that defines dataloaders internally.""" + + def __init__(self): + super().__init__() + self.model = nn.Linear(10, 2) + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + loss = F.cross_entropy(self(x), y) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(SimpleDataset(32), batch_size=8) + + def val_dataloader(self): + return DataLoader(SimpleDataset(16), batch_size=8) + + +class MyLighterModule(LighterModule): + """Example Lighter module for testing mixed usage.""" + + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = F.cross_entropy(pred, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + return self.training_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + x, y = batch + return {"pred": self(x)} + + def predict_step(self, batch, batch_idx): + return self(batch[0]) diff --git a/tests/integration/test_cifar.py b/tests/integration/test_cifar.py index 0cef31ae..4d79f05f 100644 --- a/tests/integration/test_cifar.py +++ b/tests/integration/test_cifar.py @@ -1,37 +1,45 @@ """Tests for running CIFAR training to verify integrity of the pipeline""" +from pathlib import Path + import pytest from lighter.engine.runner import Runner, Stage -test_overrides = "./tests/integration/test_overrides.yaml" - @pytest.mark.parametrize( ("stage", "config"), [ ( Stage.FIT, - "./projects/cifar10/experiments/example.yaml", + "configs/example.yaml", ), ( Stage.TEST, - "./projects/cifar10/experiments/example.yaml", + "configs/example.yaml", ), ( Stage.PREDICT, - "./projects/cifar10/experiments/example.yaml", + "configs/example.yaml", ), ], ) @pytest.mark.slow -def test_trainer_stage(stage: Stage, config: str): +def test_trainer_stage(stage: Stage, config: str, monkeypatch: pytest.MonkeyPatch) -> None: """ Test the specified stage using the given configuration. Args: stage: The stage to run (e.g., Stage.FIT, Stage.TEST, Stage.PREDICT). config: Path to the configuration file. + monkeypatch: Pytest fixture for changing working directory. """ + # Change to CIFAR10 project directory for auto-discovery + project_dir = Path(__file__).parent.parent.parent / "projects" / "cifar10" + monkeypatch.chdir(project_dir) + + # Paths relative to project directory + test_overrides = "../../tests/integration/test_overrides.yaml" + runner = Runner() - runner.run(stage, config=f"{config},{test_overrides}") - assert runner.trainer.state.finished, f"Stage {stage} did not finish successfully." + runner.run(stage, [config, test_overrides]) + # Runner no longer stores trainer, just verify it completed without error diff --git a/tests/integration/test_config_validation.py b/tests/integration/test_config_validation.py index 0f551d93..7c993bda 100644 --- a/tests/integration/test_config_validation.py +++ b/tests/integration/test_config_validation.py @@ -2,7 +2,7 @@ import tempfile from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -19,33 +19,39 @@ def temp_config_dir(self): with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) - def test_minimal_valid_config_loads_and_validates(self, temp_config_dir): - """Test that minimal valid config loads and validates.""" + def test_minimal_valid_config_loads(self, temp_config_dir): + """Test that minimal valid config loads.""" config_path = temp_config_dir / "minimal.yaml" config_content = """ trainer: _target_: pytorch_lightning.Trainer max_epochs: 1 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} + +data: + _target_: lighter.LighterDataModule + train_dataloader: {} """ config_path.write_text(config_content) runner = Runner() - # Patch _setup and _execute to avoid needing real components - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - # Should not raise ValidationError - runner.run(Stage.FIT, str(config_path)) - assert runner.config is not None - - def test_config_with_all_optional_fields_validates(self, temp_config_dir): - """Test config with all optional fields validates.""" + # Patch internal methods to avoid needing real components + with ( + patch.object(runner, "_resolve_model"), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): + # Should not raise error - just verify config loads + runner.run(Stage.FIT, [str(config_path)]) + + def test_config_with_all_components(self, temp_config_dir): + """Test config with all components.""" config_path = temp_config_dir / "full.yaml" config_content = """ project: ./path/to/project @@ -57,9 +63,9 @@ def test_config_with_all_optional_fields_validates(self, temp_config_dir): _target_: pytorch_lightning.Trainer max_epochs: 10 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity criterion: _target_: torch.nn.MSELoss @@ -69,22 +75,16 @@ def test_config_with_all_optional_fields_validates(self, temp_config_dir): scheduler: _target_: torch.optim.lr_scheduler.StepLR step_size: 10 - inferer: - _target_: lighter.Inferer - metrics: - train: [] - val: [] - test: [] - dataloaders: - train: {} - val: {} - test: {} - predict: {} - adapters: - train: {} - val: {} - test: {} - predict: {} + train_metrics: [] + val_metrics: [] + test_metrics: [] + +data: + _target_: lighter.LighterDataModule + train_dataloader: {} + val_dataloader: {} + test_dataloader: {} + predict_dataloader: {} args: fit: @@ -96,259 +96,25 @@ def test_config_with_all_optional_fields_validates(self, temp_config_dir): config_path.write_text(config_content) runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, str(config_path)) - assert runner.config.get("project") == "./path/to/project" - assert runner.config.get("vars::learning_rate") == 0.001 - assert runner.config.get("system::optimizer::lr") == 0.001 - - def test_missing_trainer_raises_validation_error(self, temp_config_dir): - """Test that missing trainer field raises ValidationError.""" - config_path = temp_config_dir / "invalid_no_trainer.yaml" - config_content = """ -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} -""" - config_path.write_text(config_content) - - runner = Runner() - with pytest.raises(ValueError, match="validation failed"): - runner.run(Stage.FIT, str(config_path)) - - def test_missing_system_raises_validation_error(self, temp_config_dir): - """Test that missing system field raises ValidationError.""" - config_path = temp_config_dir / "invalid_no_system.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 -""" - config_path.write_text(config_content) - - runner = Runner() - with pytest.raises(ValueError, match="validation failed"): - runner.run(Stage.FIT, str(config_path)) - - def test_wrong_type_for_trainer_raises_validation_error(self, temp_config_dir): - """Test that wrong type for trainer raises ValidationError.""" - config_path = temp_config_dir / "invalid_trainer_type.yaml" - config_content = """ -trainer: - - this_should_be_a_dict - - not_a_list - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} -""" - config_path.write_text(config_content) - - runner = Runner() - with pytest.raises(ValueError, match="validation failed"): - runner.run(Stage.FIT, str(config_path)) - - def test_wrong_type_for_system_raises_validation_error(self, temp_config_dir): - """Test that wrong type for system raises ValidationError.""" - config_path = temp_config_dir / "invalid_system_type.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: "this should be a dict" -""" - config_path.write_text(config_content) - - runner = Runner() - with pytest.raises(ValueError, match="validation failed"): - runner.run(Stage.FIT, str(config_path)) - - def test_pruning_removes_unused_dataloaders_for_fit_stage(self, temp_config_dir): - """Test that pruning removes test/predict dataloaders for FIT stage.""" - config_path = temp_config_dir / "all_dataloaders.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {batch_size: 32} - val: {batch_size: 32} - test: {batch_size: 32} - predict: {batch_size: 32} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, str(config_path)) - # After pruning, only train and val should remain - assert runner.config.get("system::dataloaders::train") is not None - assert runner.config.get("system::dataloaders::val") is not None - assert runner.config.get("system::dataloaders::test") is None - assert runner.config.get("system::dataloaders::predict") is None - - def test_pruning_removes_unused_dataloaders_for_test_stage(self, temp_config_dir): - """Test that pruning removes train/val/predict dataloaders for TEST stage.""" - config_path = temp_config_dir / "all_dataloaders.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {batch_size: 32} - val: {batch_size: 32} - test: {batch_size: 32} - predict: {batch_size: 32} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.TEST, str(config_path)) - # After pruning, only test should remain - assert runner.config.get("system::dataloaders::test") is not None - assert runner.config.get("system::dataloaders::train") is None - assert runner.config.get("system::dataloaders::val") is None - assert runner.config.get("system::dataloaders::predict") is None - - def test_pruning_removes_unused_dataloaders_for_predict_stage(self, temp_config_dir): - """Test that pruning removes train/val/test dataloaders for PREDICT stage.""" - config_path = temp_config_dir / "all_dataloaders.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {batch_size: 32} - val: {batch_size: 32} - test: {batch_size: 32} - predict: {batch_size: 32} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.PREDICT, str(config_path)) - # After pruning, only predict should remain - assert runner.config.get("system::dataloaders::predict") is not None - assert runner.config.get("system::dataloaders::train") is None - assert runner.config.get("system::dataloaders::val") is None - assert runner.config.get("system::dataloaders::test") is None - - def test_pruning_removes_unused_metrics(self, temp_config_dir): - """Test that pruning removes unused metrics for each stage.""" - config_path = temp_config_dir / "all_metrics.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - metrics: - train: [] - val: [] - test: [] - dataloaders: - train: {} - val: {} - test: {} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, str(config_path)) - # FIT keeps train and val metrics - assert runner.config.get("system::metrics::train") is not None - assert runner.config.get("system::metrics::val") is not None - assert runner.config.get("system::metrics::test") is None - - def test_pruning_removes_optimizer_for_non_fit_stages(self, temp_config_dir): - """Test that pruning removes optimizer/scheduler for non-FIT stages.""" - config_path = temp_config_dir / "with_optimizer.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - optimizer: - _target_: torch.optim.Adam - lr: 0.001 - scheduler: - _target_: torch.optim.lr_scheduler.StepLR - step_size: 10 - criterion: - _target_: torch.nn.MSELoss - dataloaders: - test: {} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.TEST, str(config_path)) - # TEST stage should remove optimizer, scheduler, and criterion - assert runner.config.get("system::optimizer") is None - assert runner.config.get("system::scheduler") is None - assert runner.config.get("system::criterion") is None - - def test_pruning_keeps_criterion_for_validate_stage(self, temp_config_dir): - """Test that VALIDATE stage keeps criterion but removes optimizer.""" - config_path = temp_config_dir / "validate_config.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - optimizer: - _target_: torch.optim.Adam - lr: 0.001 - criterion: - _target_: torch.nn.MSELoss - dataloaders: - val: {} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.VALIDATE, str(config_path)) - # VALIDATE removes optimizer but keeps criterion - assert runner.config.get("system::optimizer") is None - assert runner.config.get("system::criterion") is not None + # Capture config to verify values + captured_config = None + + def capture_system(config): + nonlocal captured_config + captured_config = config + return MagicMock() + + with ( + patch.object(runner, "_resolve_model", side_effect=capture_system), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): + runner.run(Stage.FIT, [str(config_path)]) + assert captured_config.get("project") == "./path/to/project" + assert captured_config.get("vars::learning_rate") == 0.001 + assert captured_config.get("model::optimizer::lr") == 0.001 def test_multi_file_config_merge(self, temp_config_dir): """Test that multiple config files merge correctly.""" @@ -359,12 +125,14 @@ def test_multi_file_config_merge(self, temp_config_dir): max_epochs: 100 devices: 1 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity - dataloaders: - train: {} + +data: + _target_: lighter.LighterDataModule + train_dataloader: {} """ base_path.write_text(base_content) @@ -374,63 +142,43 @@ def test_multi_file_config_merge(self, temp_config_dir): max_epochs: 1 devices: 2 -system: - model: +model: + network: _target_: torch.nn.Identity criterion: _target_: torch.nn.MSELoss - dataloaders: - train: {} - val: {} + +data: + train_dataloader: {} + val_dataloader: {} """ override_path.write_text(override_content) runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): + # Capture config to verify values + captured_config = None + + def capture_system(config): + nonlocal captured_config + captured_config = config + return MagicMock() + + with ( + patch.object(runner, "_resolve_model", side_effect=capture_system), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): runner.run(Stage.FIT, [str(base_path), str(override_path)]) # Override values should be applied - assert runner.config.get("trainer::max_epochs") == 1 - assert runner.config.get("trainer::devices") == 2 + assert captured_config.get("trainer::max_epochs") == 1 + assert captured_config.get("trainer::devices") == 2 # Model should be present (from override) - assert runner.config.get("system::model::_target_") == "torch.nn.Identity" + assert captured_config.get("model::network::_target_") == "torch.nn.Identity" # New values from override should be added - assert runner.config.get("system::criterion::_target_") == "torch.nn.MSELoss" - assert runner.config.get("system::dataloaders::val") is not None - - def test_comma_separated_config_files(self, temp_config_dir): - """Test that comma-separated config file paths work.""" - base_path = temp_config_dir / "base.yaml" - base_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 100 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} -""" - base_path.write_text(base_content) - - override_path = temp_config_dir / "override.yaml" - override_content = """ -trainer: - max_epochs: 1 - -system: - dataloaders: - val: {} -""" - override_path.write_text(override_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - # Use comma-separated string - config_str = f"{base_path},{override_path}" - runner.run(Stage.FIT, config_str) - assert runner.config.get("trainer::max_epochs") == 1 + assert captured_config.get("model::criterion::_target_") == "torch.nn.MSELoss" + assert captured_config.get("data::val_dataloader") is not None def test_cli_overrides_apply(self, temp_config_dir): """Test that CLI overrides are applied correctly.""" @@ -439,65 +187,40 @@ def test_cli_overrides_apply(self, temp_config_dir): trainer: _target_: pytorch_lightning.Trainer max_epochs: 10 - devices: 1 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity optimizer: _target_: torch.optim.Adam lr: 0.001 - dataloaders: - train: {} - val: {} -""" - config_path.write_text(config_content) - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - overrides = ["trainer::max_epochs=5", "system::optimizer::lr=0.1"] - runner.run(Stage.FIT, str(config_path), overrides) - assert runner.config.get("trainer::max_epochs") == 5 - assert runner.config.get("system::optimizer::lr") == 0.1 - - def test_stage_specific_args_pruned(self, temp_config_dir): - """Test that stage-specific args are preserved and others pruned.""" - config_path = temp_config_dir / "with_args.yaml" - config_content = """ -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 1 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} -args: - fit: - ckpt_path: checkpoint.ckpt - validate: - verbose: true - test: - verbose: false - predict: - return_predictions: true +data: + _target_: lighter.LighterDataModule + train_dataloader: {} """ config_path.write_text(config_content) runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, str(config_path)) - # Fit args should be preserved - assert runner.config.get("args::fit") is not None - # Other args should be pruned - assert runner.config.get("args::validate") is None - assert runner.config.get("args::test") is None - assert runner.config.get("args::predict") is None + captured_config = None + + def capture_config(config): + nonlocal captured_config + captured_config = config + return MagicMock() + + with ( + patch.object(runner, "_resolve_model", side_effect=capture_config), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): + overrides = ["trainer::max_epochs=5", "model::optimizer::lr=0.1"] + runner.run(Stage.FIT, [str(config_path)] + overrides) + assert captured_config.get("trainer::max_epochs") == 5 + assert captured_config.get("model::optimizer::lr") == 0.1 def test_config_with_references(self, temp_config_dir): """Test configuration with Sparkwheel references.""" @@ -507,26 +230,48 @@ def test_config_with_references(self, temp_config_dir): _target_: pytorch_lightning.Trainer max_epochs: 1 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity - metrics: - train: - - _target_: torchmetrics.MeanSquaredError - val: "%::train" - dataloaders: - train: {batch_size: 32} - val: "%::train" + train_metrics: + - _target_: torchmetrics.MeanSquaredError + val_metrics: "%::train_metrics" + +data: + _target_: lighter.LighterDataModule + train_dataloader: {batch_size: 32} + val_dataloader: "%::train_dataloader" """ config_path.write_text(config_content) runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, str(config_path)) - # Before resolution, references should exist - assert runner.config.get("system::metrics::val") == "%::train" - assert runner.config.get("system::dataloaders::val") == "%::train" + # Capture config to verify values + captured_config = None + + def capture_system(config): + nonlocal captured_config + captured_config = config + return MagicMock() + + with ( + patch.object(runner, "_resolve_model", side_effect=capture_system), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): + runner.run(Stage.FIT, [str(config_path)]) + # Raw references (%) remain as strings until resolution + # They are expanded when resolved, not during config loading + train_metrics = captured_config.get("model::train_metrics") + val_metrics_ref = captured_config.get("model::val_metrics") + + # train_metrics should be the list of metric configs + assert train_metrics == [{"_target_": "torchmetrics.MeanSquaredError"}] + # val_metrics uses a raw reference (%::) which Sparkwheel normalizes to absolute path + # The reference stays as a string until instantiation time + assert val_metrics_ref == "%model::train_metrics" def test_config_with_vars(self, temp_config_dir): """Test configuration with vars section.""" @@ -540,24 +285,40 @@ def test_config_with_vars(self, temp_config_dir): _target_: pytorch_lightning.Trainer max_epochs: 1 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity optimizer: _target_: torch.optim.Adam lr: "@vars::learning_rate" - dataloaders: - train: - batch_size: "@vars::batch_size" - val: - batch_size: "@vars::batch_size" + +data: + _target_: lighter.LighterDataModule + train_dataloader: + batch_size: "@vars::batch_size" + val_dataloader: + batch_size: "@vars::batch_size" """ config_path.write_text(config_content) runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, str(config_path)) + # Capture config to verify values + captured_config = None + + def capture_system(config): + nonlocal captured_config + captured_config = config + return MagicMock() + + with ( + patch.object(runner, "_resolve_model", side_effect=capture_system), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): + runner.run(Stage.FIT, [str(config_path)]) # Vars should be accessible - assert runner.config.get("vars::learning_rate") == 0.001 - assert runner.config.get("vars::batch_size") == 32 + assert captured_config.get("vars::learning_rate") == 0.001 + assert captured_config.get("vars::batch_size") == 32 diff --git a/tests/integration/test_overrides.yaml b/tests/integration/test_overrides.yaml index a890e169..8add4e67 100644 --- a/tests/integration/test_overrides.yaml +++ b/tests/integration/test_overrides.yaml @@ -1,8 +1,43 @@ trainer::fast_dev_run: True trainer::accelerator: cpu -system::dataloaders::train::batch_size: 16 -system::dataloaders::train::num_workers: 2 +# Override batch_size by replacing the dataloader config +=data::train_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 16 + num_workers: 2 + pin_memory: True + shuffle: True + dataset: + _target_: torchvision.datasets.CIFAR10 + download: True + root: ./.datasets/ + train: True + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + target_transform: null -system::dataloaders::val::batch_size: 16 -system::dataloaders::val::num_workers: 2 +=data::val_dataloader: + _target_: torch.utils.data.DataLoader + batch_size: 16 + num_workers: 2 + pin_memory: True + shuffle: False + dataset: + _target_: torchvision.datasets.CIFAR10 + download: True + root: ./.datasets/ + train: False + transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + target_transform: null diff --git a/tests/integration/test_plain_lightning.py b/tests/integration/test_plain_lightning.py new file mode 100644 index 00000000..0ddb1000 --- /dev/null +++ b/tests/integration/test_plain_lightning.py @@ -0,0 +1,134 @@ +"""Integration test showing Lighter works with plain PyTorch Lightning modules.""" + +import sys +import tempfile +from pathlib import Path + +# Add tests directory to path so we can import fixtures +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +from lighter.engine.runner import Runner +from lighter.utils.types.enums import Stage + + +def test_lighter_with_plain_lightning_module(): + """Test that Lighter can run a plain PyTorch Lightning module.""" + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "config.yaml" + + # Config using a plain Lightning module + # Dataloaders are defined in the module itself + config_content = """ +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 1 + enable_checkpointing: false + logger: false + +model: + _target_: fixtures.plain_lightning_modules.PlainLightningModule + input_size: 10 + hidden_size: 20 + output_size: 2 + learning_rate: 0.001 +""" + config_path.write_text(config_content) + + # Run with Lighter + runner = Runner() + runner.run(Stage.FIT, [str(config_path)]) + + # Runner no longer stores system, just verify it ran successfully without error + + +def test_lighter_with_lightning_module_and_external_dataloaders(): + """Test using a Lightning module that defines its own dataloaders.""" + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "config.yaml" + + # Minimal config - dataloaders defined in the module itself + config_content = """ +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 1 + enable_checkpointing: false + logger: false + +model: + _target_: fixtures.plain_lightning_modules.LightningModuleWithDataloaders +""" + config_path.write_text(config_content) + + runner = Runner() + runner.run(Stage.FIT, [str(config_path)]) + + # Runner no longer stores system, just verify it ran successfully without error + + +def test_mixed_lighter_system_and_plain_lightning(): + """ + Test that you can switch between Lighter System and plain Lightning + by just changing the config. + """ + + with tempfile.TemporaryDirectory() as tmpdir: + # Test 1: Use LighterModule + config_path1 = Path(tmpdir) / "lighter_module.yaml" + config_content1 = """ +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 1 + enable_checkpointing: false + logger: false + +model: + _target_: fixtures.plain_lightning_modules.MyLighterModule + network: + _target_: torch.nn.Linear + in_features: 10 + out_features: 2 + optimizer: + _target_: torch.optim.Adam + params: "$@model::network.parameters()" + lr: 0.001 + +data: + _target_: lighter.LighterDataModule + train_dataloader: + _target_: torch.utils.data.DataLoader + dataset: + _target_: fixtures.plain_lightning_modules.SimpleDataset + batch_size: 8 + val_dataloader: + _target_: torch.utils.data.DataLoader + dataset: + _target_: fixtures.plain_lightning_modules.SimpleDataset + batch_size: 8 +""" + config_path1.write_text(config_content1) + + runner1 = Runner() + runner1.run(Stage.FIT, [str(config_path1)]) + + # Runner no longer stores system, just verify it ran successfully without error + + # Test 2: Use plain Lightning module (same codebase!) + # Plain Lightning module defines dataloaders in the module itself + config_path2 = Path(tmpdir) / "plain_lightning.yaml" + config_content2 = """ +trainer: + _target_: pytorch_lightning.Trainer + max_epochs: 1 + enable_checkpointing: false + logger: false + +model: + _target_: fixtures.plain_lightning_modules.PlainLightningModule +""" + config_path2.write_text(config_content2) + + runner2 = Runner() + runner2.run(Stage.FIT, [str(config_path2)]) + + # Runner no longer stores system, just verify it ran successfully without error diff --git a/tests/unit/test_adapters.py b/tests/unit/test_adapters.py deleted file mode 100644 index f25c9646..00000000 --- a/tests/unit/test_adapters.py +++ /dev/null @@ -1,301 +0,0 @@ -"""Unit tests for the adapters in lighter/adapters.py""" - -import pytest -import torch -from torch import Tensor - -from lighter.adapters import ( - BatchAdapter, - CriterionAdapter, - LoggingAdapter, - MetricsAdapter, - _ArgumentsAdapter, - _ArgumentsAndTransformsAdapter, - _TransformsAdapter, -) - - -def create_tensor(shape=(1,), value=None): - """Helper function to create dummy tensors.""" - if value is not None: - return Tensor([value]) - return torch.rand(*shape) - - -class TestTransformsAdapter: - def test_no_transforms(self): - """Test that adapter works correctly with no transforms.""" - adapter = _TransformsAdapter() - input, target, pred = create_tensor(), create_tensor(), create_tensor() - transformed_input, transformed_target, transformed_pred = adapter(input, target, pred) - assert transformed_input.equal(input) - assert transformed_target.equal(target) - assert transformed_pred.equal(pred) - - def test_single_transform(self): - """Test adapter with a single transform for each input.""" - transform = lambda x: x + 1 # noqa: E731 - adapter = _TransformsAdapter(input_transforms=transform, target_transforms=transform, pred_transforms=transform) - input, target, pred = create_tensor(), create_tensor(), create_tensor() - transformed_input, transformed_target, transformed_pred = adapter(input, target, pred) - assert transformed_input.equal(input + 1) - assert transformed_target.equal(target + 1) - assert transformed_pred.equal(pred + 1) - - def test_multiple_transforms(self): - """Test adapter with multiple transforms that are applied in sequence.""" - transform1 = lambda x: x * 2 # noqa: E731 - transform2 = lambda x: x - 1 # noqa: E731 - adapter = _TransformsAdapter( - input_transforms=[transform1, transform2], - target_transforms=[transform1, transform2], - pred_transforms=[transform1, transform2], - ) - input, target, pred = create_tensor(), create_tensor(), create_tensor() - transformed_input, transformed_target, transformed_pred = adapter(input, target, pred) - assert transformed_input.equal((input * 2) - 1) - assert transformed_target.equal((target * 2) - 1) - assert transformed_pred.equal((pred * 2) - 1) - - def test_invalid_transform(self): - """Test that adapter raises ValueError for non-callable transforms.""" - adapter = _TransformsAdapter(input_transforms="not_a_callable") - with pytest.raises(ValueError, match="Invalid transform type"): - adapter(create_tensor(), create_tensor(), create_tensor()) - - -class TestArgumentsAdapter: - def test_positional_arguments(self): - """Test adapter with all positional arguments.""" - adapter = _ArgumentsAdapter(input_argument=0, target_argument=1, pred_argument=2) - input, target, pred = create_tensor(), create_tensor(), create_tensor() - args, kwargs = adapter(input, target, pred) - assert args == [input, target, pred] - assert kwargs == {} - - def test_keyword_arguments(self): - """Test adapter with all keyword arguments.""" - adapter = _ArgumentsAdapter(input_argument="in", target_argument="tar", pred_argument="pre") - input, target, pred = create_tensor(), create_tensor(), create_tensor() - args, kwargs = adapter(input, target, pred) - assert args == [] - assert kwargs == {"in": input, "tar": target, "pre": pred} - - def test_mixed_arguments(self): - """Test adapter with mixed positional and keyword arguments.""" - adapter = _ArgumentsAdapter(input_argument=0, target_argument="tar", pred_argument=1) - input, target, pred = create_tensor(), create_tensor(), create_tensor() - args, kwargs = adapter(input, target, pred) - assert args == [input, pred] - assert kwargs == {"tar": target} - - def test_invalid_positional_arguments(self): - """Test that adapter validates consecutive positional arguments.""" - # Test for non-consecutive positional arguments - with pytest.raises(ValueError, match="Positional arguments must be consecutive integers starting from 0"): - _ArgumentsAdapter(input_argument=0, target_argument=2) - - # Test for non-zero starting positional argument - with pytest.raises(ValueError, match="Positional arguments must be consecutive integers starting from 0"): - _ArgumentsAdapter(input_argument=1, target_argument=2) - - -class TestArgumentsAndTransformsAdapter: - def test_valid_initialization(self): - """Test that adapter initializes correctly with valid arguments.""" - adapter = _ArgumentsAndTransformsAdapter(input_argument=0, target_argument=1, pred_argument=2) - assert adapter.input_argument == 0 - assert adapter.target_argument == 1 - assert adapter.pred_argument == 2 - - def test_invalid_initialization(self): - """Test that adapter raises appropriate errors for invalid initialization.""" - with pytest.raises(ValueError, match="Input transforms provided but input_argument is None"): - _ArgumentsAndTransformsAdapter(input_transforms=lambda x: x) - with pytest.raises(ValueError, match="Target transforms provided but target_argument is None"): - _ArgumentsAndTransformsAdapter(target_transforms=lambda x: x) - with pytest.raises(ValueError, match="Pred transforms provided but pred_argument is None"): - _ArgumentsAndTransformsAdapter(pred_transforms=lambda x: x) - - def test_call_with_transforms_and_arguments(self): - """Test adapter correctly applies transforms and arranges arguments.""" - - def mock_fn(pred, target, input): - return pred + target + input - - adapter = _ArgumentsAndTransformsAdapter( - pred_argument=0, - target_argument=1, - input_argument=2, - input_transforms=lambda x: x * 2, - target_transforms=lambda x: x + 1, - pred_transforms=lambda x: x - 1, - ) - input = create_tensor(value=1.0) - target = create_tensor(value=2.0) - pred = create_tensor(value=3.0) - - # Calculate expected result: (pred-1) + (target+1) + (input*2) - # (3-1) + (2+1) + (1*2) = 2 + 3 + 2 = 7 - expected = Tensor([7.0]) - - result = adapter(mock_fn, input, target, pred) - assert result.equal(expected) - - def test_call_with_only_arguments(self): - """Test adapter works correctly with only argument adaptation.""" - - def mock_fn(a, b, c): - return a, b, c - - adapter = _ArgumentsAndTransformsAdapter(input_argument="a", target_argument="b", pred_argument="c") - input, target, pred = create_tensor(), create_tensor(), create_tensor() - result = adapter(mock_fn, input, target, pred) - assert result[0].equal(input) - assert result[1].equal(target) - assert result[2].equal(pred) - - def test_call_with_only_transforms(self): - """Test adapter works correctly with only transforms.""" - - def mock_fn(x): - return x - - adapter = _ArgumentsAndTransformsAdapter(input_argument=0, input_transforms=lambda x: x * 2) - input = create_tensor() - result = adapter(mock_fn, input, None, None) - assert result.equal(input * 2) - - -class TestBatchAdapter: - def test_list_access(self): - """Test adapter correctly accesses list elements.""" - adapter = BatchAdapter(input_accessor=0, target_accessor=1, identifier_accessor=2) - batch = [create_tensor(), create_tensor(), "id"] - input, target, identifier = adapter(batch) - assert input.equal(batch[0]) - assert target.equal(batch[1]) - assert identifier == "id" - - def test_dict_access(self): - """Test adapter correctly accesses dictionary elements.""" - adapter = BatchAdapter(input_accessor="in", target_accessor="tar", identifier_accessor="id") - batch = {"in": create_tensor(), "tar": create_tensor(), "id": "identifier"} - input, target, identifier = adapter(batch) - assert input.equal(batch["in"]) - assert target.equal(batch["tar"]) - assert identifier == "identifier" - - def test_callable_access(self): - """Test adapter correctly uses callable accessors.""" - adapter = BatchAdapter(input_accessor=lambda b: b["data"], target_accessor=lambda b: b["label"]) - batch = {"data": create_tensor(), "label": create_tensor()} - input, target, identifier = adapter(batch) - assert input.equal(batch["data"]) - assert target.equal(batch["label"]) - assert identifier is None - - def test_invalid_access(self): - """Test adapter handles invalid access attempts appropriately.""" - # Test KeyError for non-existent dictionary key - adapter = BatchAdapter(input_accessor="invalid_key") - batch = {"in": create_tensor()} - with pytest.raises(KeyError): - adapter(batch) - - # Test ValueError for invalid accessor type - adapter = BatchAdapter(input_accessor=0) - batch = "invalid_batch_type" - with pytest.raises(ValueError, match="Invalid accessor"): - adapter(batch) - - -class TestCriterionAdapter: - def test_call_with_arguments_and_transforms(self): - """Test CriterionAdapter correctly applies transforms and arranges arguments.""" - - def mock_criterion(pred, target, input): - return pred + target + input - - adapter = CriterionAdapter( - pred_argument=0, - target_argument=1, - input_argument=2, - input_transforms=lambda x: x * 2, - target_transforms=lambda x: x + 1, - pred_transforms=lambda x: x - 1, - ) - input = create_tensor(value=1.0) - target = create_tensor(value=2.0) - pred = create_tensor(value=3.0) - - # Calculate expected result: (pred-1) + (target+1) + (input*2) - # (3-1) + (2+1) + (1*2) = 2 + 3 + 2 = 7 - expected = Tensor([7.0]) - - result = adapter(mock_criterion, input, target, pred) - assert result.equal(expected) - - -class TestMetricsAdapter: - def test_call_with_arguments_and_transforms(self): - """Test MetricsAdapter correctly applies transforms and arranges arguments.""" - - def mock_metric(pred, target, input): - return pred + target + input - - adapter = MetricsAdapter( - pred_argument=0, - target_argument=1, - input_argument=2, - input_transforms=lambda x: x * 2, - target_transforms=lambda x: x + 1, - pred_transforms=lambda x: x - 1, - ) - input = create_tensor(value=1.0) - target = create_tensor(value=2.0) - pred = create_tensor(value=3.0) - - # Calculate expected result: (pred-1) + (target+1) + (input*2) - # (3-1) + (2+1) + (1*2) = 2 + 3 + 2 = 7 - expected = Tensor([7.0]) - - result = adapter(mock_metric, input, target, pred) - assert result.equal(expected) - - -class TestLoggingAdapter: - def test_no_transforms(self): - """Test LoggingAdapter works correctly with no transforms.""" - adapter = LoggingAdapter() - input, target, pred = create_tensor(), create_tensor(), create_tensor() - transformed_input, transformed_target, transformed_pred = adapter(input, target, pred) - assert transformed_input.equal(input) - assert transformed_target.equal(target) - assert transformed_pred.equal(pred) - - def test_single_transform(self): - """Test LoggingAdapter works correctly with single transform.""" - transform = lambda x: x + 1 # noqa: E731 - adapter = LoggingAdapter(input_transforms=transform, target_transforms=transform, pred_transforms=transform) - input, target, pred = create_tensor(), create_tensor(), create_tensor() - transformed_input, transformed_target, transformed_pred = adapter(input, target, pred) - assert transformed_input.equal(input + 1) - assert transformed_target.equal(target + 1) - assert transformed_pred.equal(pred + 1) - - def test_multiple_transforms(self): - """Test LoggingAdapter works correctly with multiple transforms.""" - transform1 = lambda x: x * 2 # noqa: E731 - transform2 = lambda x: x - 1 # noqa: E731 - - adapter = LoggingAdapter( - input_transforms=[transform1, transform2], - target_transforms=[transform1, transform2], - pred_transforms=[transform1, transform2], - ) - input, target, pred = create_tensor(), create_tensor(), create_tensor() - transformed_input, transformed_target, transformed_pred = adapter(input, target, pred) - assert transformed_input.equal((input * 2) - 1) - assert transformed_target.equal((target * 2) - 1) - assert transformed_pred.equal((pred * 2) - 1) diff --git a/tests/unit/test_callbacks_freezer.py b/tests/unit/test_callbacks_freezer.py index 7da5bd2d..6e743274 100644 --- a/tests/unit/test_callbacks_freezer.py +++ b/tests/unit/test_callbacks_freezer.py @@ -1,3 +1,5 @@ +import re + import pytest import torch from pytorch_lightning import Trainer @@ -5,15 +7,17 @@ from torch.utils.data import DataLoader, Dataset from lighter.callbacks.freezer import Freezer -from lighter.system import System +from lighter.model import LighterModule -class DummyDataset(Dataset): - """ - A simple dataset for testing purposes. +def strip_ansi(text: str) -> str: + """Remove ANSI escape sequences from text.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*m|\x1b\]8;[^\\]*\\|\x1b\\") + return ansi_escape.sub("", text) - Generates random input data and labels for training. - """ + +class DummyDataset(Dataset): + """Simple dataset for testing.""" def __init__(self, num_samples=100): self.num_samples = num_samples @@ -29,11 +33,7 @@ def __getitem__(self, idx): class DummyModel(Module): - """ - A simple neural network model for testing purposes. - - Contains three linear layers that can be selectively frozen during training. - """ + """Three-layer network for testing freezing behavior.""" def __init__(self): super().__init__() @@ -48,30 +48,48 @@ def forward(self, x): return x +class DummyLighterModule(LighterModule): + """Concrete System implementation for testing.""" + + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = self.criterion(pred, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + return self.training_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + return {"pred": pred, "target": y} + + def predict_step(self, batch, batch_idx): + return self(batch) + + def train_dataloader(self): + return DataLoader(DummyDataset(), batch_size=32) + + def val_dataloader(self): + return DataLoader(DummyDataset(), batch_size=32) + + @pytest.fixture def dummy_system(): - """ - Fixture that creates a System instance with a dummy model for testing. - - Returns: - System: A configured system with DummyModel, SGD optimizer, and DummyDataset. - """ + """Create a LighterModule with DummyModel for freezer tests.""" model = DummyModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = torch.nn.BCEWithLogitsLoss() - train_dataloader = DataLoader(DummyDataset(), batch_size=32) - return System(model=model, criterion=criterion, optimizer=optimizer, dataloaders={"train": train_dataloader}) + return DummyLighterModule( + network=model, + criterion=criterion, + optimizer=optimizer, + ) def test_freezer_initialization(): - """ - Test the initialization of Freezer with various parameter combinations. - - Verifies: - - Raises ValueError when neither names nor name_starts_with is specified - - Raises ValueError when both until_step and until_epoch are specified - - Correctly stores the names parameter - """ + """Test Freezer initialization validates parameters correctly.""" with pytest.raises(ValueError, match="At least one of `names` or `name_starts_with` must be specified."): Freezer() @@ -82,136 +100,102 @@ def test_freezer_initialization(): def test_freezer_functionality(dummy_system): - """ - Test the basic functionality of Freezer during training. - - Verifies: - - Specified layers are correctly frozen (requires_grad=False) - - Non-specified layers remain unfrozen (requires_grad=True) - """ + """Test that specified layers are frozen while others remain trainable.""" freezer = Freezer(names=["layer1.weight", "layer1.bias"]) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - assert dummy_system.model.layer2.weight.requires_grad + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + assert dummy_system.network.layer2.weight.requires_grad def test_freezer_exceed_until_step(dummy_system): - """ - Test that layers are unfrozen after exceeding the specified step limit. - - Verifies that layers become trainable (requires_grad=True) after the until_step threshold. - """ + """Test that layers are unfrozen after exceeding until_step.""" freezer = Freezer(names=["layer1.weight", "layer1.bias"], until_step=0) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad # Test unfreezing after exceeding until_step freezer = Freezer(names=["layer1.weight", "layer1.bias"], until_step=1) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad def test_freezer_exceed_until_epoch(dummy_system): - """ - Test that layers are unfrozen after exceeding the specified epoch limit. - - Verifies that layers become trainable (requires_grad=True) after the until_epoch threshold. - """ + """Test that layers are unfrozen after exceeding until_epoch.""" freezer = Freezer(names=["layer1.weight", "layer1.bias"], until_epoch=0) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad # Test unfreezing after exceeding until_epoch freezer = Freezer(names=["layer1.weight", "layer1.bias"], until_epoch=1) trainer = Trainer(callbacks=[freezer], max_epochs=2) trainer.fit(dummy_system) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad def test_freezer_set_model_requires_grad(dummy_system): - """ - Test the internal _set_model_requires_grad method of Freezer. - - Verifies: - - Method correctly freezes specified parameters - - Method correctly unfreezes specified parameters - """ + """Test _set_model_requires_grad freezes and unfreezes parameters.""" freezer = Freezer(names=["layer1.weight", "layer1.bias"]) - freezer._set_model_requires_grad(dummy_system.model, requires_grad=False) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - freezer._set_model_requires_grad(dummy_system.model, requires_grad=True) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + freezer._set_model_requires_grad(dummy_system.network, requires_grad=False) + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + freezer._set_model_requires_grad(dummy_system.network, requires_grad=True) + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad # Test with exceptions freezer = Freezer(names=["layer1.weight", "layer1.bias"], except_names=["layer1.bias"]) - freezer._set_model_requires_grad(dummy_system.model, requires_grad=False) - assert not dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad - freezer._set_model_requires_grad(dummy_system.model, requires_grad=True) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + freezer._set_model_requires_grad(dummy_system.network, requires_grad=False) + assert not dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad + freezer._set_model_requires_grad(dummy_system.network, requires_grad=True) + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad def test_freezer_with_exceptions(dummy_system): - """ - Test Freezer with exception patterns for layer freezing. - - Verifies: - - Layers matching name_starts_with are frozen - - Layers in except_names remain unfrozen - - Other layers behave as expected - """ + """Test Freezer respects except_names and except_name_starts_with.""" freezer = Freezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"]) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - assert dummy_system.model.layer2.weight.requires_grad - assert dummy_system.model.layer2.bias.requires_grad - assert not dummy_system.model.layer3.weight.requires_grad - assert not dummy_system.model.layer3.bias.requires_grad + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + assert dummy_system.network.layer2.weight.requires_grad + assert dummy_system.network.layer2.bias.requires_grad + assert not dummy_system.network.layer3.weight.requires_grad + assert not dummy_system.network.layer3.bias.requires_grad # Test with except_name_starts_with freezer = Freezer(name_starts_with=["layer"], except_name_starts_with=["layer2"]) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - assert dummy_system.model.layer2.weight.requires_grad - assert dummy_system.model.layer2.bias.requires_grad - assert not dummy_system.model.layer3.weight.requires_grad - assert not dummy_system.model.layer3.bias.requires_grad + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + assert dummy_system.network.layer2.weight.requires_grad + assert dummy_system.network.layer2.bias.requires_grad + assert not dummy_system.network.layer3.weight.requires_grad + assert not dummy_system.network.layer3.bias.requires_grad def test_freezer_except_name_starts_with(dummy_system): - """ - Test Freezer with except_name_starts_with parameter. - - Verifies: - - Layers matching name_starts_with are frozen - - Layers matching except_name_starts_with remain unfrozen - - Other layers behave as expected - """ + """Test Freezer with except_name_starts_with parameter.""" freezer = Freezer(name_starts_with=["layer"], except_name_starts_with=["layer2"]) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - assert dummy_system.model.layer2.weight.requires_grad - assert dummy_system.model.layer2.bias.requires_grad - assert not dummy_system.model.layer3.weight.requires_grad - assert not dummy_system.model.layer3.bias.requires_grad + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + assert dummy_system.network.layer2.weight.requires_grad + assert dummy_system.network.layer2.bias.requires_grad + assert not dummy_system.network.layer3.weight.requires_grad + assert not dummy_system.network.layer3.bias.requires_grad # Test with both except_names and except_name_starts_with freezer = Freezer( @@ -221,49 +205,73 @@ def test_freezer_except_name_starts_with(dummy_system): ) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - assert not dummy_system.model.layer2.weight.requires_grad - assert dummy_system.model.layer2.bias.requires_grad - assert dummy_system.model.layer3.weight.requires_grad - assert dummy_system.model.layer3.bias.requires_grad + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + assert not dummy_system.network.layer2.weight.requires_grad + assert dummy_system.network.layer2.bias.requires_grad + assert dummy_system.network.layer3.weight.requires_grad + assert dummy_system.network.layer3.bias.requires_grad def test_freezer_set_model_requires_grad_with_exceptions(dummy_system): - """ - Test the _set_model_requires_grad method with various exception patterns. - - Verifies: - - Correct handling of specific parameter exceptions - - Proper behavior with name_starts_with and except_names combinations - - Consistent freezing/unfreezing across multiple configurations - """ + """Test _set_model_requires_grad with various exception patterns.""" freezer = Freezer(names=["layer1.weight", "layer1.bias"], except_names=["layer1.bias"]) - freezer._set_model_requires_grad(dummy_system.model, requires_grad=False) - assert not dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad - freezer._set_model_requires_grad(dummy_system.model, requires_grad=True) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + freezer._set_model_requires_grad(dummy_system.network, requires_grad=False) + assert not dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad + freezer._set_model_requires_grad(dummy_system.network, requires_grad=True) + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad freezer = Freezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"]) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert not dummy_system.model.layer1.weight.requires_grad - assert not dummy_system.model.layer1.bias.requires_grad - assert dummy_system.model.layer2.weight.requires_grad - assert dummy_system.model.layer2.bias.requires_grad - assert not dummy_system.model.layer3.weight.requires_grad - assert not dummy_system.model.layer3.bias.requires_grad + assert not dummy_system.network.layer1.weight.requires_grad + assert not dummy_system.network.layer1.bias.requires_grad + assert dummy_system.network.layer2.weight.requires_grad + assert dummy_system.network.layer2.bias.requires_grad + assert not dummy_system.network.layer3.weight.requires_grad + assert not dummy_system.network.layer3.bias.requires_grad # Test with until_step and until_epoch freezer = Freezer(names=["layer1.weight", "layer1.bias"], until_step=1) trainer = Trainer(callbacks=[freezer], max_epochs=1) trainer.fit(dummy_system) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad freezer = Freezer(names=["layer1.weight", "layer1.bias"], until_epoch=1) trainer = Trainer(callbacks=[freezer], max_epochs=2) trainer.fit(dummy_system) - assert dummy_system.model.layer1.weight.requires_grad - assert dummy_system.model.layer1.bias.requires_grad + assert dummy_system.network.layer1.weight.requires_grad + assert dummy_system.network.layer1.bias.requires_grad + + +def test_freezer_logs_excepted_layers(dummy_system, capsys): + """Test that excepted layers are logged with the correct suffix during freeze.""" + freezer = Freezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"]) + + # Directly call _set_model_requires_grad to test logging + freezer._set_model_requires_grad(dummy_system, requires_grad=False) + + # Loguru logs to stderr by default; strip ANSI codes for CI compatibility + captured = capsys.readouterr() + output = strip_ansi(captured.out + captured.err) + assert "(excepted from freeze)" in output + assert "layer2.weight" in output + assert "layer2.bias" in output + + +def test_freezer_logs_unfrozen_layers_without_suffix(dummy_system, capsys): + """Test that unfrozen layers are logged without suffix when explicitly unfreezing.""" + freezer = Freezer(names=["layer1.weight", "layer1.bias"]) + + # First freeze, then unfreeze + freezer._set_model_requires_grad(dummy_system, requires_grad=False) + capsys.readouterr() # Clear previous output + freezer._set_model_requires_grad(dummy_system, requires_grad=True) + + # Loguru logs to stderr by default; strip ANSI codes for CI compatibility + captured = capsys.readouterr() + output = strip_ansi(captured.out + captured.err) + assert "Unfroze layers" in output + assert "(excepted from freeze)" not in output diff --git a/tests/unit/test_callbacks_utils.py b/tests/unit/test_callbacks_utils.py deleted file mode 100644 index 711bdd74..00000000 --- a/tests/unit/test_callbacks_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch - -from lighter.callbacks.utils import preprocess_image - - -def test_preprocess_image_single_3d(): - """ - Test preprocess_image function with a single 3D image input. - - Tests the reshaping of a single 3D image with dimensions: - - Input: (1, 1, depth, height, width) - - Expected output: (1, depth*height, width) - - The function verifies that a 3D medical image is correctly - reshaped while preserving spatial relationships. - """ - depth = 20 - height = 64 - width = 64 - image = torch.rand(1, 1, depth, height, width) # Single 3D image - processed_image = preprocess_image(image) - assert processed_image.shape == (1, depth * height, width) - - -def test_preprocess_image_batch_3d(): - """ - Test preprocess_image function with a batch of 3D images. - - Tests the reshaping of multiple 3D images with dimensions: - - Input: (batch_size, 1, depth, height, width) - - Expected output: (1, depth*height, batch_size*width) - - The function verifies that a batch of 3D medical images is correctly - reshaped into a single 2D representation while maintaining the - spatial relationships and batch information. - - Args used in test: - batch_size: 8 - depth: 20 - height: 64 - width: 64 - """ - batch_size = 8 - depth = 20 - height = 64 - width = 64 - image = torch.rand(batch_size, 1, depth, height, width) # Batch of 3D images - processed_image = preprocess_image(image) - assert processed_image.shape == (1, depth * height, batch_size * width) diff --git a/tests/unit/test_callbacks_writer_base.py b/tests/unit/test_callbacks_writer_base.py deleted file mode 100644 index 0a1295b3..00000000 --- a/tests/unit/test_callbacks_writer_base.py +++ /dev/null @@ -1,281 +0,0 @@ -import logging -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest -import torch -from loguru import logger - -from lighter.callbacks.writer.base import BaseWriter - - -@pytest.fixture -def target_path(): - """ - Fixture that provides a test path for the writer. - - Returns: - Path: A Path object pointing to "test" directory - """ - return Path("test") - - -class MockWriter(BaseWriter): - """ - Mock implementation of BaseWriter for testing purposes. - - This class provides a minimal implementation of the abstract base class - with a simple tensor writer function. - """ - - @property - def writers(self): - """ - Define available writers for the mock class. - - Returns: - dict: Dictionary containing writer name and corresponding function - """ - return {"tensor": lambda x: None} - - def write(self, tensor, identifier): - """ - Mock implementation of the write method. - - Args: - tensor: The tensor to write - identifier: Identifier for the tensor - """ - pass - - -def test_writer_initialization(target_path): - """ - Test the initialization of writers. - - Tests that: - - MockWriter initializes correctly with valid writer - - Base class raises TypeError when instantiated directly - - Raises ValueError when initialized with invalid writer - - Raises TypeError when initialized with invalid path - """ - # Test initialization with a valid writer - writer = MockWriter(path=target_path, writer="tensor") - assert callable(writer.writer) - with pytest.raises(TypeError): - BaseWriter(path=target_path, writer="tensor") - - # Test initialization with invalid writer - with pytest.raises(ValueError, match="Writer for format invalid_writer does not exist"): - MockWriter(path=target_path, writer="invalid_writer") - - # Test initialization with invalid path - with pytest.raises(TypeError): - MockWriter(path=123, writer="tensor") - - -def test_on_predict_batch_end(target_path): - """ - Test the on_predict_batch_end callback functionality. - - Verifies that: - - Prediction IDs are properly assigned - - Prediction counter increments correctly - - Trainer's prediction list is maintained - - Args: - target_path (Path): Fixture providing test directory path - """ - logging.basicConfig(level=logging.INFO) - trainer = MagicMock() - trainer.world_size = 1 - trainer.predict_loop.num_dataloaders = 1 - trainer.predict_loop._predictions = [[]] - - pl_module = MagicMock() - - writer = MockWriter(path=target_path, writer="tensor") - writer._pred_counter = 0 - - # Test with batch size of 4 and no provided identifiers - outputs = {"pred": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]), "identifier": None} - batch = MagicMock() - batch_idx = 0 - - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - - assert outputs["identifier"] == [0, 1, 2, 3] - assert trainer.predict_loop._predictions == [[]] - assert writer._pred_counter == 4 - - # Test with provided identifiers (batch size 4) - outputs = {"pred": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]), "identifier": [10, 11, 12, 13]} - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - assert writer._pred_counter == 4 # Should not increment when identifiers are provided - - # Check for incorrect identifier length (3 identifiers for 4 predictions) - outputs = {"pred": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]), "identifier": [10, 11, 12]} - with pytest.raises(ValueError, match="The number of predictions"): - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - - # Test with list of tensors (batch size 4) and no provided identifiers - writer._pred_counter = 0 # Reset counter - outputs = { - "pred": [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]), torch.tensor([7, 8])], - "identifier": None, - } - - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - - assert outputs["identifier"] == [0, 1, 2, 3] - assert trainer.predict_loop._predictions == [[]] - assert writer._pred_counter == 4 - - # Test with list of tensors and provided identifiers - outputs = { - "pred": [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]), torch.tensor([7, 8])], - "identifier": [20, 21, 22, 23], - } - - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - assert writer._pred_counter == 4 # Should not increment when identifiers are provided - - # Test list of tensors with incorrect identifier length - outputs = { - "pred": [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]), torch.tensor([7, 8])], - "identifier": [20, 21], - } - with pytest.raises(ValueError, match="The number of predictions"): - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - - # Test with mixed tensor shapes in list - outputs = { - "pred": [ - torch.tensor([[1, 2, 3]]), - torch.tensor([[4, 5, 6]]), - torch.tensor([[7, 8, 9]]), - torch.tensor([[10, 11, 12]]), - ], - "identifier": None, - } - - writer._pred_counter = 0 # Reset counter - writer.on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx) - - assert outputs["identifier"] == [0, 1, 2, 3] - assert writer._pred_counter == 4 - - -def test_writer_setup_predict(target_path, caplog): - """ - Test writer setup for prediction stage. - - Verifies that: - - Writer initializes correctly for prediction - - Prediction counter is properly reset - - Global synchronization works as expected - - Args: - target_path (Path): Fixture providing test directory path - caplog: Pytest fixture for capturing log output - """ - trainer = MagicMock() - trainer.world_size = 1 - trainer.is_global_zero = True - trainer.strategy.broadcast.return_value = target_path - trainer.strategy.barrier = MagicMock() - - pl_module = MagicMock() - - writer = MockWriter(path=target_path, writer="tensor") - writer.setup(trainer, pl_module, stage="predict") - assert writer._pred_counter == 0 - - # Test setup for non-predict stage - writer = MockWriter(path=target_path, writer="tensor") - writer.setup(trainer, pl_module, stage="train") - assert writer._pred_counter is None - - -def test_writer_setup_non_predict(target_path): - """ - Test writer setup for non-prediction stages. - - Verifies that: - - Writer initializes correctly for non-prediction stages (e.g., train) - - Prediction counter remains None - - Path is properly set - - Args: - target_path (Path): Fixture providing test directory path - """ - trainer = MagicMock() - trainer.world_size = 1 - trainer.is_global_zero = True - trainer.strategy.broadcast.return_value = target_path - trainer.strategy.barrier = MagicMock() - - pl_module = MagicMock() - - writer = MockWriter(path=target_path, writer="tensor") - writer.setup(trainer, pl_module, stage="train") - assert writer._pred_counter is None - assert writer.path == target_path - - # Test with invalid path - with pytest.raises(ValueError, match="Writer for format invalid_writer does not exist"): - MockWriter(path=target_path, writer="invalid_writer") - - # Test with invalid path type - with pytest.raises(TypeError): - MockWriter(path=123, writer="tensor") - - -def test_writer_setup_existing_path(target_path): - """ - Test writer setup when the target path already exists. - - Args: - target_path (Path): Fixture providing test directory path - """ - trainer = MagicMock() - trainer.world_size = 1 - trainer.is_global_zero = True - trainer.strategy.broadcast.return_value = target_path - trainer.strategy.barrier = MagicMock() - - pl_module = MagicMock() - writer = MockWriter(path=target_path, writer="tensor") - - # Mock path.exists() to return True and capture loguru warning - warning_messages = [] - logger.add(lambda msg: warning_messages.append(msg.record["message"]), level="WARNING") - - with patch.object(Path, "exists", return_value=True): - writer.setup(trainer, pl_module, stage="predict") - assert any("already exists" in msg for msg in warning_messages) - assert any("existing predictions will be overwritten" in msg for msg in warning_messages) - - -def test_writer_setup_directory_not_shared(target_path): - """ - Test writer setup when directory is not shared between nodes. - - Args: - target_path (Path): Fixture providing test directory path - """ - trainer = MagicMock() - trainer.world_size = 2 - trainer.is_global_zero = False - trainer.global_rank = 1 - trainer.strategy.broadcast.return_value = target_path - trainer.strategy.barrier = MagicMock() - - pl_module = MagicMock() - writer = MockWriter(path=target_path, writer="tensor") - - # Mock path.exists() to return False to simulate directory not being shared - # Also mock torch.distributed.get_rank() to avoid initialization error - with patch.object(Path, "exists", return_value=False), patch("torch.distributed.get_rank", return_value=1): - with pytest.raises(RuntimeError, match="Rank 1 does not share storage with rank 0"): - writer.setup(trainer, pl_module, stage="predict") diff --git a/tests/unit/test_callbacks_writer_file.py b/tests/unit/test_callbacks_writer_file.py deleted file mode 100644 index afcce8d7..00000000 --- a/tests/unit/test_callbacks_writer_file.py +++ /dev/null @@ -1,138 +0,0 @@ -from pathlib import Path - -import numpy as np -import pytest -import torch -from PIL import Image - -from lighter.callbacks.writer.file import FileWriter - - -def test_file_writer_initialization(tmp_path): - """Test the initialization of FileWriter class. - - This test verifies that: - 1. The FileWriter is initialized with the correct path - 2. The writer function is properly assigned based on the writer type - 3. The directory is created and cleaned up properly - - The test creates a temporary directory, initializes a writer, checks its attributes, - and then cleans up the directory. - """ - writer = FileWriter(path=tmp_path, writer="tensor") - assert writer.path == tmp_path - assert writer.writer.__name__ == "write_tensor" # Verify writer function - - -def test_file_writer_write_tensor(tmp_path): - """Test tensor writing functionality of FileWriter. - - This test verifies that: - 1. A tensor can be successfully written to disk - 2. The saved file exists at the expected location - 3. The loaded tensor matches the original tensor exactly - - The test creates a simple tensor, saves it, loads it back, and verifies - the content matches the original. - """ - writer = FileWriter(path=tmp_path, writer="tensor") - tensor = torch.tensor([1, 2, 3]) - writer.write(tensor, identifier=1) - - # Verify file exists - saved_path = writer.path / "1.pt" - assert saved_path.exists() - - # Verify tensor contents - loaded_tensor = torch.load(saved_path) # nosec B614 - assert torch.equal(loaded_tensor, tensor) - - -def test_file_writer_write_image(tmp_path): - """Test image writing functionality of FileWriter. - - This test verifies that: - 1. A tensor representing an image can be successfully written to disk as PNG - 2. The saved file exists at the expected location - 3. The loaded image has the correct dimensions and format - - The test creates a random RGB image tensor, saves it, and verifies - the saved image properties. - """ - writer = FileWriter(path=tmp_path, writer="image") - tensor = torch.randint(0, 256, (3, 64, 64), dtype=torch.uint8) - writer.write(tensor, identifier="image_test") - - # Verify file exists - saved_path = writer.path / "image_test.png" - assert saved_path.exists() - - # Verify image contents - image = Image.open(saved_path) - image_array = np.array(image) - assert image_array.shape == (64, 64, 3) - - -def test_file_writer_write_video(tmp_path): - """Test video writing functionality of LighterFileWriter. - - This test verifies that: - 1. A tensor representing a video can be successfully written to disk as MP4 - 2. The saved file exists at the expected location - - The test creates a random RGB video tensor and verifies it can be saved - to disk in the correct format. - """ - writer = FileWriter(path=tmp_path, writer="video") - tensor = torch.randint(0, 256, (3, 10, 64, 64), dtype=torch.uint8) - writer.write(tensor, identifier="video_test") - - # Verify file exists - saved_path = writer.path / "video_test.mp4" - assert saved_path.exists() - - -def test_file_writer_write_grayscale_video(tmp_path): - """Test grayscale video writing functionality of FileWriter. - - This test verifies that: - 1. A single-channel (grayscale) video tensor can be successfully written to disk - 2. The writer correctly handles the conversion from grayscale to RGB format - 3. The saved file exists at the expected location - - The test creates a grayscale video tensor and verifies it can be properly - converted and saved as an MP4 file. - """ - writer = FileWriter(path=tmp_path, writer="video") - # Create a grayscale video tensor with 1 channel - tensor = torch.randint(0, 256, (1, 10, 64, 64), dtype=torch.uint8) - writer.write(tensor, identifier="grayscale_video_test") - - # Verify file exists - saved_path = writer.path / "grayscale_video_test.mp4" - assert saved_path.exists() - - -def test_file_writer_invalid_directory(): - """Test error handling for invalid directory paths in FileWriter. - - This test verifies that: - 1. Using a file path instead of a directory path raises appropriate errors - 2. Using a non-existent directory path raises appropriate errors - 3. Error messages are clear and descriptive - - The test attempts to initialize writers with invalid paths and verifies - the correct error handling behavior. - """ - test_file = Path("test_file.txt") - test_file.touch() # Create a file instead of a directory - try: - with pytest.raises(RuntimeError, match="FileWriter expects a directory path"): - writer = FileWriter(path=test_file, writer="tensor") - writer.write(torch.tensor([1, 2, 3]), identifier=1) - finally: - test_file.unlink() # Clean up the file after test - - with pytest.raises(RuntimeError): - writer = FileWriter(path=Path("invalid_dir"), writer="tensor") - writer.write(torch.tensor([1, 2, 3]), identifier=1) diff --git a/tests/unit/test_callbacks_writer_table.py b/tests/unit/test_callbacks_writer_table.py deleted file mode 100644 index 96ea000c..00000000 --- a/tests/unit/test_callbacks_writer_table.py +++ /dev/null @@ -1,279 +0,0 @@ -from pathlib import Path -from unittest import mock - -import pandas as pd -import torch -from pytorch_lightning import Trainer - -from lighter.callbacks.writer.table import TableWriter - - -def custom_writer(tensor): - """ - Custom writer function that sums a tensor and returns it in a dictionary. - - Args: - tensor (torch.Tensor): Input tensor to be processed - - Returns: - dict: Dictionary with key 'custom' and value as the sum of the input tensor - """ - return {"custom": tensor.sum().item()} - - -def test_table_writer_initialization(): - """ - Test proper initialization of TableWriter. - - Verifies that: - - The writer is correctly instantiated with the given path - - The path attribute is properly converted to a Path object - """ - writer = TableWriter(path="test.csv", writer="tensor") - assert writer.path == Path("test.csv") - - -def test_table_writer_custom_writer(): - """ - Test TableWriter with a custom writer function. - - Verifies that: - - Custom writer function is properly integrated - - Writer correctly processes tensor input using custom function - - Resulting CSV records contain expected values - """ - writer = TableWriter(path="test.csv", writer=custom_writer) - test_tensor = torch.tensor([1, 2, 3]) - writer.write(tensor=test_tensor, identifier=1) - assert writer.csv_records[0]["pred"] == {"custom": 6} - - -def test_table_writer_write(): - """ - Test TableWriter write functionality with various inputs. - - Tests: - - Basic tensor writing with integer IDENTIFIER - - Writing with string IDENTIFIER - - Writing floating point tensors - - CSV file creation and content verification - - Proper handling of different tensor shapes and types - - File Operations: - - Creates a temporary CSV file - - Writes multiple records with different formats - - Verifies file content matches expected records - - Cleans up by removing the test file - """ - test_file = Path("test.csv") - writer = TableWriter(path="test.csv", writer="tensor") - - expected_records = [ - {"identifier": 1, "pred": [1, 2, 3]}, - {"identifier": "some_id", "pred": -1}, - {"identifier": 1331, "pred": [1.5, 2.5]}, - ] - # Test basic write - writer.write(tensor=torch.tensor(expected_records[0]["pred"]), identifier=expected_records[0]["identifier"]) - assert len(writer.csv_records) == 1 - assert writer.csv_records[0]["pred"] == expected_records[0]["pred"] - assert writer.csv_records[0]["identifier"] == expected_records[0]["identifier"] - - # Test edge cases - writer.write(tensor=torch.tensor(expected_records[1]["pred"]), identifier=expected_records[1]["identifier"]) - writer.write(tensor=torch.tensor(expected_records[2]["pred"]), identifier=expected_records[2]["identifier"]) - trainer = Trainer(max_epochs=1) - writer.on_predict_epoch_end(trainer, mock.Mock()) - - # Verify file creation and content - assert test_file.exists() - df = pd.read_csv(test_file) - df["identifier"] = df["identifier"].astype(str) - df["pred"] = df["pred"].apply(eval) - - for record in expected_records: - row = df[df["identifier"] == str(record["identifier"])] - assert not row.empty - pred_value = row["pred"].iloc[0] # get the value from the Series - assert pred_value == record["pred"] - - # Cleanup - test_file.unlink() - - -def test_table_writer_write_multi_process_rank0(tmp_path, monkeypatch): - """ - Test TableWriter in a multi-process environment from rank 0. - - Tests: - - Writing records from multiple processes - - Proper gathering of records across processes - - Correct file creation and content verification in distributed setting - - Args: - tmp_path (Path): Pytest fixture providing temporary directory path - monkeypatch (MonkeyPatch): Pytest fixture for mocking - - Mocked Behaviors: - - Simulates 2-process distributed environment - - Mocks torch.distributed functions for testing - - Simulates gathering of records from multiple ranks - - Verifies: - - All records from different processes are properly gathered - - CSV file contains correct combined records - - Record order and content integrity is maintained - """ - test_file = tmp_path / "test.csv" - writer = TableWriter(path=test_file, writer="tensor") - - # Expected records after gathering from all processes - rank0_records = [{"identifier": 1, "pred": [1, 2, 3]}] # records from rank 0 - rank1_records = [{"identifier": 2, "pred": [4, 5, 6]}] # records from rank 1 - expected_records = rank0_records + rank1_records - - # Mock distributed functions for multi-process simulation - def mock_gather(obj, gather_list, dst=0): - if gather_list is not None: - # Fill gather_list with records from each rank - gather_list[0] = rank0_records - gather_list[1] = rank1_records - - def mock_get_rank(): - return 0 - - # Create a mock strategy with is_global_zero property - mock_strategy = mock.MagicMock() - mock_strategy.world_size = 2 - type(mock_strategy).is_global_zero = mock.PropertyMock(return_value=True) - - # Create a mock trainer with the strategy - trainer = mock.MagicMock() - type(trainer).strategy = mock.PropertyMock(return_value=mock_strategy) - type(trainer).world_size = mock.PropertyMock(return_value=2) - - monkeypatch.setattr(torch.distributed, "gather_object", mock_gather) - monkeypatch.setattr(torch.distributed, "get_rank", mock_get_rank) - - writer.on_predict_epoch_end(trainer, mock.Mock()) - - # Verify file creation - assert test_file.exists() - - # Verify file content - df = pd.read_csv(test_file) - df["identifier"] = df["identifier"].astype(str) - df["pred"] = df["pred"].apply(eval) - - # Check that all expected records are in the CSV - for record in expected_records: - row = df[df["identifier"] == str(record["identifier"])] - assert not row.empty - pred_value = row["pred"].iloc[0] - assert pred_value == record["pred"] - - # Verify total number of records - assert len(df) == len(expected_records) - - -def test_table_writer_write_multi_process_rank1(tmp_path, monkeypatch): - """ - Test TableWriter in a multi-process environment from rank 1. - - Tests: - - Writing records from non-zero rank - - Proper gathering of records to rank 0 - - No file creation on non-zero rank - - Args: - tmp_path (Path): Pytest fixture providing temporary directory path - monkeypatch (MonkeyPatch): Pytest fixture for mocking - - Mocked Behaviors: - - Simulates 2-process distributed environment from rank 1 - - Mocks torch.distributed functions for testing - - Simulates gathering of records to rank 0 - - Verifies: - - Records are properly gathered to rank 0 - - No file is created on rank 1 - """ - test_file = tmp_path / "test.csv" - writer = TableWriter(path=test_file, writer="tensor") - - # Create some records for rank 1 - writer.write(tensor=torch.tensor([4, 5, 6]), identifier=2) - - # Mock distributed functions for multi-process simulation - def mock_gather(obj, gather_list, dst=0): - # Just verify obj is our records - assert len(obj) == 1 - assert obj[0]["identifier"] == 2 - assert obj[0]["pred"] == [4, 5, 6] - # In a real distributed environment, gather_object would handle sending - # our records to rank 0, but we don't need to simulate that in the test - # since we're just verifying rank 1's behavior - - def mock_get_rank(): - return 1 - - # Create a mock trainer for rank 1 - trainer = mock.MagicMock() - trainer.world_size = 2 - trainer.is_global_zero = False - - # Mock torch.distributed.get_rank to return 1 (non-zero rank) - monkeypatch.setattr(torch.distributed, "gather_object", mock_gather) - monkeypatch.setattr(torch.distributed, "get_rank", mock_get_rank) - - # Run the test - writer.on_predict_epoch_end(trainer, mock.Mock()) - - # Verify no file was created on rank 1 - assert not test_file.exists() - # Verify records were cleared - assert len(writer.csv_records) == 0 - - -def test_table_writer_unsortable_identifiers(tmp_path): - """ - Test TableWriter with identifiers that cannot be sorted. - - Tests: - - Writing records with unsortable identifiers - - Proper handling of TypeError during sorting - - Successful file creation despite sorting failure - - Args: - tmp_path (Path): Pytest fixture providing temporary directory path - """ - test_file = tmp_path / "test.csv" - writer = TableWriter(path=test_file, writer="tensor") - - # Create records with unsortable identifiers (mix of types) - records = [ - {"identifier": 1, "pred": [1, 2]}, - {"identifier": "a", "pred": [3, 4]}, - {"identifier": [1, 2], "pred": [5, 6]}, # List is not comparable with str/int - ] - - for record in records: - writer.write(tensor=torch.tensor(record["pred"]), identifier=record["identifier"]) - - trainer = Trainer(max_epochs=1) - writer.on_predict_epoch_end(trainer, mock.Mock()) - - # Verify file creation and content - assert test_file.exists() - df = pd.read_csv(test_file) - df["pred"] = df["pred"].apply(eval) - - # Check that all records are in the CSV (order doesn't matter) - assert len(df) == len(records) - for record in records: - # Convert identifier to string since pandas reads it as string - identifier = str(record["identifier"]) - row = df[df["identifier"] == identifier] - assert not row.empty - pred_value = row["pred"].iloc[0] - assert pred_value == record["pred"] diff --git a/tests/unit/test_callbacks_writers.py b/tests/unit/test_callbacks_writers.py new file mode 100644 index 00000000..ba89f080 --- /dev/null +++ b/tests/unit/test_callbacks_writers.py @@ -0,0 +1,764 @@ +"""Unit tests for FileWriter and CsvWriter callbacks.""" + +from unittest.mock import MagicMock + +import pandas as pd +import pytest +import torch + +from lighter.callbacks.base_writer import BaseWriter +from lighter.callbacks.csv_writer import CsvWriter +from lighter.callbacks.file_writer import FileWriter, writer_registry +from lighter.model import LighterModule +from lighter.utils.types.enums import Stage + + +@pytest.fixture +def mock_system(): + """Create a mock LighterModule for testing.""" + return MagicMock(spec=LighterModule) + + +# ============================================================================= +# BaseWriter Tests +# ============================================================================= + + +class TestBaseWriter: + """Test suite for BaseWriter callback.""" + + def test_setup_warns_on_existing_path(self, tmp_path, mock_trainer, mock_system): + """Test that setup warns when path already exists.""" + + # Create a concrete implementation for testing + class ConcreteWriter(BaseWriter): + def write(self, outputs, batch, batch_idx, dataloader_idx): + pass + + # Create existing file + existing_file = tmp_path / "existing.csv" + existing_file.touch() + + writer = ConcreteWriter(path=existing_file) + + # Should complete without error (just logs a warning) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + assert writer.path == existing_file + + def test_on_predict_batch_end_with_outputs(self, tmp_path, mock_trainer, mock_system): + """Test on_predict_batch_end calls write when outputs exist.""" + write_called = [] + + class ConcreteWriter(BaseWriter): + def write(self, outputs, batch, batch_idx, dataloader_idx): + write_called.append((outputs, batch, batch_idx, dataloader_idx)) + + writer = ConcreteWriter(path=tmp_path / "test.csv") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = {"pred": torch.tensor([1, 2, 3])} + batch = {"input": torch.tensor([1])} + + writer.on_predict_batch_end(mock_trainer, mock_system, outputs, batch, 5, 0) + + assert len(write_called) == 1 + assert write_called[0][0] == outputs + assert write_called[0][2] == 5 # batch_idx + + def test_on_predict_batch_end_empty_outputs(self, tmp_path, mock_trainer, mock_system): + """Test on_predict_batch_end skips when outputs are empty.""" + write_called = [] + + class ConcreteWriter(BaseWriter): + def write(self, outputs, batch, batch_idx, dataloader_idx): + write_called.append(True) + + writer = ConcreteWriter(path=tmp_path / "test.csv") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Empty outputs + writer.on_predict_batch_end(mock_trainer, mock_system, {}, None, 0, 0) + + # Write should not be called + assert len(write_called) == 0 + + +# ============================================================================= +# FileWriter Tests +# ============================================================================= + + +class TestFileWriter: + """Test suite for FileWriter callback.""" + + def test_initialization(self, tmp_path): + """Test FileWriter initialization with built-in writer.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + assert writer.path == tmp_path + assert writer.value_key == "pred" + assert writer.name_key is None + assert callable(writer.writer_fn) + + def test_initialization_custom_writer(self, tmp_path): + """Test FileWriter initialization with custom writer function.""" + + def custom_writer(path, tensor): + pass + + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn=custom_writer) + assert writer.writer_fn == custom_writer + + def test_initialization_invalid_writer(self, tmp_path): + """Test FileWriter raises error for invalid writer.""" + with pytest.raises(ValueError, match="Writer with name 'invalid' is not registered"): + FileWriter(directory=tmp_path, value_key="pred", writer_fn="invalid") + + def test_initialization_non_callable(self, tmp_path): + """Test FileWriter raises error for non-callable writer.""" + with pytest.raises(TypeError, match="writer_fn must be a string or a callable"): + FileWriter(directory=tmp_path, value_key="pred", writer_fn=123) + + def test_setup(self, tmp_path, mock_trainer, mock_system): + """Test FileWriter setup creates directory and initializes counter.""" + writer = FileWriter(directory=tmp_path / "predictions", value_key="pred", writer_fn="tensor") + + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + assert writer.path.exists() + assert writer._counter == 0 + assert writer._step == 1 + + def test_setup_distributed(self, tmp_path, mock_trainer, mock_system): + """Test FileWriter setup in distributed setting.""" + mock_trainer.world_size = 4 + mock_trainer.global_rank = 2 + + writer = FileWriter(directory=tmp_path / "predictions", value_key="pred", writer_fn="tensor") + + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + assert writer._counter == 2 # Starts at rank + assert writer._step == 4 # Step size = world_size + + def test_setup_file_path_error(self, tmp_path, mock_trainer, mock_system): + """Test FileWriter raises error if path is a file.""" + file_path = tmp_path / "file.pt" + + writer = FileWriter(directory=file_path, value_key="pred", writer_fn="tensor") + + with pytest.raises(ValueError, match="expects 'directory' to be a directory path"): + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + def test_write_with_tensor_batch(self, tmp_path, mock_trainer, mock_system): + """Test writing a batch of tensor predictions.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = { + "pred": torch.tensor([[1, 2], [3, 4], [5, 6]]) # 3 predictions + } + + writer.write(outputs, None, 0, 0) + + # Check files were created + assert (tmp_path / "0.pt").exists() + assert (tmp_path / "1.pt").exists() + assert (tmp_path / "2.pt").exists() + + # Verify content + loaded = torch.load(tmp_path / "0.pt") + assert torch.equal(loaded, torch.tensor([1, 2])) + + def test_write_with_custom_names(self, tmp_path, mock_trainer, mock_system): + """Test writing with custom sample names.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor", name_key="sample_id") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = {"pred": torch.tensor([[1, 2], [3, 4]]), "sample_id": ["patient_001", "patient_002"]} + + writer.write(outputs, None, 0, 0) + + assert (tmp_path / "patient_001.pt").exists() + assert (tmp_path / "patient_002.pt").exists() + + def test_write_length_mismatch(self, tmp_path, mock_trainer, mock_system): + """Test error when predictions and names have different lengths.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor", name_key="sample_id") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = { + "pred": torch.tensor([[1, 2], [3, 4], [5, 6]]), # 3 items + "sample_id": ["id1", "id2"], # 2 items + } + + with pytest.raises(ValueError, match="Length mismatch"): + writer.write(outputs, None, 0, 0) + + def test_write_missing_key(self, tmp_path, mock_trainer, mock_system): + """Test error when required key is missing from outputs.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = {"other_key": torch.tensor([1, 2, 3])} + + with pytest.raises(KeyError, match="expected key 'pred'"): + writer.write(outputs, None, 0, 0) + + def test_to_sequence_scalar_tensor(self): + """Test _to_sequence handles scalar tensors.""" + outputs = {"pred": torch.tensor(5.0)} + result = FileWriter._to_sequence(outputs, "pred") + assert len(result) == 1 + assert torch.equal(result[0], torch.tensor(5.0)) + + def test_to_sequence_list(self): + """Test _to_sequence handles lists.""" + outputs = {"pred": [1, 2, 3, 4]} + result = FileWriter._to_sequence(outputs, "pred") + assert result == [1, 2, 3, 4] + + def test_to_sequence_tuple(self): + """Test _to_sequence handles tuples.""" + outputs = {"pred": (1, 2, 3)} + result = FileWriter._to_sequence(outputs, "pred") + assert result == [1, 2, 3] + + def test_prepare_value_tensor(self): + """Test _prepare_value moves tensor to CPU.""" + tensor = torch.tensor([1, 2, 3]) + result = FileWriter._prepare_value(tensor) + assert result.device.type == "cpu" + + def test_prepare_name_scalar_tensor(self): + """Test _prepare_name handles scalar tensor.""" + name = torch.tensor(42) + result = FileWriter._prepare_name(name) + assert result == 42 + + def test_prepare_name_vector_tensor(self): + """Test _prepare_name handles vector tensor.""" + name = torch.tensor([1, 2, 3]) + result = FileWriter._prepare_name(name) + assert result == [1, 2, 3] + + def test_prepare_name_non_tensor(self): + """Test _prepare_name handles non-tensor values.""" + result = FileWriter._prepare_name("sample_001") + assert result == "sample_001" + + def test_prepare_value_non_tensor(self): + """Test _prepare_value handles non-tensor values.""" + result = FileWriter._prepare_value("text_value") + assert result == "text_value" + + def test_write_before_setup(self, tmp_path, mock_trainer, mock_system): + """Test that write skips batch when called before setup.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + # Don't call setup - counter will be None + + outputs = {"pred": torch.tensor([[1, 2], [3, 4]])} + + # Should skip without error + writer.write(outputs, None, 0, 0) + + # No files should be created + assert len(list(tmp_path.glob("*.pt"))) == 0 + + def test_write_empty_values(self, tmp_path, mock_trainer, mock_system): + """Test write handles empty values gracefully.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Empty tensor + outputs = {"pred": torch.tensor([])} + + # Should skip without error + writer.write(outputs, None, 0, 0) + + # No files should be created + assert len(list(tmp_path.glob("*.pt"))) == 0 + + def test_to_sequence_generic_sequence(self): + """Test _to_sequence handles generic sequences (not str/bytes).""" + + outputs = {"pred": range(5)} # range is a Sequence + result = FileWriter._to_sequence(outputs, "pred") + assert result == [0, 1, 2, 3, 4] + + def test_to_sequence_single_value(self): + """Test _to_sequence wraps single non-sequence value.""" + outputs = {"pred": 42} + result = FileWriter._to_sequence(outputs, "pred") + assert result == [42] + + def test_to_sequence_string_not_split(self): + """Test _to_sequence doesn't split strings.""" + outputs = {"pred": "sample_001"} + result = FileWriter._to_sequence(outputs, "pred") + assert result == ["sample_001"] + + def test_write_with_nested_directory(self, tmp_path, mock_trainer, mock_system): + """Test FileWriter creates nested directories when using custom names.""" + writer = FileWriter(directory=tmp_path / "outputs", value_key="pred", writer_fn="tensor", name_key="path") + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = {"pred": torch.tensor([[1, 2]]), "path": ["subfolder/sample"]} + + writer.write(outputs, None, 0, 0) + + assert (tmp_path / "outputs" / "subfolder" / "sample.pt").exists() + + +class TestWriterRegistry: + """Test suite for writer registry.""" + + def test_builtin_writers_exist(self): + """Test that all built-in writers are registered.""" + assert "tensor" in writer_registry._registry + assert "image_2d" in writer_registry._registry + assert "image_3d" in writer_registry._registry + assert "text" in writer_registry._registry + + def test_get_existing_writer(self): + """Test getting an existing writer.""" + writer = writer_registry.get("tensor") + assert callable(writer) + + def test_get_nonexistent_writer(self): + """Test error when getting non-existent writer.""" + with pytest.raises(ValueError, match="Writer with name 'nonexistent' is not registered"): + writer_registry.get("nonexistent") + + def test_tensor_writer(self, tmp_path): + """Test tensor writer function.""" + from lighter.callbacks.file_writer import write_tensor + + path = tmp_path / "test" + tensor = torch.tensor([1, 2, 3, 4]) + + write_tensor(path, tensor) + + assert (tmp_path / "test.pt").exists() + loaded = torch.load(tmp_path / "test.pt") + assert torch.equal(loaded, tensor) + + def test_text_writer(self, tmp_path): + """Test text writer function.""" + from lighter.callbacks.file_writer import write_text + + path = tmp_path / "test" + value = "Hello, World!" + + write_text(path, value) + + assert (tmp_path / "test.txt").exists() + content = (tmp_path / "test.txt").read_text() + assert content == "Hello, World!" + + def test_image_2d_writer(self, tmp_path): + """Test image 2D writer function.""" + from lighter.callbacks.file_writer import write_image_2d + + path = tmp_path / "test" + # Create a valid 3D tensor (CHW format) + tensor = torch.rand(3, 64, 64) + + write_image_2d(path, tensor) + + assert (tmp_path / "test.png").exists() + + def test_image_2d_writer_invalid_dimensions(self, tmp_path): + """Test image 2D writer raises error for wrong dimensions.""" + from lighter.callbacks.file_writer import write_image_2d + + path = tmp_path / "test" + # Create invalid 2D tensor instead of 3D + tensor = torch.rand(64, 64) + + with pytest.raises(ValueError, match="write_image_2d expects a 3D tensor"): + write_image_2d(path, tensor) + + def test_image_3d_writer(self, tmp_path): + """Test image 3D writer function.""" + from lighter.callbacks.file_writer import write_image_3d + + path = tmp_path / "test" + # Create a valid 4D tensor (CDHW format) + tensor = torch.rand(3, 10, 64, 64) + + write_image_3d(path, tensor) + + assert (tmp_path / "test.png").exists() + + def test_image_3d_writer_invalid_dimensions(self, tmp_path): + """Test image 3D writer raises error for wrong dimensions.""" + from lighter.callbacks.file_writer import write_image_3d + + path = tmp_path / "test" + # Create invalid 3D tensor instead of 4D + tensor = torch.rand(3, 64, 64) + + with pytest.raises(ValueError, match="write_image_3d expects a 4D tensor"): + write_image_3d(path, tensor) + + def test_writer_registry_register_duplicate(self): + """Test that registering duplicate writer raises error.""" + from lighter.callbacks.file_writer import WriterRegistry + + registry = WriterRegistry() + + @registry.register("test_writer") + def writer1(path, value): + pass + + with pytest.raises(ValueError, match="Writer with name 'test_writer' is already registered"): + + @registry.register("test_writer") + def writer2(path, value): + pass + + +# ============================================================================= +# CsvWriter Tests +# ============================================================================= + + +class TestCsvWriter: + """Test suite for CsvWriter callback.""" + + def test_initialization(self, tmp_path): + """Test CsvWriter initialization.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target", "loss"]) + assert writer.path == tmp_path / "results.csv" + assert writer.keys == ["pred", "target", "loss"] + + def test_setup(self, tmp_path, mock_trainer, mock_system): + """Test CsvWriter setup creates temp file and writes header.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Check temp file created with header + temp_file = tmp_path / "results.tmp_rank0.csv" + assert temp_file.exists() + # File is created but header not yet written (written on first write) + # Close the file to flush the header + writer._csv_file.close() + content = temp_file.read_text() + assert "pred,target" in content + + def test_write_tensor_batch(self, tmp_path, mock_trainer, mock_system): + """Test writing a batch with tensor values.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = {"pred": torch.tensor([0.1, 0.2, 0.3]), "target": torch.tensor([0, 1, 0])} + + writer.write(outputs, None, 0, 0) + + # Close file to flush writes + writer._csv_file.close() + + # Read temp file + temp_file = tmp_path / "results.tmp_rank0.csv" + df = pd.read_csv(temp_file) + + assert len(df) == 3 + assert list(df.columns) == ["pred", "target"] + assert df["pred"].tolist() == pytest.approx([0.1, 0.2, 0.3], rel=1e-5) + assert df["target"].tolist() == [0, 1, 0] + + def test_write_mixed_types(self, tmp_path, mock_trainer, mock_system): + """Test writing with mixed data types.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target", "id"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = { + "pred": torch.tensor([0.1, 0.2]), + "target": [1, 0], # List + "id": ["sample1", "sample2"], # Strings + } + + writer.write(outputs, None, 0, 0) + + # Close file to flush writes + writer._csv_file.close() + + temp_file = tmp_path / "results.tmp_rank0.csv" + df = pd.read_csv(temp_file) + + assert len(df) == 2 + assert df["id"].tolist() == ["sample1", "sample2"] + + def test_write_inconsistent_lengths(self, tmp_path, mock_trainer, mock_system): + """Test error when outputs have inconsistent lengths.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = { + "pred": torch.tensor([0.1, 0.2, 0.3]), # 3 items + "target": torch.tensor([0, 1]), # 2 items + } + + with pytest.raises(ValueError, match="inconsistent lengths"): + writer.write(outputs, None, 0, 0) + + def test_write_missing_key(self, tmp_path, mock_trainer, mock_system): + """Test error when required key is missing.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + outputs = {"pred": torch.tensor([0.1, 0.2])} # Missing "target" + + with pytest.raises(KeyError, match="expected key 'target'"): + writer.write(outputs, None, 0, 0) + + def test_get_sequence_length_tensor(self): + """Test _get_sequence_length with tensor.""" + writer = CsvWriter(path="test.csv", keys=["pred"]) + + # Scalar tensor + assert writer._get_sequence_length(torch.tensor(5.0)) == 1 + + # Vector tensor + assert writer._get_sequence_length(torch.tensor([1, 2, 3])) == 3 + + def test_get_sequence_length_list(self): + """Test _get_sequence_length with list.""" + writer = CsvWriter(path="test.csv", keys=["pred"]) + assert writer._get_sequence_length([1, 2, 3, 4]) == 4 + + def test_get_sequence_length_non_sequence(self): + """Test _get_sequence_length with non-sequence.""" + writer = CsvWriter(path="test.csv", keys=["pred"]) + assert writer._get_sequence_length("string") is None + + def test_get_record_value_tensor(self): + """Test _get_record_value with tensor.""" + writer = CsvWriter(path="test.csv", keys=["pred"]) + + # Scalar tensor + value = torch.tensor(5.0) + assert writer._get_record_value(value, 0) == 5.0 + + # Vector tensor + value = torch.tensor([1.0, 2.0, 3.0]) + assert writer._get_record_value(value, 1) == 2.0 + + def test_get_record_value_list(self): + """Test _get_record_value with list.""" + writer = CsvWriter(path="test.csv", keys=["pred"]) + value = [10, 20, 30] + assert writer._get_record_value(value, 2) == 30 + + def test_on_predict_epoch_end_single_rank(self, tmp_path, mock_trainer, mock_system): + """Test epoch end combines temp file and creates final CSV.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Write some data + outputs = {"pred": torch.tensor([0.1, 0.2]), "target": torch.tensor([0, 1])} + writer.write(outputs, None, 0, 0) + + # Trigger epoch end + writer.on_predict_epoch_end(mock_trainer, mock_system) + + # Check final CSV exists + assert (tmp_path / "results.csv").exists() + df = pd.read_csv(tmp_path / "results.csv") + assert len(df) == 2 + + # Check temp file was removed + temp_file = tmp_path / "results.tmp_rank0.csv" + assert not temp_file.exists() + + def test_write_empty_outputs_raises(self, tmp_path, mock_trainer, mock_system): + """Test write raises KeyError when outputs is empty.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Empty outputs - none of the configured keys are present + outputs = {} + + with pytest.raises(KeyError, match="none of the configured keys"): + writer.write(outputs, None, 0, 0) + + def test_write_no_configured_keys_present_raises(self, tmp_path, mock_trainer, mock_system): + """Test write raises KeyError when outputs has keys but none match configured keys.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "target"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Outputs has keys, but none of them match the configured keys + outputs = {"other_key": [1, 2, 3], "another_key": [4, 5, 6]} + + with pytest.raises(KeyError, match="none of the configured keys.*pred.*target.*were found") as exc_info: + writer.write(outputs, None, 0, 0) + + # Verify error message includes available keys + assert "other_key" in str(exc_info.value) or "another_key" in str(exc_info.value) + + def test_write_non_sequence_values(self, tmp_path, mock_trainer, mock_system): + """Test write with non-sequence values (single sample).""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "label"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Single values (not sequences) + outputs = { + "pred": 0.95, # Single float + "label": "positive", # Single string + } + + writer.write(outputs, None, 0, 0) + + # Close file to flush + writer._csv_file.close() + + temp_file = tmp_path / "results.tmp_rank0.csv" + df = pd.read_csv(temp_file) + + assert len(df) == 1 + assert df["pred"].iloc[0] == 0.95 + assert df["label"].iloc[0] == "positive" + + def test_get_record_value_non_sequence(self): + """Test _get_record_value with non-sequence value.""" + writer = CsvWriter(path="test.csv", keys=["pred"]) + value = "constant_value" + # Should return the value as-is for any index + assert writer._get_record_value(value, 0) == "constant_value" + assert writer._get_record_value(value, 5) == "constant_value" + + def test_on_predict_epoch_end_before_setup(self, mock_trainer, mock_system): + """Test on_predict_epoch_end handles case when setup wasn't called.""" + writer = CsvWriter(path="results.csv", keys=["pred"]) + + # Call epoch end without setup - should return without error + writer.on_predict_epoch_end(mock_trainer, mock_system) + + def test_close_file_closes_open_file(self, tmp_path, mock_trainer, mock_system): + """Test _close_file closes the file handle and resets state.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # File should be open after setup + assert writer._csv_file is not None + assert not writer._csv_file.closed + assert writer._csv_writer is not None + + # Call _close_file + writer._close_file() + + # File handle should be None and writer reset + assert writer._csv_file is None + assert writer._csv_writer is None + + def test_close_file_handles_none_file(self, tmp_path): + """Test _close_file handles case when file is already None.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred"]) + + # File is None by default + assert writer._csv_file is None + + # Should not raise + writer._close_file() + + # Still None + assert writer._csv_file is None + + def test_close_file_handles_already_closed_file(self, tmp_path, mock_trainer, mock_system): + """Test _close_file handles already closed file.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Manually close the file + writer._csv_file.close() + assert writer._csv_file.closed + + # Should not raise when calling _close_file on closed file + writer._close_file() + + assert writer._csv_file is None + assert writer._csv_writer is None + + def test_on_exception_closes_file(self, tmp_path, mock_trainer, mock_system): + """Test on_exception closes file to prevent handle leaks.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # File should be open + assert writer._csv_file is not None + assert not writer._csv_file.closed + + # Simulate an exception occurring + writer.on_exception(mock_trainer, mock_system, RuntimeError("Test error")) + + # File should be closed and state reset + assert writer._csv_file is None + assert writer._csv_writer is None + + def test_teardown_closes_file_on_predict_stage(self, tmp_path, mock_trainer, mock_system): + """Test teardown closes file when stage is PREDICT.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # File should be open + assert writer._csv_file is not None + + # Call teardown with PREDICT stage + writer.teardown(mock_trainer, mock_system, Stage.PREDICT) + + # File should be closed + assert writer._csv_file is None + assert writer._csv_writer is None + + def test_teardown_does_not_close_on_other_stages(self, tmp_path, mock_trainer, mock_system): + """Test teardown does not close file for non-PREDICT stages.""" + writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred"]) + writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Store reference to check if still open + csv_file = writer._csv_file + assert csv_file is not None + + # Call teardown with FIT stage - should not close the file + writer.teardown(mock_trainer, mock_system, Stage.FIT) + + # File should still be open (same reference) + assert writer._csv_file is csv_file + assert not writer._csv_file.closed + + # Clean up + writer._close_file() + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestWritersIntegration: + """Integration tests for writers working together.""" + + def test_filewriter_and_csvwriter_together(self, tmp_path, mock_trainer, mock_system): + """Test using FileWriter and CsvWriter on same outputs.""" + file_writer = FileWriter(directory=tmp_path / "predictions", value_key="pred", writer_fn="tensor") + csv_writer = CsvWriter(path=tmp_path / "results.csv", keys=["pred", "confidence"]) + + # Setup both + file_writer.setup(mock_trainer, mock_system, Stage.PREDICT) + csv_writer.setup(mock_trainer, mock_system, Stage.PREDICT) + + # Write same outputs to both + outputs = {"pred": torch.tensor([[1, 2], [3, 4]]), "confidence": torch.tensor([0.9, 0.8])} + + file_writer.write(outputs, None, 0, 0) + csv_writer.write(outputs, None, 0, 0) + + # Check FileWriter created files + assert (tmp_path / "predictions" / "0.pt").exists() + assert (tmp_path / "predictions" / "1.pt").exists() + + # Check CsvWriter created temp file + temp_csv = tmp_path / "results.tmp_rank0.csv" + assert temp_csv.exists() diff --git a/tests/unit/test_callbacks_writers_ddp.py b/tests/unit/test_callbacks_writers_ddp.py new file mode 100644 index 00000000..da216f69 --- /dev/null +++ b/tests/unit/test_callbacks_writers_ddp.py @@ -0,0 +1,341 @@ +"""Unit tests for distributed (DDP) functionality of FileWriter and CsvWriter callbacks. + +These tests emulate DDP on a single device by mocking the distributed environment. +""" + +import csv +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +import torch + +from lighter.callbacks.csv_writer import CsvWriter +from lighter.callbacks.file_writer import FileWriter +from lighter.model import LighterModule +from lighter.utils.types.enums import Stage + + +def make_mock_trainer(global_rank=0, world_size=2, is_global_zero=True): + """Create a mock trainer configured for DDP with proper strategy mocking.""" + trainer = MagicMock() + trainer.global_rank = global_rank + trainer.world_size = world_size + trainer.is_global_zero = is_global_zero + + # Mock strategy with broadcast and barrier + strategy = MagicMock() + strategy.broadcast = lambda x, src=0: x # Return path as-is + strategy.barrier = MagicMock() + strategy.root_device = torch.device("cpu") + trainer.strategy = strategy + + # Mock the predict_loop to avoid the private API issue + trainer.predict_loop = MagicMock() + trainer.predict_loop.num_dataloaders = 1 + trainer.predict_loop._predictions = [[]] + + return trainer + + +@pytest.fixture +def mock_trainer_ddp(): + """Create a mock trainer configured for DDP rank 0.""" + return make_mock_trainer(global_rank=0, world_size=2, is_global_zero=True) + + +@pytest.fixture +def mock_trainer_ddp_rank1(): + """Create a mock trainer for rank 1 in DDP.""" + return make_mock_trainer(global_rank=1, world_size=2, is_global_zero=False) + + +@pytest.fixture +def mock_system(): + """Create a mock LighterModule for testing.""" + return MagicMock(spec=LighterModule) + + +class TestCsvWriterDDP: + """Test suite for CsvWriter in distributed settings.""" + + def test_csv_writer_creates_rank_specific_temp_files(self, tmp_path, mock_trainer_ddp, mock_system): + """Test that CsvWriter creates temporary files with rank suffix.""" + csv_path = tmp_path / "predictions.csv" + writer = CsvWriter(path=csv_path, keys=["pred", "target"]) + + writer.setup(mock_trainer_ddp, mock_system, Stage.PREDICT) + + # Check that temporary file has rank suffix + expected_temp = tmp_path / "predictions.tmp_rank0.csv" + assert writer._temp_path == expected_temp + assert expected_temp.exists() + assert writer._csv_file is not None + + def test_csv_writer_multi_rank_temp_files(self, tmp_path, mock_system): + """Test that different ranks create different temporary files.""" + csv_path = tmp_path / "predictions.csv" + + # Rank 0 + trainer_rank0 = make_mock_trainer(global_rank=0, world_size=2, is_global_zero=True) + writer_rank0 = CsvWriter(path=csv_path, keys=["pred"]) + writer_rank0.setup(trainer_rank0, mock_system, Stage.PREDICT) + + # Rank 1 + trainer_rank1 = make_mock_trainer(global_rank=1, world_size=2, is_global_zero=False) + writer_rank1 = CsvWriter(path=csv_path, keys=["pred"]) + writer_rank1.setup(trainer_rank1, mock_system, Stage.PREDICT) + + # Verify different temp files + assert writer_rank0._temp_path == tmp_path / "predictions.tmp_rank0.csv" + assert writer_rank1._temp_path == tmp_path / "predictions.tmp_rank1.csv" + assert writer_rank0._temp_path != writer_rank1._temp_path + + def test_csv_writer_merge_without_dist(self, tmp_path, mock_trainer_ddp, mock_system): + """Test CSV merging when distributed is not initialized (single process).""" + csv_path = tmp_path / "predictions.csv" + writer = CsvWriter(path=csv_path, keys=["pred", "target"]) + + writer.setup(mock_trainer_ddp, mock_system, Stage.PREDICT) + + # Save temp path before it's cleared + temp_path = writer._temp_path + + # Write some predictions + outputs = {"pred": torch.tensor([1, 2, 3]), "target": torch.tensor([4, 5, 6])} + writer.write(outputs, None, 0, 0) + + # Mock dist.is_initialized() to return False (non-distributed) + with patch("torch.distributed.is_initialized", return_value=False): + writer.on_predict_epoch_end(mock_trainer_ddp, mock_system) + + # Verify final CSV exists + assert csv_path.exists() + + # Verify content + df = pd.read_csv(csv_path) + assert len(df) == 3 + assert list(df.columns) == ["pred", "target"] + assert df["pred"].tolist() == [1, 2, 3] + assert df["target"].tolist() == [4, 5, 6] + + # Verify temp file cleaned up + assert not temp_path.exists() + + # Verify writer state reset + assert writer._temp_path is None + + def test_csv_writer_merge_with_dist(self, tmp_path, mock_trainer_ddp, mock_system): + """Test CSV merging when distributed is initialized (simulated DDP).""" + csv_path = tmp_path / "predictions.csv" + writer = CsvWriter(path=csv_path, keys=["pred", "target"]) + + writer.setup(mock_trainer_ddp, mock_system, Stage.PREDICT) + + # Save temp path before it's cleared + temp_path_rank0 = writer._temp_path + + # Write predictions for rank 0 + outputs_rank0 = {"pred": torch.tensor([1, 2]), "target": torch.tensor([4, 5])} + writer.write(outputs_rank0, None, 0, 0) + + # Close the file to simulate end of predictions + writer._csv_file.close() + + # Create a second temporary file to simulate rank 1 + temp_path_rank1 = tmp_path / "predictions.tmp_rank1.csv" + with open(temp_path_rank1, "w", newline="") as f: + csv_writer = csv.writer(f) + csv_writer.writerow(["pred", "target"]) + csv_writer.writerow([3, 6]) + + # Mock distributed gathering + temp_paths = [temp_path_rank0, temp_path_rank1] + + with ( + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.all_gather_object") as mock_gather, + ): + # Simulate all_gather_object by filling the list + def side_effect(tensor_list, tensor): + tensor_list[:] = temp_paths + + mock_gather.side_effect = side_effect + + writer.on_predict_epoch_end(mock_trainer_ddp, mock_system) + + # Verify final CSV exists and contains data from both ranks + assert csv_path.exists() + df = pd.read_csv(csv_path) + assert len(df) == 3 # 2 from rank 0 + 1 from rank 1 + assert list(df.columns) == ["pred", "target"] + assert sorted(df["pred"].tolist()) == [1, 2, 3] + assert sorted(df["target"].tolist()) == [4, 5, 6] + + # Verify temp files cleaned up (rank 0 is global_zero, so it cleans) + assert not temp_path_rank0.exists() + assert not temp_path_rank1.exists() + + # Verify writer state reset + assert writer._temp_path is None + + def test_csv_writer_rank1_does_not_cleanup(self, tmp_path, mock_trainer_ddp_rank1, mock_system): + """Test that non-zero ranks don't perform cleanup.""" + csv_path = tmp_path / "predictions.csv" + writer = CsvWriter(path=csv_path, keys=["pred"]) + + writer.setup(mock_trainer_ddp_rank1, mock_system, Stage.PREDICT) + + # Save temp path before it's cleared + temp_path_rank1 = writer._temp_path + + # Write predictions for rank 1 + outputs = {"pred": torch.tensor([7, 8])} + writer.write(outputs, None, 0, 0) + + # Close the file + writer._csv_file.close() + + # Mock distributed gathering + temp_path_rank0 = tmp_path / "predictions.tmp_rank0.csv" + temp_paths = [temp_path_rank0, temp_path_rank1] + + # Create a fake rank 0 file + temp_path_rank0.write_text("pred\n1\n2\n") + + with ( + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.all_gather_object") as mock_gather, + ): + + def side_effect(tensor_list, tensor): + tensor_list[:] = temp_paths + + mock_gather.side_effect = side_effect + + writer.on_predict_epoch_end(mock_trainer_ddp_rank1, mock_system) + + # Rank 1 should NOT create the final CSV (only rank 0 does) + assert not csv_path.exists() + + # Temp files should still exist (rank 1 doesn't clean up) + assert temp_path_rank1.exists() + assert temp_path_rank0.exists() + + # Verify writer state reset + assert writer._temp_path is None + + # Cleanup for test + temp_path_rank1.unlink() + temp_path_rank0.unlink() + + +class TestFileWriterDDP: + """Test suite for FileWriter in distributed settings.""" + + def test_file_writer_creates_rank_directories(self, tmp_path, mock_system): + """Test that FileWriter can work with rank-specific directories.""" + # Create a writer that could use rank in directory structure + writer = FileWriter(directory=tmp_path / "rank_0", value_key="pred", writer_fn="tensor") + + trainer_rank0 = make_mock_trainer(global_rank=0, world_size=2, is_global_zero=True) + writer.setup(trainer_rank0, mock_system, Stage.PREDICT) + + # Verify directory created + assert writer.path.exists() + assert writer.path.is_dir() + + def test_file_writer_barrier_synchronization(self, tmp_path, mock_system): + """Test that FileWriter calls barrier for synchronization.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + + # Create a trainer with a mock strategy that tracks barrier calls + trainer = make_mock_trainer(global_rank=0, world_size=2, is_global_zero=True) + + writer.setup(trainer, mock_system, Stage.PREDICT) + + # Verify barrier was called on the strategy + trainer.strategy.barrier.assert_called_once() + + def test_file_writer_no_barrier_without_dist(self, tmp_path, mock_trainer_ddp, mock_system): + """Test that FileWriter doesn't call barrier when distributed is not initialized.""" + writer = FileWriter(directory=tmp_path, value_key="pred", writer_fn="tensor") + + with patch("torch.distributed.is_initialized", return_value=False), patch("torch.distributed.barrier") as mock_barrier: + writer.setup(mock_trainer_ddp, mock_system, Stage.PREDICT) + + # Barrier should not be called + mock_barrier.assert_not_called() + + +class TestDistributedEdgeCases: + """Test edge cases in distributed settings.""" + + def test_csv_writer_handles_empty_rank(self, tmp_path, mock_trainer_ddp, mock_system): + """Test that CSV writer handles ranks with no predictions.""" + csv_path = tmp_path / "predictions.csv" + writer = CsvWriter(path=csv_path, keys=["pred"]) + + writer.setup(mock_trainer_ddp, mock_system, Stage.PREDICT) + + # Don't write anything (simulating empty rank) + writer._csv_file.close() + + # Create another rank with data + temp_path_rank1 = tmp_path / "predictions.tmp_rank1.csv" + with open(temp_path_rank1, "w", newline="") as f: + csv_writer = csv.writer(f) + csv_writer.writerow(["pred"]) + csv_writer.writerow([1]) + csv_writer.writerow([2]) + + temp_paths = [writer._temp_path, temp_path_rank1] + + with ( + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.all_gather_object") as mock_gather, + ): + + def side_effect(tensor_list, tensor): + tensor_list[:] = temp_paths + + mock_gather.side_effect = side_effect + + writer.on_predict_epoch_end(mock_trainer_ddp, mock_system) + + # Final CSV should have only rank 1's data + assert csv_path.exists() + df = pd.read_csv(csv_path) + assert len(df) == 2 + assert df["pred"].tolist() == [1, 2] + + def test_csv_writer_handles_none_paths(self, tmp_path, mock_trainer_ddp, mock_system): + """Test that CSV writer handles None paths in gathered list.""" + csv_path = tmp_path / "predictions.csv" + writer = CsvWriter(path=csv_path, keys=["pred"]) + + writer.setup(mock_trainer_ddp, mock_system, Stage.PREDICT) + outputs = {"pred": torch.tensor([1, 2])} + writer.write(outputs, None, 0, 0) + writer._csv_file.close() + + # Simulate gathering with some None values (failed/missing ranks) + temp_paths = [writer._temp_path, None] + + with ( + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.all_gather_object") as mock_gather, + ): + + def side_effect(tensor_list, tensor): + tensor_list[:] = temp_paths + + mock_gather.side_effect = side_effect + + writer.on_predict_epoch_end(mock_trainer_ddp, mock_system) + + # Should still work with None paths filtered out + assert csv_path.exists() + df = pd.read_csv(csv_path) + assert len(df) == 2 + assert df["pred"].tolist() == [1, 2] diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 00000000..ebdf573d --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,86 @@ +"""Tests for LighterDataModule.""" + +import pytest +import torch +from torch.utils.data import DataLoader, TensorDataset + +from lighter import LighterDataModule + + +class TestLighterDataModule: + """Test suite for LighterDataModule.""" + + @pytest.fixture + def sample_dataset(self): + """Create a simple tensor dataset for testing.""" + x = torch.randn(100, 10) + y = torch.randint(0, 2, (100,)) + return TensorDataset(x, y) + + @pytest.fixture + def train_dataloader(self, sample_dataset): + """Create a training dataloader.""" + return DataLoader(sample_dataset, batch_size=32, shuffle=True) + + @pytest.fixture + def val_dataloader(self, sample_dataset): + """Create a validation dataloader.""" + return DataLoader(sample_dataset, batch_size=32, shuffle=False) + + def test_initialization_all_dataloaders(self, train_dataloader, val_dataloader): + """Test initialization with all dataloaders.""" + datamodule = LighterDataModule( + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + test_dataloader=val_dataloader, + predict_dataloader=val_dataloader, + ) + + assert datamodule.train_dataloader() is train_dataloader + assert datamodule.val_dataloader() is val_dataloader + assert datamodule.test_dataloader() is val_dataloader + assert datamodule.predict_dataloader() is val_dataloader + + def test_initialization_partial_dataloaders(self, train_dataloader): + """Test initialization with only some dataloaders.""" + datamodule = LighterDataModule(train_dataloader=train_dataloader) + + assert datamodule.train_dataloader() is train_dataloader + assert datamodule.val_dataloader() is None + assert datamodule.test_dataloader() is None + assert datamodule.predict_dataloader() is None + + def test_initialization_no_dataloaders(self): + """Test initialization with no dataloaders.""" + datamodule = LighterDataModule() + + assert datamodule.train_dataloader() is None + assert datamodule.val_dataloader() is None + assert datamodule.test_dataloader() is None + assert datamodule.predict_dataloader() is None + + def test_is_lightning_datamodule(self): + """Test that LighterDataModule is a LightningDataModule.""" + from pytorch_lightning import LightningDataModule + + datamodule = LighterDataModule() + assert isinstance(datamodule, LightningDataModule) + + def test_dataloader_returns_original_instance(self, train_dataloader): + """Test that dataloaders return the exact instance passed in.""" + datamodule = LighterDataModule(train_dataloader=train_dataloader) + + # Should be the same instance, not a copy + assert datamodule.train_dataloader() is train_dataloader + + def test_batch_iteration(self, train_dataloader): + """Test that we can iterate over batches from the datamodule.""" + datamodule = LighterDataModule(train_dataloader=train_dataloader) + + # Should be able to iterate + dataloader = datamodule.train_dataloader() + batch = next(iter(dataloader)) + + assert len(batch) == 2 # x, y + assert batch[0].shape[0] == 32 # batch size + assert batch[0].shape[1] == 10 # input dim diff --git a/tests/unit/test_engine_cli.py b/tests/unit/test_engine_cli.py index ec159f2a..a8733806 100644 --- a/tests/unit/test_engine_cli.py +++ b/tests/unit/test_engine_cli.py @@ -22,15 +22,10 @@ def temp_config_file(self): _target_: pytorch_lightning.Trainer max_epochs: 10 -system: - _target_: lighter.System - model: +model: + _target_: lighter.LighterModule + network: _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} - test: {} - predict: {} """) config_path = f.name yield config_path @@ -49,7 +44,7 @@ def test_cli_fit_command_basic(self, temp_config_file): # Verify Runner was instantiated mock_runner_class.assert_called_once() # Verify run was called with correct arguments - mock_runner.run.assert_called_once_with("fit", temp_config_file, []) + mock_runner.run.assert_called_once_with("fit", [temp_config_file]) def test_cli_fit_command_with_overrides(self, temp_config_file): """Test fit command with CLI overrides.""" @@ -70,8 +65,7 @@ def test_cli_fit_command_with_overrides(self, temp_config_file): # Verify run was called with overrides mock_runner.run.assert_called_once_with( "fit", - temp_config_file, - ["trainer::max_epochs=5", "trainer::devices=2"], + [temp_config_file, "trainer::max_epochs=5", "trainer::devices=2"], ) def test_cli_validate_command(self, temp_config_file): @@ -84,7 +78,7 @@ def test_cli_validate_command(self, temp_config_file): cli() - mock_runner.run.assert_called_once_with("validate", temp_config_file, []) + mock_runner.run.assert_called_once_with("validate", [temp_config_file]) def test_cli_validate_command_with_overrides(self, temp_config_file): """Test validate command with overrides.""" @@ -92,7 +86,7 @@ def test_cli_validate_command_with_overrides(self, temp_config_file): "lighter", "validate", temp_config_file, - "system::model::weights=checkpoint.ckpt", + "model::model::weights=checkpoint.ckpt", ] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: @@ -103,8 +97,7 @@ def test_cli_validate_command_with_overrides(self, temp_config_file): mock_runner.run.assert_called_once_with( "validate", - temp_config_file, - ["system::model::weights=checkpoint.ckpt"], + [temp_config_file, "model::model::weights=checkpoint.ckpt"], ) def test_cli_test_command(self, temp_config_file): @@ -117,7 +110,7 @@ def test_cli_test_command(self, temp_config_file): cli() - mock_runner.run.assert_called_once_with("test", temp_config_file, []) + mock_runner.run.assert_called_once_with("test", [temp_config_file]) def test_cli_test_command_with_overrides(self, temp_config_file): """Test test command with overrides.""" @@ -126,7 +119,7 @@ def test_cli_test_command_with_overrides(self, temp_config_file): "test", temp_config_file, "trainer::devices=1", - "system::model::dropout=0.5", + "model::model::dropout=0.5", ] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: @@ -137,8 +130,7 @@ def test_cli_test_command_with_overrides(self, temp_config_file): mock_runner.run.assert_called_once_with( "test", - temp_config_file, - ["trainer::devices=1", "system::model::dropout=0.5"], + [temp_config_file, "trainer::devices=1", "model::model::dropout=0.5"], ) def test_cli_predict_command(self, temp_config_file): @@ -151,7 +143,7 @@ def test_cli_predict_command(self, temp_config_file): cli() - mock_runner.run.assert_called_once_with("predict", temp_config_file, []) + mock_runner.run.assert_called_once_with("predict", [temp_config_file]) def test_cli_predict_command_with_overrides(self, temp_config_file): """Test predict command with overrides.""" @@ -159,7 +151,7 @@ def test_cli_predict_command_with_overrides(self, temp_config_file): "lighter", "predict", temp_config_file, - "system::model::weights=best.ckpt", + "model::model::weights=best.ckpt", "trainer::devices=4", ] @@ -171,8 +163,7 @@ def test_cli_predict_command_with_overrides(self, temp_config_file): mock_runner.run.assert_called_once_with( "predict", - temp_config_file, - ["system::model::weights=best.ckpt", "trainer::devices=4"], + [temp_config_file, "model::model::weights=best.ckpt", "trainer::devices=4"], ) def test_cli_missing_command(self): @@ -220,8 +211,8 @@ def test_cli_multiple_overrides(self, temp_config_file): temp_config_file, "trainer::max_epochs=100", "trainer::devices=2", - "system::optimizer::lr=0.001", - "system::optimizer::weight_decay=0.0001", + "model::optimizer::lr=0.001", + "model::optimizer::weight_decay=0.0001", ] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: @@ -230,16 +221,17 @@ def test_cli_multiple_overrides(self, temp_config_file): cli() - # Verify all overrides are passed + # Verify all inputs are passed (config file + overrides) mock_runner.run.assert_called_once() _, args, kwargs = mock_runner.run.mock_calls[0] assert args[0] == "fit" - assert args[1] == temp_config_file - assert len(args[2]) == 4 - assert "trainer::max_epochs=100" in args[2] - assert "trainer::devices=2" in args[2] - assert "system::optimizer::lr=0.001" in args[2] - assert "system::optimizer::weight_decay=0.0001" in args[2] + assert isinstance(args[1], list) + assert args[1][0] == temp_config_file # First item is config file + assert len(args[1]) == 5 # 1 config file + 4 overrides + assert "trainer::max_epochs=100" in args[1] + assert "trainer::devices=2" in args[1] + assert "model::optimizer::lr=0.001" in args[1] + assert "model::optimizer::weight_decay=0.0001" in args[1] def test_cli_all_stages_independent(self, temp_config_file): """Test that each stage command is independent.""" @@ -255,11 +247,11 @@ def test_cli_all_stages_independent(self, temp_config_file): cli() # Verify correct stage is called - mock_runner.run.assert_called_once_with(stage, temp_config_file, []) + mock_runner.run.assert_called_once_with(stage, [temp_config_file]) - def test_cli_comma_separated_configs_as_single_arg(self): - """Test that comma-separated config paths work as a single argument.""" - test_args = ["lighter", "fit", "config1.yaml,config2.yaml"] + def test_cli_multiple_configs_as_separate_args(self): + """Test that multiple config paths work as separate arguments.""" + test_args = ["lighter", "fit", "config1.yaml", "config2.yaml"] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: mock_runner = MagicMock() @@ -267,11 +259,10 @@ def test_cli_comma_separated_configs_as_single_arg(self): cli() - # Runner should receive the comma-separated string as-is + # Runner should receive both configs in the inputs list mock_runner.run.assert_called_once_with( "fit", - "config1.yaml,config2.yaml", - [], + ["config1.yaml", "config2.yaml"], ) def test_cli_override_with_equals_in_value(self, temp_config_file): @@ -280,7 +271,7 @@ def test_cli_override_with_equals_in_value(self, temp_config_file): "lighter", "fit", temp_config_file, - "system::model::config=key1=value1", + "model::model::config=key1=value1", ] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: @@ -292,7 +283,7 @@ def test_cli_override_with_equals_in_value(self, temp_config_file): # Verify override is passed correctly mock_runner.run.assert_called_once() _, args, kwargs = mock_runner.run.mock_calls[0] - assert "system::model::config=key1=value1" in args[2] + assert "model::model::config=key1=value1" in args[1] def test_cli_override_with_special_characters(self, temp_config_file): """Test overrides with special characters.""" @@ -300,7 +291,7 @@ def test_cli_override_with_special_characters(self, temp_config_file): "lighter", "fit", temp_config_file, - "system::model::name=my-model_v1.0", + "model::model::name=my-model_v1.0", ] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: @@ -311,10 +302,10 @@ def test_cli_override_with_special_characters(self, temp_config_file): mock_runner.run.assert_called_once() _, args, kwargs = mock_runner.run.mock_calls[0] - assert "system::model::name=my-model_v1.0" in args[2] + assert "model::model::name=my-model_v1.0" in args[1] def test_cli_no_overrides(self, temp_config_file): - """Test that no overrides results in empty list.""" + """Test that no overrides results in list with just config file.""" test_args = ["lighter", "fit", temp_config_file] with patch.object(sys, "argv", test_args), patch("lighter.engine.runner.Runner") as mock_runner_class: @@ -325,4 +316,5 @@ def test_cli_no_overrides(self, temp_config_file): mock_runner.run.assert_called_once() _, args, kwargs = mock_runner.run.mock_calls[0] - assert args[2] == [] # Empty overrides list + assert args[0] == "fit" + assert args[1] == [temp_config_file] # Just the config file, no overrides diff --git a/tests/unit/test_engine_runner.py b/tests/unit/test_engine_runner.py index 215d1783..5941734f 100644 --- a/tests/unit/test_engine_runner.py +++ b/tests/unit/test_engine_runner.py @@ -1,28 +1,40 @@ """Unit tests for the Runner class in lighter/engine/runner.py""" -from copy import deepcopy -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, mock_open, patch import pytest from pytorch_lightning import Trainer from sparkwheel import Config from lighter.engine.runner import Runner -from lighter.system import System +from lighter.model import LighterModule from lighter.utils.types.enums import Stage @pytest.fixture -def mock_system(): - """Fixture providing a mock System instance.""" - system = MagicMock(spec=System) - system.save_hyperparameters = MagicMock() - return system +def mock_model(): + """Fixture providing a mock LightningModule instance.""" + model = MagicMock(spec=LighterModule) + model.save_hyperparameters = MagicMock() + return model @pytest.fixture -def mock_trainer(): - """Fixture providing a mock Trainer instance.""" +def mock_datamodule(): + """Fixture providing a mock LightningDataModule instance.""" + from pytorch_lightning import LightningDataModule + + datamodule = MagicMock(spec=LightningDataModule) + return datamodule + + +@pytest.fixture +def mock_runner_trainer(): + """Fixture providing a mock Trainer instance for runner tests. + + Note: Named differently from the shared mock_trainer to avoid conflicts + since this has runner-specific configuration. + """ trainer = MagicMock(spec=Trainer) trainer.logger = MagicMock() trainer.logger.log_hyperparams = MagicMock() @@ -41,21 +53,17 @@ def base_config(): "_target_": "pytorch_lightning.Trainer", "max_epochs": 10, }, - "system": { - "_target_": "lighter.system.System", + "model": { + "_target_": "lighter.LighterModule", "model": {"_target_": "torch.nn.Identity"}, "optimizer": {"_target_": "torch.optim.Adam", "lr": 0.001}, - "dataloaders": { - "train": {"batch_size": 32}, - "val": {"batch_size": 32}, - "test": {"batch_size": 32}, - "predict": {"batch_size": 32}, - }, - "metrics": { - "train": [], - "val": [], - "test": [], - }, + }, + "data": { + "_target_": "lighter.LighterDataModule", + "train_dataloader": {"batch_size": 32}, + "val_dataloader": {"batch_size": 32}, + "test_dataloader": {"batch_size": 32}, + "predict_dataloader": {"batch_size": 32}, }, "args": { "fit": {"some_arg": "value"}, @@ -73,158 +81,306 @@ def runner(): def test_runner_initialization(runner): - """Test that Runner initializes with correct default values.""" - assert runner.config is None - assert runner.system is None - assert runner.trainer is None + """Test that Runner initializes correctly.""" + # Runner no longer stores state, just verify it exists + assert runner is not None def test_runner_applies_overrides(runner, base_config): """Test that CLI overrides are applied correctly.""" - overrides = ["trainer::max_epochs=100"] + # Combine config and overrides into single inputs list + inputs = [base_config, "trainer::max_epochs=100"] + + # Mock all methods that would cause side effects + with ( + patch.object(runner, "_resolve_model") as mock_resolve_model, + patch.object(runner, "_resolve_trainer") as mock_resolve_trainer, + patch.object(runner, "_resolve_datamodule") as mock_resolve_datamodule, + patch.object(runner, "_save_config") as mock_save_config, + patch.object(runner, "_execute") as mock_execute, + ): + mock_model = MagicMock() + mock_resolve_model.return_value = mock_model + + mock_trainer = MagicMock() + mock_resolve_trainer.return_value = mock_trainer - # Mock the setup and execute to avoid needing real models - with patch.object(runner, "_setup") as mock_setup, patch.object(runner, "_execute") as mock_execute: - runner.run(Stage.FIT, base_config, overrides) + mock_datamodule = MagicMock() + mock_resolve_datamodule.return_value = mock_datamodule - # Verify override was applied - assert runner.config.get("trainer::max_epochs") == 100 + runner.run(Stage.FIT, inputs) # Verify methods were called - mock_setup.assert_called_once() + mock_resolve_model.assert_called_once() + mock_resolve_trainer.assert_called_once() + mock_resolve_datamodule.assert_called_once() + mock_save_config.assert_called_once() mock_execute.assert_called_once() -def test_prune_removes_unused_modes(runner, base_config): - """Test that pruning removes unused dataloaders/metrics for each stage.""" - # Load config - use deepcopy since Config.load modifies the dict - runner.config = Config.load(deepcopy(base_config)) - - # Test FIT stage pruning - runner._prune_for_stage(Stage.FIT) - system = runner.config.get("system", {}) - dataloaders = system.get("dataloaders", {}) - metrics = system.get("metrics", {}) - - # FIT keeps train and val - assert "train" in dataloaders - assert "val" in dataloaders - assert "test" not in dataloaders - assert "predict" not in dataloaders - assert "train" in metrics - assert "val" in metrics - assert "test" not in metrics - - # Reset config and test TEST stage - use fresh copy - runner.config = Config.load(deepcopy(base_config)) - runner._prune_for_stage(Stage.TEST) - system = runner.config.get("system", {}) - dataloaders = system.get("dataloaders", {}) - metrics = system.get("metrics", {}) - - # TEST stage should only have test dataloader - assert "test" in dataloaders - assert "train" not in dataloaders - assert "val" not in dataloaders - assert "predict" not in dataloaders - - -def test_prune_removes_optimizer_for_non_fit(runner, base_config): - """Test that optimizer/scheduler are removed for non-FIT stages.""" - runner.config = Config.load(deepcopy(base_config)) - - # Test VALIDATE stage - runner._prune_for_stage(Stage.VALIDATE) - system = runner.config.get("system", {}) - - # VALIDATE removes optimizer/scheduler but keeps criterion - assert "optimizer" not in system - assert "scheduler" not in system - - # Test TEST stage - use fresh copy - runner.config = Config.load(deepcopy(base_config)) - runner._prune_for_stage(Stage.TEST) - system = runner.config.get("system", {}) - - # TEST removes everything - assert "optimizer" not in system - assert "scheduler" not in system - - -def test_setup_with_invalid_system(runner, mock_trainer): - """Test that _setup raises error for invalid system type.""" - # Create config that resolves to non-System object (just a dict) +def test_setup_with_invalid_model(runner, mock_trainer): + """Test that _resolve_model raises error for invalid model type.""" + # Create config that resolves to non-LightningModule object (just a dict) bad_config = { "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": {"_target_": "builtins.dict"}, # This will resolve to dict type, not System + "model": {"_target_": "builtins.dict"}, # This will resolve to dict type, not LightningModule } - runner.config = Config.load(bad_config) + config = Config().update(bad_config) - # Mock trainer resolution to avoid needing real trainer - with patch.object(runner.config, "resolve") as mock_resolve: - mock_resolve.side_effect = [dict(), mock_trainer] # First call returns dict instead of System + # Mock resolve to return dict instead of LightningModule + with patch.object(config, "resolve") as mock_resolve: + mock_resolve.return_value = dict() # Returns dict instead of LightningModule - with pytest.raises(TypeError, match="system must be System"): - runner._setup(Stage.FIT) + with pytest.raises(TypeError, match="model must be LightningModule or LighterModule"): + runner._resolve_model(config) -def test_setup_with_invalid_trainer(runner, mock_system): - """Test that _setup raises error for invalid trainer type.""" +def test_setup_with_invalid_trainer(runner, mock_model): + """Test that _resolve_trainer raises error for invalid trainer type.""" # Create simple config bad_config = { "trainer": {"_target_": "builtins.dict"}, # This will resolve to dict type - "system": {"_target_": "lighter.system.System"}, + "model": {"_target_": "lighter.LighterModule"}, } - runner.config = Config.load(bad_config) + config = Config().update(bad_config) - # Mock system and trainer resolution - with patch.object(runner.config, "resolve") as mock_resolve: - mock_resolve.side_effect = [mock_system, dict()] # Second call returns dict instead of Trainer + # Mock resolve to return dict instead of Trainer + with patch.object(config, "resolve") as mock_resolve: + mock_resolve.return_value = dict() # Returns dict instead of Trainer with pytest.raises(TypeError, match="trainer must be Trainer"): - runner._setup(Stage.FIT) - - -@patch("lighter.engine.runner.import_module_from_path") -def test_setup_with_project(mock_import, runner, base_config, mock_system, mock_trainer): - """Test that _setup correctly imports project module.""" - config_with_project = base_config.copy() - config_with_project["project"] = "path/to/project" - - runner.config = Config.load(config_with_project) - - # Mock resolve to return our mocks - with patch.object(runner.config, "resolve") as mock_resolve: - mock_resolve.side_effect = [mock_system, mock_trainer] + runner._resolve_trainer(config) - runner._setup(Stage.FIT) - mock_import.assert_called_once_with("project", "path/to/project") - -def test_execute_calls_stage_method(runner, mock_system, mock_trainer): +def test_execute_calls_stage_method(runner, mock_model, mock_trainer, mock_datamodule): """Test that _execute calls the correct trainer method.""" - runner.config = MagicMock() - runner.config.resolve.return_value = {"some_arg": "value"} - runner.system = mock_system - runner.trainer = mock_trainer - # Test fit - runner._execute(Stage.FIT) - mock_trainer.fit.assert_called_once_with(mock_system, some_arg="value") + runner._execute(Stage.FIT, mock_model, mock_trainer, mock_datamodule) + mock_trainer.fit.assert_called_once_with(mock_model, datamodule=mock_datamodule) mock_trainer.reset_mock() # Test validate - runner._execute(Stage.VALIDATE) - mock_trainer.validate.assert_called_once_with(mock_system, some_arg="value") + runner._execute(Stage.VALIDATE, mock_model, mock_trainer, mock_datamodule) + mock_trainer.validate.assert_called_once_with(mock_model, datamodule=mock_datamodule) mock_trainer.reset_mock() # Test test - runner._execute(Stage.TEST) - mock_trainer.test.assert_called_once_with(mock_system, some_arg="value") + runner._execute(Stage.TEST, mock_model, mock_trainer, mock_datamodule) + mock_trainer.test.assert_called_once_with(mock_model, datamodule=mock_datamodule) mock_trainer.reset_mock() # Test predict - runner._execute(Stage.PREDICT) - mock_trainer.predict.assert_called_once_with(mock_system, some_arg="value") + runner._execute(Stage.PREDICT, mock_model, mock_trainer, mock_datamodule) + mock_trainer.predict.assert_called_once_with(mock_model, datamodule=mock_datamodule) + + +# Tests for auto-discovery feature + + +def test_auto_discover_project_with_marker(runner, tmp_path, monkeypatch): + """Test that ProjectImporter finds __lighter__.py marker file.""" + import sys + + from lighter.engine.runner import ProjectImporter + + # Create a project directory with marker file and __init__.py + project_dir = tmp_path / "my_project" + project_dir.mkdir() + marker_file = project_dir / "__lighter__.py" + marker_file.touch() + init_file = project_dir / "__init__.py" + init_file.touch() + + # Change working directory to project + monkeypatch.chdir(project_dir) + + # Clean up any previous "project" module to avoid conflicts + sys.modules.pop("project", None) + + # Test auto-discovery + importer = ProjectImporter() + found = importer.auto_discover_and_import() + + assert found is True + + # Cleanup + sys.modules.pop("project", None) + + +def test_auto_discover_project_without_marker(runner, tmp_path, monkeypatch): + """Test that ProjectImporter returns False when marker is absent.""" + from lighter.engine.runner import ProjectImporter + + # Create a directory without marker file + project_dir = tmp_path / "not_a_project" + project_dir.mkdir() + + # Change working directory to directory without marker + monkeypatch.chdir(project_dir) + + # Test auto-discovery + importer = ProjectImporter() + found = importer.auto_discover_and_import() + + assert found is False + + +@patch("lighter.engine.runner.import_module_from_path") +def test_setup_with_auto_discovery(mock_import, runner, base_config, mock_model, mock_trainer, tmp_path, monkeypatch): + """Test that ProjectImporter auto-discovers and imports project.""" + from lighter.engine.runner import ProjectImporter + + # Create project directory with marker + project_dir = tmp_path / "auto_project" + project_dir.mkdir() + (project_dir / "__lighter__.py").touch() + monkeypatch.chdir(project_dir) + + # Test auto-discovery and import + result = ProjectImporter.auto_discover_and_import() + + # Verify auto-discovered project was imported as 'project' + assert result is True + mock_import.assert_called_once_with("project", project_dir) + + +@patch("lighter.engine.runner.import_module_from_path") +def test_setup_without_project_or_discovery(mock_import, runner, base_config, mock_model, mock_trainer, tmp_path, monkeypatch): + """Test that ProjectImporter works without project module (plain Lightning).""" + from lighter.engine.runner import ProjectImporter + + # Create directory without marker + no_project_dir = tmp_path / "plain_lightning" + no_project_dir.mkdir() + monkeypatch.chdir(no_project_dir) + + # Test auto-discovery without project + result = ProjectImporter.auto_discover_and_import() + + # Verify no project module was imported + assert result is False + mock_import.assert_not_called() + + +# Tests for configuration saving + + +@patch("builtins.open", new_callable=mock_open) +@patch("lighter.engine.runner.yaml.dump") +def test_save_config_to_trainer_log_dir(mock_yaml_dump, mock_file, runner, base_config, tmp_path, mock_model): + """Test that _save_config saves configuration to trainer's log directory.""" + config = Config().update(base_config) + + # Create mock trainer with log_dir + mock_trainer = MagicMock(spec=Trainer) + log_dir = tmp_path / "lightning_logs" / "version_0" + mock_trainer.log_dir = str(log_dir) + + # Call _save_config with model + runner._save_config(config, mock_trainer, mock_model) + + # Verify file was opened at correct path + expected_path = log_dir / "config.yaml" + mock_file.assert_called_once_with(expected_path, "w") + + # Verify yaml.dump was called with config + mock_yaml_dump.assert_called_once() + call_args = mock_yaml_dump.call_args + assert call_args[0][0] == config.get() + assert call_args[1]["default_flow_style"] is False + assert call_args[1]["sort_keys"] is False + + +def test_save_config_without_log_dir(runner, base_config, mock_model): + """Test that _save_config handles trainer without log_dir gracefully.""" + config = Config().update(base_config) + + # Create mock trainer without log_dir and without logger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = None + mock_trainer.logger = None + + # Should not raise error + runner._save_config(config, mock_trainer, mock_model) + + +def test_save_config_saves_to_model_and_logger(runner, base_config, mock_model): + """Test that _save_config saves to model and logger.""" + config = Config().update(base_config) + + # Create mock trainer with logger (no log_dir to avoid file operations) + mock_trainer = MagicMock(spec=Trainer) + mock_logger = MagicMock() + mock_trainer.logger = mock_logger + mock_trainer.log_dir = None + + # Call _save_config + runner._save_config(config, mock_trainer, mock_model) + + # Verify model.save_hyperparameters was called with config wrapped in dict + mock_model.save_hyperparameters.assert_called_once_with({"config": config.get()}) + + # Verify logger.log_hyperparams was called + mock_logger.log_hyperparams.assert_called_once_with(config.get()) + + +def test_save_config_without_logger(runner, base_config, mock_model): + """Test that _save_config handles trainer without logger.""" + config = Config().update(base_config) + + # Create mock trainer without logger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.logger = None + + # Should save to model but not crash + runner._save_config(config, mock_trainer, mock_model) + + # Verify model.save_hyperparameters was still called + mock_model.save_hyperparameters.assert_called_once_with({"config": config.get()}) + + +# Tests for CLI kwargs handling + + +def test_execute_with_cli_kwargs(runner, mock_model, mock_trainer, mock_datamodule): + """Test that CLI kwargs are passed to trainer method.""" + # Test with CLI kwargs + runner._execute(Stage.FIT, mock_model, mock_trainer, mock_datamodule, ckpt_path="checkpoint.ckpt") + + mock_trainer.fit.assert_called_once_with(mock_model, datamodule=mock_datamodule, ckpt_path="checkpoint.ckpt") + + +def test_execute_without_cli_kwargs(runner, mock_model, mock_trainer, mock_datamodule): + """Test that execute works without CLI kwargs.""" + # Test with no CLI kwargs + runner._execute(Stage.FIT, mock_model, mock_trainer, mock_datamodule) + + mock_trainer.fit.assert_called_once_with(mock_model, datamodule=mock_datamodule) + + +def test_run_passes_kwargs_to_execute(runner, base_config, mock_model, mock_trainer, mock_datamodule): + """Test that Runner.run() passes CLI kwargs to _execute().""" + # Mock all methods that would cause side effects + with ( + patch.object(runner, "_resolve_model") as mock_resolve_model, + patch.object(runner, "_resolve_trainer") as mock_resolve_trainer, + patch.object(runner, "_resolve_datamodule") as mock_resolve_datamodule, + patch.object(runner, "_save_config"), + patch.object(runner, "_execute") as mock_execute, + ): + mock_resolve_model.return_value = mock_model + mock_resolve_trainer.return_value = mock_trainer + mock_resolve_datamodule.return_value = mock_datamodule + + # Run with CLI kwargs + runner.run(Stage.FIT, [base_config], ckpt_path="checkpoint.ckpt", verbose=True) + + # Verify _execute was called with kwargs + mock_execute.assert_called_once() + call_kwargs = mock_execute.call_args[1] + assert "ckpt_path" in call_kwargs + assert call_kwargs["ckpt_path"] == "checkpoint.ckpt" + assert "verbose" in call_kwargs + assert call_kwargs["verbose"] is True diff --git a/tests/unit/test_engine_runner_errors.py b/tests/unit/test_engine_runner_errors.py index be0357cb..e0998068 100644 --- a/tests/unit/test_engine_runner_errors.py +++ b/tests/unit/test_engine_runner_errors.py @@ -1,399 +1,107 @@ -"""Unit tests for error handling in Runner class""" +"""Unit tests for error handling in the Runner class""" import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -import yaml -from sparkwheel import Config +from sparkwheel.utils.exceptions import ConfigKeyError from lighter.engine.runner import Runner from lighter.utils.types.enums import Stage class TestRunnerErrorHandling: - """Tests for Runner error handling and edge cases.""" + """Test class for Runner error handling scenarios.""" - def test_run_with_nonexistent_config_file_raises_error(self): - """Test that non-existent config file raises appropriate error.""" + def test_run_without_config_raises_error(self): + """Test that calling run without config_paths raises ConfigKeyError.""" runner = Runner() - with pytest.raises(FileNotFoundError): - runner.run(Stage.FIT, "nonexistent_config.yaml") + with pytest.raises(ConfigKeyError): # Sparkwheel raises ConfigKeyError for missing keys + runner.run(Stage.FIT, []) - def test_run_with_invalid_yaml_raises_error(self): - """Test that invalid YAML raises appropriate error.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("invalid: yaml: content: [") - config_path = f.name - - try: - runner = Runner() - with pytest.raises(yaml.YAMLError): - runner.run(Stage.FIT, config_path) - finally: - Path(config_path).unlink() + def test_run_with_nonexistent_config_raises_error(self): + """Test that nonexistent config file raises FileNotFoundError.""" + runner = Runner() + with pytest.raises(FileNotFoundError): + runner.run(Stage.FIT, ["/nonexistent/path/config.yaml"]) def test_run_with_empty_config_raises_validation_error(self): - """Test that empty config raises validation error.""" + """Test that empty config raises ConfigKeyError.""" with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: f.write("") # Empty config config_path = f.name try: runner = Runner() - with pytest.raises(ValueError, match="validation failed"): - runner.run(Stage.FIT, config_path) - finally: - Path(config_path).unlink() - - def test_prune_without_loaded_config_raises_error(self): - """Test that pruning without loaded config raises ValueError.""" - runner = Runner() - with pytest.raises(ValueError, match="Config must be loaded"): - runner._prune_for_stage(Stage.FIT) - - def test_setup_without_loaded_config_raises_error(self): - """Test that setup without loaded config raises ValueError.""" - runner = Runner() - with pytest.raises(ValueError, match="Config must be loaded"): - runner._setup(Stage.FIT) - - def test_execute_without_loaded_config_raises_error(self): - """Test that execute without loaded config raises ValueError.""" - runner = Runner() - with pytest.raises(ValueError, match="Config.*must be set up"): - runner._execute(Stage.FIT) - - def test_execute_without_trainer_raises_error(self): - """Test that execute without trainer raises ValueError.""" - runner = Runner() - runner.config = MagicMock() - runner.system = MagicMock() - runner.trainer = None - with pytest.raises(ValueError, match="trainer.*must be set up"): - runner._execute(Stage.FIT) - - def test_execute_without_system_raises_error(self): - """Test that execute without system raises ValueError.""" - runner = Runner() - runner.config = MagicMock() - runner.trainer = MagicMock() - runner.system = None - with pytest.raises(ValueError, match="system.*must be set up"): - runner._execute(Stage.FIT) - - def test_run_with_invalid_override_format_raises_error(self): - """Test that invalid override format raises ValueError.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write(""" -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} -""") - config_path = f.name - - try: - runner = Runner() - # Sparkwheel raises ValueError for invalid override format - with pytest.raises(ValueError, match="Invalid override format"): - runner.run(Stage.FIT, config_path, ["invalid_override_no_equals"]) - finally: - Path(config_path).unlink() - - def test_run_with_invalid_project_path_raises_error(self): - """Test that invalid project path raises FileNotFoundError.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write(""" -project: /nonexistent/path/to/project - -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} -""") - config_path = f.name - - try: - runner = Runner() - with pytest.raises(FileNotFoundError): - runner.run(Stage.FIT, config_path) - finally: - Path(config_path).unlink() - - def test_run_with_conflicting_overrides(self): - """Test behavior with conflicting CLI overrides (last one wins).""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write(""" -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 10 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} -""") - config_path = f.name - - try: - runner = Runner() - overrides = [ - "trainer::max_epochs=5", - "trainer::max_epochs=20", # Conflicting - should override previous - ] - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, config_path, overrides) - # Last override should win - assert runner.config.get("trainer::max_epochs") == 20 + # Empty config should raise ConfigKeyError when trying to resolve 'model' + with pytest.raises(ConfigKeyError): + runner.run(Stage.FIT, [config_path]) finally: Path(config_path).unlink() - def test_prune_with_all_stages_defined(self): - """Test pruning behavior when all stages are defined.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "optimizer": {"_target_": "torch.optim.Adam", "lr": 0.001}, - "scheduler": {"_target_": "torch.optim.lr_scheduler.StepLR"}, - "criterion": {"_target_": "torch.nn.MSELoss"}, - "dataloaders": { - "train": {}, - "val": {}, - "test": {}, - "predict": {}, - }, - "metrics": { - "train": [], - "val": [], - "test": [], - }, - "adapters": { - "train": {}, - "val": {}, - "test": {}, - "predict": {}, - }, - }, - "args": { - "fit": {}, - "validate": {}, - "test": {}, - "predict": {}, - }, - } - - runner = Runner() - runner.config = Config.load(config_dict) - - # Test pruning for VALIDATE stage - runner._prune_for_stage(Stage.VALIDATE) - - # VALIDATE keeps: val dataloader, val metrics, criterion - # VALIDATE removes: train/test/predict dataloaders, train/test metrics, optimizer, scheduler - assert runner.config.get("system::dataloaders::val") is not None - assert runner.config.get("system::dataloaders::train") is None - assert runner.config.get("system::dataloaders::test") is None - assert runner.config.get("system::dataloaders::predict") is None - - assert runner.config.get("system::metrics::val") is not None - assert runner.config.get("system::metrics::train") is None - assert runner.config.get("system::metrics::test") is None - - assert runner.config.get("system::optimizer") is None - assert runner.config.get("system::scheduler") is None - assert runner.config.get("system::criterion") is not None # Kept for VALIDATE - - assert runner.config.get("args::validate") is not None - assert runner.config.get("args::fit") is None - assert runner.config.get("args::test") is None - assert runner.config.get("args::predict") is None - def test_multiple_config_files_with_list(self): """Test loading multiple config files as a list.""" with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f1: - f1.write(""" -trainer: - _target_: pytorch_lightning.Trainer - max_epochs: 100 - -system: - _target_: lighter.System - model: - _target_: torch.nn.Identity - dataloaders: - train: {} -""") + f1.write("trainer:\n _target_: pytorch_lightning.Trainer\n max_epochs: 1\n") config_path1 = f1.name with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f2: - f2.write(""" -trainer: - max_epochs: 1 - -system: - model: - _target_: torch.nn.Identity - dataloaders: - train: {} - val: {} -""") + f2.write("model:\n _target_: lighter.LighterModule\n model:\n _target_: torch.nn.Identity\n") config_path2 = f2.name try: runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): + # Should load multiple configs successfully + # Just test that it doesn't raise during loading + with ( + patch.object(runner, "_resolve_model"), + patch.object(runner, "_resolve_trainer"), + patch.object(runner, "_resolve_datamodule"), + patch.object(runner, "_save_config"), + patch.object(runner, "_execute"), + ): runner.run(Stage.FIT, [config_path1, config_path2]) - # Second file should override - assert runner.config.get("trainer::max_epochs") == 1 - # Both dataloaders should exist (from second file) - assert runner.config.get("system::dataloaders::train") is not None - assert runner.config.get("system::dataloaders::val") is not None finally: Path(config_path1).unlink() Path(config_path2).unlink() - def test_run_with_dict_config(self): - """Test running with dict config instead of file path.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "dataloaders": {"train": {}, "val": {}}, - }, - } - - runner = Runner() - with patch.object(runner, "_setup"), patch.object(runner, "_execute"): - runner.run(Stage.FIT, config_dict) - assert runner.config is not None - assert runner.config.get("trainer::max_epochs") == 10 - - def test_stage_modes_mapping(self): - """Test that STAGE_MODES mapping is correct.""" - from lighter.utils.types.enums import Mode - - assert Runner.STAGE_MODES[Stage.FIT] == [Mode.TRAIN, Mode.VAL] - assert Runner.STAGE_MODES[Stage.VALIDATE] == [Mode.VAL] - assert Runner.STAGE_MODES[Stage.TEST] == [Mode.TEST] - assert Runner.STAGE_MODES[Stage.PREDICT] == [Mode.PREDICT] + def test_resolve_datamodule_missing_data_and_no_dataloaders(self): + """Test that missing data config raises error when model has no dataloaders.""" + from unittest.mock import MagicMock - def test_runner_initialization_state(self): - """Test that Runner initializes with None state.""" - runner = Runner() - assert runner.config is None - assert runner.system is None - assert runner.trainer is None + from pytorch_lightning import LightningModule + from sparkwheel import Config - def test_execute_with_stage_args(self): - """Test that execute passes stage-specific args correctly.""" runner = Runner() - runner.config = MagicMock() - runner.system = MagicMock() - runner.trainer = MagicMock() - - # Mock resolve to return specific args - runner.config.resolve.return_value = {"ckpt_path": "checkpoint.ckpt"} - - runner._execute(Stage.FIT) - # Verify trainer.fit was called with the args - runner.trainer.fit.assert_called_once_with( - runner.system, - ckpt_path="checkpoint.ckpt", - ) - - def test_execute_with_empty_stage_args(self): - """Test that execute handles empty stage args (uses default).""" - runner = Runner() - runner.config = MagicMock() - runner.system = MagicMock() - runner.trainer = MagicMock() + # Create a model without dataloader methods + mock_model = MagicMock(spec=LightningModule) + # Remove dataloader methods from spec + del mock_model.train_dataloader + del mock_model.val_dataloader + del mock_model.test_dataloader + del mock_model.predict_dataloader - # Mock resolve to return default empty dict - runner.config.resolve.return_value = {} + # Create config without data key + config = Config().update({"trainer": {"_target_": "pytorch_lightning.Trainer"}}) - runner._execute(Stage.TEST) + with pytest.raises(ValueError, match="Missing required 'data:' config key"): + runner._resolve_datamodule(config, mock_model) - # Verify trainer.test was called with no extra args - runner.trainer.test.assert_called_once_with(runner.system) + def test_resolve_datamodule_invalid_type(self): + """Test that invalid datamodule type raises TypeError.""" + from unittest.mock import MagicMock - def test_setup_logs_config_to_trainer_logger(self): - """Test that setup logs config to trainer logger if available.""" - from pytorch_lightning import Trainer - - from lighter.system import System - - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer"}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "dataloaders": {"train": {}}, - }, - } + from pytorch_lightning import LightningModule + from sparkwheel import Config runner = Runner() - runner.config = Config.load(config_dict) - - mock_system = MagicMock(spec=System) - mock_trainer = MagicMock(spec=Trainer) - mock_logger = MagicMock() - mock_trainer.logger = mock_logger - - with patch.object(runner.config, "resolve") as mock_resolve: - mock_resolve.side_effect = [mock_system, mock_trainer] - runner._setup(Stage.FIT) - - # Verify hyperparameters were logged - mock_system.save_hyperparameters.assert_called_once() - mock_logger.log_hyperparams.assert_called_once() - def test_setup_handles_no_trainer_logger(self): - """Test that setup handles case when trainer has no logger.""" - from pytorch_lightning import Trainer - - from lighter.system import System - - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer"}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "dataloaders": {"train": {}}, - }, - } - - runner = Runner() - runner.config = Config.load(config_dict) + mock_model = MagicMock(spec=LightningModule) - mock_system = MagicMock(spec=System) - mock_trainer = MagicMock(spec=Trainer) - mock_trainer.logger = None # No logger + # Create config with data key that resolves to wrong type + config = Config().update({"data": {"_target_": "builtins.dict"}}) - with patch.object(runner.config, "resolve") as mock_resolve: - mock_resolve.side_effect = [mock_system, mock_trainer] - # Should not raise even without logger - runner._setup(Stage.FIT) - mock_system.save_hyperparameters.assert_called_once() + with pytest.raises(TypeError, match="data must be LightningDataModule"): + runner._resolve_datamodule(config, mock_model) diff --git a/tests/unit/test_engine_schema.py b/tests/unit/test_engine_schema.py deleted file mode 100644 index d4d531f6..00000000 --- a/tests/unit/test_engine_schema.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Unit tests for schema validation in lighter/engine/schema.py""" - -import pytest -from sparkwheel import Config, ValidationError - -from lighter.engine.schema import ( - AdapterConfig, - AdaptersConfig, - ArgsConfig, - DataloadersConfig, - LighterConfig, - MetricsConfig, - PredictAdapterConfig, - SystemConfig, -) - - -class TestAdapterConfig: - """Tests for AdapterConfig dataclass.""" - - def test_empty_adapter_config(self): - """Test creating an empty AdapterConfig.""" - config = AdapterConfig() - assert config.batch is None - assert config.criterion is None - assert config.metrics is None - assert config.logging is None - - def test_partial_adapter_config(self): - """Test creating an AdapterConfig with some fields.""" - config = AdapterConfig(batch={"key": "value"}, metrics={"key2": "value2"}) - assert config.batch == {"key": "value"} - assert config.criterion is None - assert config.metrics == {"key2": "value2"} - assert config.logging is None - - def test_full_adapter_config(self): - """Test creating a full AdapterConfig.""" - config = AdapterConfig( - batch={"b": 1}, - criterion={"c": 2}, - metrics={"m": 3}, - logging={"l": 4}, - ) - assert config.batch == {"b": 1} - assert config.criterion == {"c": 2} - assert config.metrics == {"m": 3} - assert config.logging == {"l": 4} - - -class TestPredictAdapterConfig: - """Tests for PredictAdapterConfig dataclass.""" - - def test_empty_predict_adapter_config(self): - """Test creating an empty PredictAdapterConfig.""" - config = PredictAdapterConfig() - assert config.batch is None - assert config.logging is None - - def test_predict_adapter_no_criterion(self): - """Test that PredictAdapterConfig doesn't have criterion field.""" - config = PredictAdapterConfig(batch={"key": "value"}) - assert not hasattr(config, "criterion") - assert not hasattr(config, "metrics") - - -class TestAdaptersConfig: - """Tests for AdaptersConfig dataclass.""" - - def test_empty_adapters_config(self): - """Test creating an empty AdaptersConfig.""" - config = AdaptersConfig() - assert config.train is None - assert config.val is None - assert config.test is None - assert config.predict is None - - def test_partial_adapters_config(self): - """Test creating an AdaptersConfig with some stages.""" - config = AdaptersConfig(train={"batch": {}}, val={"batch": {}}) - assert config.train is not None - assert config.val is not None - assert config.test is None - assert config.predict is None - - -class TestMetricsConfig: - """Tests for MetricsConfig dataclass.""" - - def test_empty_metrics_config(self): - """Test creating an empty MetricsConfig.""" - config = MetricsConfig() - assert config.train is None - assert config.val is None - assert config.test is None - - def test_metrics_config_with_list(self): - """Test MetricsConfig accepts lists.""" - config = MetricsConfig(train=[{"_target_": "torchmetrics.Accuracy"}]) - assert isinstance(config.train, list) - assert len(config.train) == 1 - - def test_metrics_config_with_dict(self): - """Test MetricsConfig accepts dicts.""" - config = MetricsConfig(train={"accuracy": {"_target_": "torchmetrics.Accuracy"}}) - assert isinstance(config.train, dict) - assert "accuracy" in config.train - - -class TestDataloadersConfig: - """Tests for DataloadersConfig dataclass.""" - - def test_empty_dataloaders_config(self): - """Test creating an empty DataloadersConfig.""" - config = DataloadersConfig() - assert config.train is None - assert config.val is None - assert config.test is None - assert config.predict is None - - def test_dataloaders_config_all_stages(self): - """Test DataloadersConfig with all stages.""" - config = DataloadersConfig( - train={"batch_size": 32}, - val={"batch_size": 64}, - test={"batch_size": 128}, - predict={"batch_size": 256}, - ) - assert config.train["batch_size"] == 32 - assert config.val["batch_size"] == 64 - assert config.test["batch_size"] == 128 - assert config.predict["batch_size"] == 256 - - -class TestSystemConfig: - """Tests for SystemConfig dataclass.""" - - def test_empty_system_config(self): - """Test creating an empty SystemConfig.""" - config = SystemConfig() - assert config.model is None - assert config.criterion is None - assert config.optimizer is None - assert config.scheduler is None - assert config.inferer is None - assert config.metrics is None - assert config.dataloaders is None - assert config.adapters is None - - def test_system_config_with_nested_dataclasses(self): - """Test SystemConfig accepts nested dataclass instances.""" - metrics = MetricsConfig(train=[]) - dataloaders = DataloadersConfig(train={}) - adapters = AdaptersConfig(train={}) - - config = SystemConfig( - model={"_target_": "torch.nn.Identity"}, - metrics=metrics, - dataloaders=dataloaders, - adapters=adapters, - ) - assert config.model is not None - assert isinstance(config.metrics, MetricsConfig) - assert isinstance(config.dataloaders, DataloadersConfig) - assert isinstance(config.adapters, AdaptersConfig) - - -class TestArgsConfig: - """Tests for ArgsConfig dataclass.""" - - def test_empty_args_config(self): - """Test creating an empty ArgsConfig.""" - config = ArgsConfig() - assert config.fit is None - assert config.validate is None - assert config.test is None - assert config.predict is None - - def test_args_config_all_stages(self): - """Test ArgsConfig with all stages.""" - config = ArgsConfig( - fit={"ckpt_path": "checkpoint.ckpt"}, - validate={"verbose": True}, - test={"verbose": False}, - predict={"return_predictions": True}, - ) - assert config.fit["ckpt_path"] == "checkpoint.ckpt" - assert config.validate["verbose"] is True - assert config.test["verbose"] is False - assert config.predict["return_predictions"] is True - - -class TestLighterConfig: - """Tests for main LighterConfig schema.""" - - def test_minimal_valid_config(self): - """Test minimal valid configuration with only required fields.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("trainer::max_epochs") == 10 - assert config.get("system::model::_target_") == "torch.nn.Identity" - - def test_full_valid_config(self): - """Test full valid configuration with all optional fields.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "criterion": {"_target_": "torch.nn.CrossEntropyLoss"}, - "optimizer": {"_target_": "torch.optim.Adam", "lr": 0.001}, - "scheduler": {"_target_": "torch.optim.lr_scheduler.StepLR", "step_size": 10}, - "inferer": {"_target_": "lighter.Inferer"}, - "metrics": {"train": [], "val": [], "test": []}, - "dataloaders": {"train": {}, "val": {}, "test": {}, "predict": {}}, - "adapters": {"train": {}, "val": {}, "test": {}, "predict": {}}, - }, - "project": "./path/to/project", - "vars": {"learning_rate": 0.001}, - "args": {"fit": {}, "validate": {}, "test": {}, "predict": {}}, - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("project") == "./path/to/project" - assert config.get("vars::learning_rate") == 0.001 - assert config.get("system::optimizer::lr") == 0.001 - - def test_missing_trainer_raises_error(self): - """Test that missing trainer field raises ValidationError.""" - config_dict = { - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - } - with pytest.raises(ValidationError) as exc_info: - Config.load(config_dict, schema=LighterConfig) - assert "trainer" in str(exc_info.value).lower() - - def test_missing_system_raises_error(self): - """Test that missing system field raises ValidationError.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - } - with pytest.raises(ValidationError) as exc_info: - Config.load(config_dict, schema=LighterConfig) - assert "system" in str(exc_info.value).lower() - - def test_trainer_wrong_type_raises_error(self): - """Test that trainer with wrong type raises ValidationError.""" - config_dict = { - "trainer": ["not", "a", "dict"], # Should be dict, not list - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - } - with pytest.raises(ValidationError): - Config.load(config_dict, schema=LighterConfig) - - def test_system_wrong_type_raises_error(self): - """Test that system with wrong type raises ValidationError.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": "not a dict", # Should be SystemConfig/dict - } - with pytest.raises(ValidationError): - Config.load(config_dict, schema=LighterConfig) - - def test_optional_project_field(self): - """Test that project field is optional.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - "project": "./my_project", - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("project") == "./my_project" - - def test_optional_vars_field(self): - """Test that vars field is optional and accepts any dict.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - "vars": {"custom_var": 42, "another_var": "value"}, - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("vars::custom_var") == 42 - assert config.get("vars::another_var") == "value" - - def test_nested_metrics_structure(self): - """Test nested metrics configuration structure.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "metrics": { - "train": [{"_target_": "torchmetrics.Accuracy", "task": "binary"}], - "val": [{"_target_": "torchmetrics.F1Score", "task": "binary"}], - }, - }, - } - config = Config.load(config_dict, schema=LighterConfig) - train_metrics = config.get("system::metrics::train") - assert isinstance(train_metrics, list) - assert len(train_metrics) == 1 - assert train_metrics[0]["_target_"] == "torchmetrics.Accuracy" - - def test_nested_dataloaders_structure(self): - """Test nested dataloaders configuration structure.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "dataloaders": { - "train": {"_target_": "torch.utils.data.DataLoader", "batch_size": 32}, - "val": {"_target_": "torch.utils.data.DataLoader", "batch_size": 64}, - }, - }, - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("system::dataloaders::train::batch_size") == 32 - assert config.get("system::dataloaders::val::batch_size") == 64 - - def test_nested_adapters_structure(self): - """Test nested adapters configuration structure.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "adapters": { - "train": { - "batch": {"_target_": "lighter.adapters.BatchAdapter", "input_accessor": 0}, - }, - }, - }, - } - config = Config.load(config_dict, schema=LighterConfig) - batch_adapter = config.get("system::adapters::train::batch") - assert batch_adapter["_target_"] == "lighter.adapters.BatchAdapter" - assert batch_adapter["input_accessor"] == 0 - - def test_args_for_all_stages(self): - """Test args configuration for all stages.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - "args": { - "fit": {"ckpt_path": "checkpoint.ckpt"}, - "validate": {"verbose": True}, - "test": {"verbose": False}, - "predict": {"return_predictions": True}, - }, - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("args::fit::ckpt_path") == "checkpoint.ckpt" - assert config.get("args::validate::verbose") is True - assert config.get("args::test::verbose") is False - assert config.get("args::predict::return_predictions") is True - - def test_extra_fields_rejected(self): - """Test that extra fields not in schema are rejected (Sparkwheel behavior).""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": {"_target_": "lighter.System", "model": {"_target_": "torch.nn.Identity"}}, - "extra_field": "this is extra", - } - # Sparkwheel validates strictly against the schema - with pytest.raises(ValidationError, match="extra_field"): - Config.load(config_dict, schema=LighterConfig) - - def test_config_with_references(self): - """Test configuration with Sparkwheel references (@, %, $).""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "metrics": { - "train": [{"_target_": "torchmetrics.Accuracy", "task": "binary"}], - "val": "%::train", # Reference to train metrics - }, - }, - } - config = Config.load(config_dict, schema=LighterConfig) - # Before resolution, this is a reference string - assert config.get("system::metrics::val") == "%::train" - - def test_empty_nested_configs(self): - """Test empty nested configurations are valid.""" - config_dict = { - "trainer": {"_target_": "pytorch_lightning.Trainer", "max_epochs": 10}, - "system": { - "_target_": "lighter.System", - "model": {"_target_": "torch.nn.Identity"}, - "metrics": {}, - "dataloaders": {}, - "adapters": {}, - }, - } - config = Config.load(config_dict, schema=LighterConfig) - assert config.get("system::metrics") == {} - assert config.get("system::dataloaders") == {} - assert config.get("system::adapters") == {} diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py new file mode 100644 index 00000000..a9365c13 --- /dev/null +++ b/tests/unit/test_model.py @@ -0,0 +1,554 @@ +"""Unit tests for the LighterModule class.""" + +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn.functional as F +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR +from torchmetrics import Accuracy, MetricCollection + +from lighter.model import LighterModule +from lighter.utils.types.enums import Mode + +# Shared fixtures (mock_trainer, simple_model, dummy_dataset) are available +# from conftest.py and auto-discovered by pytest + +# ============================================================================ +# Helper Functions (imported from conftest for direct use) +# ============================================================================ + + +def mock_trainer_state(trainer, training=False, validating=False, testing=False, predicting=False, sanity_checking=False): + """Helper function to set trainer state flags for testing mode detection. + + Note: This is duplicated from conftest.py because it's used as a direct function + call, not as a fixture. + """ + trainer.training = training + trainer.validating = validating + trainer.testing = testing + trainer.predicting = predicting + trainer.sanity_checking = sanity_checking + + +# ============================================================================ +# Helper Classes and Fixtures +# ============================================================================ + + +class SimpleLighterModule(LighterModule): + """Concrete implementation of LighterModule for testing.""" + + def training_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = F.cross_entropy(pred, y) + + if self.train_metrics is not None: + self.train_metrics(pred, y) + + return {"loss": loss, "pred": pred, "target": y} + + def validation_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + loss = F.cross_entropy(pred, y) + + if self.val_metrics is not None: + self.val_metrics(pred, y) + + return {"loss": loss, "pred": pred, "target": y} + + def test_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + + if self.test_metrics is not None: + self.test_metrics(pred, y) + + return {"pred": pred, "target": y} + + def predict_step(self, batch, batch_idx): + x, y = batch + pred = self(x) + return pred + + +@pytest.fixture +def simple_system(simple_model): + """Creates a SimpleLighterModule instance using the shared simple_model fixture.""" + criterion = nn.CrossEntropyLoss() + optimizer = SGD(simple_model.parameters(), lr=0.01) + scheduler = StepLR(optimizer, step_size=10, gamma=0.1) + + train_metrics = MetricCollection([Accuracy(task="multiclass", num_classes=2)]) + val_metrics = MetricCollection([Accuracy(task="multiclass", num_classes=2)]) + test_metrics = MetricCollection([Accuracy(task="multiclass", num_classes=2)]) + + system = SimpleLighterModule( + network=simple_model, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + train_metrics=train_metrics, + val_metrics=val_metrics, + test_metrics=test_metrics, + ) + + return system + + +# ============================================================================ +# Test System Class +# ============================================================================ + + +def test_system_requires_step_implementation(simple_model): + """Test that LighterModule raises NotImplementedError when step methods are called.""" + system = LighterModule(network=simple_model) + + # System can be instantiated, but calling unimplemented steps should raise + batch = (torch.randn(2, 4), torch.tensor([0, 1])) + + with pytest.raises(NotImplementedError, match="must implement training_step"): + system.training_step(batch, batch_idx=0) + + with pytest.raises(NotImplementedError, match="must implement validation_step"): + system.validation_step(batch, batch_idx=0) + + with pytest.raises(NotImplementedError, match="must implement test_step"): + system.test_step(batch, batch_idx=0) + + with pytest.raises(NotImplementedError, match="must implement predict_step"): + system.predict_step(batch, batch_idx=0) + + +def test_system_initialization(simple_system): + """Check that attributes are correctly set after initialization.""" + assert isinstance(simple_system.network, nn.Module) + assert simple_system.criterion is not None + assert simple_system.optimizer is not None + assert simple_system.scheduler is not None + assert simple_system.train_metrics is not None + assert simple_system.val_metrics is not None + assert simple_system.test_metrics is not None + + +def test_prepare_metrics(simple_model): + """Test _prepare_metrics validates input types correctly.""" + from torchmetrics import Metric + + # Test with None + system = SimpleLighterModule(network=simple_model) + assert system.train_metrics is None + + # Test with single Metric + metric = Accuracy(task="multiclass", num_classes=2) + system = SimpleLighterModule(network=simple_model, train_metrics=metric) + assert isinstance(system.train_metrics, Metric) + assert system.train_metrics is metric # Should be unchanged + + # Test with MetricCollection + metrics = MetricCollection([Accuracy(task="multiclass", num_classes=2)]) + system = SimpleLighterModule(network=simple_model, train_metrics=metrics) + assert isinstance(system.train_metrics, MetricCollection) + assert system.train_metrics is metrics # Should be unchanged + + # Test with list raises TypeError + with pytest.raises(TypeError, match="metrics must be Metric or MetricCollection"): + SimpleLighterModule(network=simple_model, train_metrics=[Accuracy(task="multiclass", num_classes=2)]) + + # Test with dict raises TypeError + with pytest.raises(TypeError, match="metrics must be Metric or MetricCollection"): + SimpleLighterModule(network=simple_model, train_metrics={"acc": Accuracy(task="multiclass", num_classes=2)}) + + +def test_configure_optimizers(simple_system): + """Test configure_optimizers returns correct structure.""" + opt_config = simple_system.configure_optimizers() + assert isinstance(opt_config, dict) + assert "optimizer" in opt_config + assert "lr_scheduler" in opt_config + + +def test_configure_optimizers_without_optimizer(simple_model): + """Test configure_optimizers when no optimizer is provided.""" + system = SimpleLighterModule(network=simple_model, optimizer=None) + + with pytest.raises(ValueError, match="Optimizer not configured"): + system.configure_optimizers() + + +def test_configure_optimizers_without_scheduler(simple_model): + """Test configure_optimizers when no scheduler is provided.""" + optimizer = SGD(simple_model.parameters(), lr=0.01) + system = SimpleLighterModule(network=simple_model, optimizer=optimizer, scheduler=None) + + opt_config = system.configure_optimizers() + assert isinstance(opt_config, dict) + assert "optimizer" in opt_config + assert "lr_scheduler" not in opt_config + + +def test_forward_delegates_to_model(simple_system): + """Test that forward() delegates to self.network().""" + x = torch.randn(2, 4) + output = simple_system(x) + + # Should have same output as calling network directly + expected = simple_system.network(x) + assert torch.allclose(output, expected) + + +def test_mode_property(simple_system, mock_trainer): + """Test mode property detects mode from trainer state.""" + simple_system.trainer = mock_trainer + + # Test training mode + mock_trainer_state(mock_trainer, training=True) + assert simple_system.mode == Mode.TRAIN + + # Test validation mode + mock_trainer_state(mock_trainer, validating=True) + assert simple_system.mode == Mode.VAL + + # Test test mode + mock_trainer_state(mock_trainer, testing=True) + assert simple_system.mode == Mode.TEST + + # Test predict mode + mock_trainer_state(mock_trainer, predicting=True) + assert simple_system.mode == Mode.PREDICT + + # Test sanity checking (should return VAL) + mock_trainer_state(mock_trainer, sanity_checking=True) + assert simple_system.mode == Mode.VAL + + +def test_mode_without_trainer(simple_system): + """Test mode property raises error when no trainer attached.""" + simple_system.trainer = None + with pytest.raises(RuntimeError, match="is not attached to a"): + _ = simple_system.mode + + +def test_mode_undetermined_state(simple_system, mock_trainer): + """Test mode property raises error when trainer is in undetermined state.""" + simple_system.trainer = mock_trainer + # Set all flags to False (no active mode) + mock_trainer_state(mock_trainer, training=False, validating=False, testing=False, predicting=False, sanity_checking=False) + + with pytest.raises(RuntimeError, match="Cannot determine mode"): + _ = simple_system.mode + + +# Dataloader tests removed - dataloaders are now configured via data: key in config + + +def test_validate_and_log_simple_loss(simple_system, mock_trainer): + """Test _log_outputs with simple scalar loss.""" + simple_system.trainer = mock_trainer + mock_trainer.training = True + simple_system.log = MagicMock() + + output = {"loss": torch.tensor(1.0)} + simple_system._log_outputs(output, batch_idx=0) + + # Should log loss twice (step + epoch) + assert simple_system.log.call_count >= 2 + + +def test_validate_and_log_dict_loss(simple_system, mock_trainer): + """Test _log_outputs with multi-component loss dict.""" + simple_system.trainer = mock_trainer + mock_trainer.training = True + simple_system.log = MagicMock() + + loss_dict = {"total": torch.tensor(3.0), "ce": torch.tensor(2.0), "reg": torch.tensor(1.0)} + output = {"loss": loss_dict} + simple_system._log_outputs(output, batch_idx=0) + + # Should log each loss component twice (step + epoch) + # 3 components * 2 (step + epoch) = 6 calls minimum + assert simple_system.log.call_count >= 6 + + +def test_validate_and_log_none_loss(simple_system, mock_trainer): + """Test _log_outputs handles None loss gracefully (doesn't crash).""" + simple_system.trainer = mock_trainer + mock_trainer.training = True + simple_system.log = MagicMock() + + output = {"loss": None} + simple_system._log_outputs(output, batch_idx=0) + # Test passes if no exception is raised + + +def test_validate_and_log_dict_loss_without_total(simple_model, mock_trainer): + """Test _on_batch_end raises error when loss dict missing 'total' key.""" + system = SimpleLighterModule(network=simple_model) + system.trainer = mock_trainer + mock_trainer.training = True + + loss_dict = {"ce": torch.tensor(2.0), "reg": torch.tensor(1.0)} # Missing 'total' + output = {"loss": loss_dict} + + with pytest.raises(ValueError, match="Loss dict.*must include 'total' key"): + system._on_batch_end(output, batch_idx=0) + + +def test_validate_and_log_without_logger(simple_system, mock_trainer): + """Test _log_outputs does nothing when no logger.""" + simple_system.trainer = mock_trainer + mock_trainer.logger = None + + output = {"loss": torch.tensor(1.0)} + # Should not raise any errors + simple_system._log_outputs(output, batch_idx=0) + + +def test_batch_end_hooks_call_validate_and_log(simple_system, mock_trainer): + """Test that batch-end hooks call _log_outputs.""" + simple_system.trainer = mock_trainer + simple_system._log_outputs = MagicMock() + + output = {"loss": torch.tensor(1.0)} + batch = (torch.randn(2, 4), torch.tensor([0, 1])) + + # Test on_train_batch_end + simple_system.on_train_batch_end(output, batch, batch_idx=0) + simple_system._log_outputs.assert_called_once_with(output, 0) + + # Test on_validation_batch_end + simple_system._log_outputs.reset_mock() + simple_system.on_validation_batch_end(output, batch, batch_idx=0) + simple_system._log_outputs.assert_called_once_with(output, 0) + + # Test on_test_batch_end + simple_system._log_outputs.reset_mock() + simple_system.on_test_batch_end(output, batch, batch_idx=0) + simple_system._log_outputs.assert_called_once_with(output, 0) + + +def test_predict_step_implementation(simple_system): + """Test predict_step implementation.""" + batch = (torch.randn(2, 4), torch.tensor([0, 1])) + output = simple_system.predict_step(batch, batch_idx=0) + + # Should return predictions + x, y = batch + expected = simple_system(x) + assert torch.allclose(output, expected) + + +def test_training_step_returns_dict(simple_system): + """Test that training_step returns dict.""" + simple_system.trainer = MagicMock() + mock_trainer_state(simple_system.trainer, training=True) + + # Create batch manually + x = torch.randn(2, 4) + y = torch.tensor([0, 1]) + batch = (x, y) + + output = simple_system.training_step(batch, batch_idx=0) + + assert isinstance(output, dict) + assert "loss" in output + assert "pred" in output + assert "target" in output + + +def test_validation_step_returns_dict(simple_system): + """Test that validation_step returns dict.""" + simple_system.trainer = MagicMock() + mock_trainer_state(simple_system.trainer, validating=True) + + # Create batch manually + x = torch.randn(2, 4) + y = torch.tensor([0, 1]) + batch = (x, y) + + output = simple_system.validation_step(batch, batch_idx=0) + + assert isinstance(output, dict) + assert "loss" in output + assert "pred" in output + + +def test_test_step_returns_dict(simple_system): + """Test that test_step returns dict.""" + simple_system.trainer = MagicMock() + mock_trainer_state(simple_system.trainer, testing=True) + + # Create batch manually + x = torch.randn(2, 4) + y = torch.tensor([0, 1]) + batch = (x, y) + + output = simple_system.test_step(batch, batch_idx=0) + + assert isinstance(output, dict) + assert "loss" not in output # Test mode doesn't require loss + assert "pred" in output + + +def test_metrics_logging_in_validate_and_log(simple_system): + """Test that metrics are logged in _log_outputs.""" + simple_system.trainer = MagicMock() + simple_system.trainer.logger = MagicMock() + simple_system.trainer.training = True + simple_system.log = MagicMock() + + # Call train_metrics manually to update them + pred = torch.randn(2, 2) + target = torch.tensor([0, 1]) + simple_system.train_metrics(pred, target) + + output = {"loss": torch.tensor(1.0), "pred": pred, "target": target} + simple_system._log_outputs(output, batch_idx=0) + + # Check that log was called for metrics + # Should have calls for loss (2) + metrics (2 per metric) + optimizer stats (2) + assert simple_system.log.call_count >= 2 + + +def test_optimizer_stats_logged_once_per_epoch(simple_system): + """Test that optimizer stats are logged only once per epoch (batch_idx=0).""" + from unittest.mock import patch + + simple_system.trainer = MagicMock() + simple_system.trainer.training = True + simple_system.trainer.validating = False + simple_system.trainer.testing = False + simple_system.trainer.predicting = False + simple_system.trainer.sanity_checking = False + + # batch_idx=0 should call get_optimizer_stats + with patch("lighter.model.get_optimizer_stats") as mock_get_stats: + mock_get_stats.return_value = {"lr": 0.01} + simple_system._log_optimizer_stats(batch_idx=0) + assert mock_get_stats.called + + # batch_idx=1 should NOT call get_optimizer_stats + with patch("lighter.model.get_optimizer_stats") as mock_get_stats: + mock_get_stats.return_value = {"lr": 0.01} + simple_system._log_optimizer_stats(batch_idx=1) + assert not mock_get_stats.called + + +def test_log_method_sets_sync_dist_for_epoch(simple_system): + """Test that _log sets sync_dist=True for epoch logging.""" + simple_system.trainer = MagicMock() + mock_trainer_state(simple_system.trainer, training=True) + simple_system.log = MagicMock() + + simple_system._log("test/metric", torch.tensor(1.0), on_epoch=True, sync_dist=True) + + # Check that epoch logging was called with sync_dist=True + assert simple_system.log.call_count == 1 + call_kwargs = simple_system.log.call_args[1] + assert call_kwargs["sync_dist"] is True + assert call_kwargs["on_epoch"] is True + assert call_kwargs["on_step"] is False + + +def test_normalize_output_accepts_dict(simple_model): + """Test that _normalize_output accepts and passes through dict.""" + system = SimpleLighterModule(network=simple_model) + + output = {"loss": torch.tensor(1.0), "pred": torch.randn(2, 2)} + normalized = system._normalize_output(output) + assert normalized is output # Should be the same dict + + +def test_normalize_output_accepts_tensor(simple_model): + """Test that _normalize_output converts Tensor to dict.""" + system = SimpleLighterModule(network=simple_model) + + output = torch.tensor(1.0) + normalized = system._normalize_output(output) + assert isinstance(normalized, dict) + assert "loss" in normalized + assert torch.allclose(normalized["loss"], output) + + +def test_normalize_output_rejects_invalid_types(simple_model): + """Test that _normalize_output rejects invalid types.""" + system = SimpleLighterModule(network=simple_model) + + # List should be rejected + bad_output = [1.0, 2.0] + with pytest.raises(TypeError, match="must return torch.Tensor or dict"): + system._normalize_output(bad_output) + + # String should be rejected + bad_output = "invalid" + with pytest.raises(TypeError, match="must return torch.Tensor or dict"): + system._normalize_output(bad_output) + + +def test_normalize_output_validates_loss_dict_has_total(simple_model): + """Test that _normalize_output validates loss dict has 'total' key.""" + system = SimpleLighterModule(network=simple_model) + + # Loss dict with 'total' should be accepted + valid_output = {"loss": {"total": torch.tensor(3.0), "ce": torch.tensor(2.0)}} + normalized = system._normalize_output(valid_output) + assert normalized is valid_output + + # Loss dict without 'total' should be rejected + invalid_output = {"loss": {"ce": torch.tensor(2.0), "reg": torch.tensor(1.0)}} + with pytest.raises(ValueError, match="Loss dict must include 'total' key"): + system._normalize_output(invalid_output) + + # Loss as tensor should still be accepted + tensor_loss_output = {"loss": torch.tensor(1.0), "pred": torch.randn(2, 2)} + normalized = system._normalize_output(tensor_loss_output) + assert normalized is tensor_loss_output + + +def test_batch_end_hooks_accept_tensor(simple_system, mock_trainer): + """Test that batch-end hooks accept and normalize tensor outputs.""" + simple_system.trainer = mock_trainer + mock_trainer.training = True + simple_system.log = MagicMock() + + # Tensor should be accepted and normalized + tensor_output = torch.tensor(1.0) + batch = (torch.randn(2, 4), torch.tensor([0, 1])) + + # Should not raise + simple_system.on_train_batch_end(tensor_output, batch, batch_idx=0) + # Should have called log + assert simple_system.log.call_count > 0 + + +def test_mode_property_without_trainer(simple_model): + """Test that mode property raises RuntimeError when trainer is not attached.""" + system = SimpleLighterModule(network=simple_model) + + # Accessing mode without trainer should raise RuntimeError + # (either from Lightning's trainer property or our own check) + with pytest.raises(RuntimeError): + _ = system.mode + + +def test_log_metrics_with_single_metric(simple_model, mock_trainer): + """Test that _log_metrics handles single Metric (not MetricCollection).""" + # Create system with single metric (not MetricCollection) + single_metric = Accuracy(task="multiclass", num_classes=2) + system = SimpleLighterModule(network=simple_model, train_metrics=single_metric) + system.trainer = mock_trainer + mock_trainer_state(mock_trainer, training=True) + system.log = MagicMock() + system._log = MagicMock() + + # Call _log_metrics - it gets metrics from self.train_metrics + system._log_metrics() + + # Should log on_step and on_epoch (2 calls for single metric) + assert system._log.call_count == 2 diff --git a/tests/unit/test_system.py b/tests/unit/test_system.py deleted file mode 100644 index 6c20524a..00000000 --- a/tests/unit/test_system.py +++ /dev/null @@ -1,613 +0,0 @@ -"""Unit tests for the System class.""" - -from unittest.mock import MagicMock - -import pytest -import pytorch_lightning as pl -import torch -from torch import nn -from torch.optim import SGD -from torch.optim.lr_scheduler import StepLR -from torch.utils.data import DataLoader, Dataset -from torchmetrics import Accuracy - -from lighter.system import System -from lighter.utils.types.enums import Data, Mode - - -class DummyDataset(Dataset): - """Dataset returning (input_tensor, target_int)""" - - def __init__(self, size=8, with_target=True): - super().__init__() - self.size = size - self.with_target = with_target - self.data = [] - for _ in range(self.size): - x = torch.randn(4) - if self.with_target: - y = torch.randint(0, 2, size=()).item() # scalar int - self.data.append((x, y)) - else: - self.data.append(x) - - def __getitem__(self, idx): - return self.data[idx] - - def __len__(self): - return self.size - - -class SimpleModel(nn.Module): - """Simple model with a single linear layer""" - - def __init__(self, in_features=4, out_features=2): - super().__init__() - self.linear = nn.Linear(in_features, out_features) - - def forward(self, x, epoch=None, step=None): - return self.linear(x) - - -@pytest.fixture -def dummy_dataloaders(): - """Provides train/val/test/predict DataLoaders""" - return { - "train": DataLoader(DummyDataset(size=8, with_target=True), batch_size=2), - "val": DataLoader(DummyDataset(size=4, with_target=True), batch_size=2), - "test": DataLoader(DummyDataset(size=4, with_target=True), batch_size=2), - "predict": DataLoader(DummyDataset(size=4, with_target=False), batch_size=2), - } - - -@pytest.fixture -def simple_system(dummy_dataloaders): - """Creates a System instance with a mock trainer and mocked log method""" - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - scheduler = StepLR(optimizer, step_size=10, gamma=0.1) - criterion = nn.CrossEntropyLoss() - metrics = { - "train": Accuracy(task="multiclass", num_classes=2), - "val": Accuracy(task="multiclass", num_classes=2), - "test": Accuracy(task="multiclass", num_classes=2), - } - - system = System( - model=model, - dataloaders=dummy_dataloaders, - optimizer=optimizer, - scheduler=scheduler, - criterion=criterion, - metrics=metrics, - adapters=None, - inferer=None, - ) - - # Initialize a Trainer without logger and checkpointing - trainer = pl.Trainer(logger=None, enable_checkpointing=False, max_epochs=1) - system.trainer = trainer - - # Mock _log_stats to prevent actual logging - system._log_stats = MagicMock() - - return system - - -def test_system_initialization(simple_system): - """Check that attributes are correctly set after initialization.""" - assert isinstance(simple_system.model, nn.Module) - assert simple_system.optimizer is not None - assert simple_system.scheduler is not None - assert simple_system.criterion is not None - assert simple_system.metrics["train"] is not None - assert simple_system.metrics["val"] is not None - assert simple_system.metrics["test"] is not None - - -def test_configure_optimizers(simple_system): - """ - Tests that configure_optimizers returns the correct structure: - { - 'optimizer': ..., - 'lr_scheduler': ... - } - """ - opt_config = simple_system.configure_optimizers() - assert isinstance(opt_config, dict), "configure_optimizers should return a dictionary." - assert "optimizer" in opt_config, "Optimizer key missing in configure_optimizers output." - assert "lr_scheduler" in opt_config, "LR scheduler key missing in configure_optimizers output." - - -def test_configure_optimizers_without_optimizer(dummy_dataloaders): - """Test configure_optimizers when no optimizer is provided.""" - model = SimpleModel() - system = System( - model=model, - dataloaders=dummy_dataloaders, - optimizer=None, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - with pytest.raises(ValueError, match="Please specify 'system.optimizer' in the config."): - system.configure_optimizers() - - -def test_configure_optimizers_without_scheduler(dummy_dataloaders): - """Test configure_optimizers when no scheduler is provided.""" - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - system = System( - model=model, - dataloaders=dummy_dataloaders, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - - opt_config = system.configure_optimizers() - assert isinstance(opt_config, dict) - assert "optimizer" in opt_config - assert "lr_scheduler" not in opt_config - - -def test_on_mode_start_and_end_train(simple_system): - """Check that _on_mode_start sets the correct mode and _on_mode_end resets it.""" - simple_system._on_mode_start(Mode.TRAIN) - assert simple_system.mode == Mode.TRAIN - simple_system._on_mode_end() - assert simple_system.mode is None - - -def test_training_step_runs(simple_system): - """ - Simulate a training step by calling lightning's hooks: - - on_train_start - - training_step - """ - simple_system.on_train_start() - batch = next(iter(simple_system.dataloaders.train)) - output = simple_system.training_step(batch, batch_idx=0) - - assert isinstance(output, dict), "Expected a dictionary output." - assert Data.LOSS in output, "Loss should be in output for training mode." - assert output[Data.LOSS] is not None, "Loss must not be None in train mode." - assert Data.METRICS in output, "Metrics should be in output for training mode." - assert Data.PRED in output, "Prediction tensor must be in the output." - assert output[Data.PRED] is not None, "Pred must not be None." - assert simple_system.mode == Mode.TRAIN - - simple_system.on_train_end() - assert simple_system.mode is None - - -def test_validation_step_runs(simple_system): - """ - Simulate a validation step by calling: - - on_validation_start - - validation_step - """ - simple_system.on_validation_start() - batch = next(iter(simple_system.dataloaders.val)) - output = simple_system.validation_step(batch, batch_idx=0) - - assert isinstance(output, dict), "Expected a dictionary output." - assert Data.LOSS in output, "Loss should be in output for validation mode." - assert output[Data.LOSS] is not None, "Loss must not be None in validation mode." - assert Data.METRICS in output, "Metrics should be in output for validation mode." - assert Data.PRED in output, "Prediction tensor must be in the output." - - simple_system.on_validation_end() - assert simple_system.mode is None - - -def test_test_step_runs(simple_system): - """ - Simulate a test step by calling: - - on_test_start - - test_step - """ - simple_system.on_test_start() - batch = next(iter(simple_system.dataloaders.test)) - output = simple_system.test_step(batch, batch_idx=0) - - assert isinstance(output, dict), "Expected a dictionary output." - assert Data.LOSS in output, "Loss should be in output for test mode." - assert output[Data.LOSS] is None, "Loss must be None in test mode." - assert Data.METRICS in output, "Metrics should be in output for test mode." - assert Data.PRED in output, "Prediction tensor must be in the output." - - simple_system.on_test_end() - assert simple_system.mode is None - - -def test_predict_step_runs(simple_system): - """ - Simulate a predict step using the predict_dataloader and check outputs. - """ - simple_system.on_predict_start() - batch = next(iter(simple_system.dataloaders.predict)) - output = simple_system.predict_step(batch, batch_idx=0) - - assert isinstance(output, dict), "Expected a dictionary output." - assert Data.PRED in output, "Predict should contain PRED." - assert output.get(Data.METRICS) is None, "Metrics must be None in predict mode." - assert output.get(Data.LOSS) is None, "Loss must be None in predict mode." - - simple_system.on_predict_end() - assert simple_system.mode is None - - -def test_no_criterion_in_train_raises_error(simple_system): - """ - If no criterion is specified in train mode, training_step should raise ValueError. - """ - # Explicitly set criterion to None - simple_system.criterion = None - - simple_system.on_train_start() - batch = next(iter(simple_system.dataloaders.train)) - with pytest.raises(ValueError, match="Please specify 'system.criterion'"): - _ = simple_system.training_step(batch, 0) - - -class DictLossNoTotal(nn.Module): - """ - A module-based "loss" returning a dict with no "total" key to trigger the ValueError. - """ - - def forward(self, pred, target): - return {"not_total": torch.tensor(1.0)} - - -def test_dict_loss_without_total_raises_error(simple_system): - """ - If the criterion returns a dictionary but does not contain 'total' key, - it should raise ValueError. - """ - simple_system.criterion = DictLossNoTotal() - simple_system.on_train_start() - - batch = next(iter(simple_system.dataloaders.train)) - with pytest.raises(ValueError, match="The loss dictionary must include a 'total' key that combines all sublosses."): - _ = simple_system.training_step(batch, 0) - - -def test_learning_rate_property(simple_system): - """Check that learning_rate getter/setter works properly with a single param group.""" - initial_lr = simple_system.learning_rate - assert initial_lr == 0.01 - - simple_system.learning_rate = 0.005 - assert simple_system.learning_rate == 0.005 - - -def test_learning_rate_multiple_param_groups_raises(): - """Ensure accessing .learning_rate with multiple param groups raises ValueError""" - model = SimpleModel() - param_groups = [ - {"params": model.linear.weight, "lr": 0.01}, - {"params": model.linear.bias, "lr": 0.001}, - ] - optimizer = SGD(param_groups) - system = System( - model=model, - dataloaders={"train": DataLoader(DummyDataset())}, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - system.trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1) - system.log = MagicMock() - - with pytest.raises(ValueError, match="multiple optimizer parameter groups"): - _ = system.learning_rate - - with pytest.raises(ValueError, match="multiple optimizer parameter groups"): - system.learning_rate = 0.0001 - - -def test_inferer_called_in_validation(simple_system): - """Ensure the inferer function is called in validation mode""" - mock_inferer = MagicMock(return_value=torch.randn(2, 2)) - simple_system.inferer = mock_inferer - - simple_system.on_validation_start() - batch = next(iter(simple_system.dataloaders.val)) - _ = simple_system.validation_step(batch, batch_idx=0) - - mock_inferer.assert_called_once() - - -def test_inferer_called_in_test(simple_system): - """Ensure the inferer function is called in test mode""" - mock_inferer = MagicMock(return_value=torch.randn(2, 2)) - simple_system.inferer = mock_inferer - - simple_system.on_test_start() - batch = next(iter(simple_system.dataloaders.test)) - _ = simple_system.test_step(batch, batch_idx=0) - - mock_inferer.assert_called_once() - - -def test_loss_logging_single_value(simple_system): - """Ensure loss logging occurs correctly when it's a single tensor""" - simple_system.on_train_start() - batch = next(iter(simple_system.dataloaders.train)) - output = simple_system.training_step(batch, batch_idx=0) - - assert Data.LOSS in output - simple_system._log_stats.assert_called_once_with(output[Data.LOSS], output[Data.METRICS], 0) - - -def test_loss_logging_dict_values(simple_system): - """Ensure loss logging occurs correctly when it's a dict of losses""" - - class MultiLoss(nn.Module): - def forward(self, pred, target): - return {"total": torch.tensor(1.0), "aux": torch.tensor(0.5)} - - simple_system.criterion = MultiLoss() - - simple_system.on_train_start() - batch = next(iter(simple_system.dataloaders.train)) - output = simple_system.training_step(batch, batch_idx=0) - - assert "total" in output[Data.LOSS] - assert "aux" in output[Data.LOSS] - simple_system._log_stats.assert_called_once_with(output[Data.LOSS], output[Data.METRICS], 0) - - -def test_metric_logging(simple_system): - """Ensure metric logging occurs""" - simple_system.on_train_start() - batch = next(iter(simple_system.dataloaders.train)) - output = simple_system.training_step(batch, batch_idx=0) - - assert Data.METRICS in output - simple_system._log_stats.assert_called_once_with(output[Data.LOSS], output[Data.METRICS], 0) - - -def test_dynamic_mode_hooks(): - """ - Test the dynamic attachment of mode-specific hooks in the System class. - - This test verifies that the appropriate hooks are dynamically attached - based on the availability of dataloaders for different modes (train, val, test, predict). - It checks that the hooks are overridden when a dataloader is provided and remain - as the default (super) implementation when not provided. - """ - - # Test case 1: All dataloaders are provided - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - system = System( - model=model, - dataloaders={ - "train": DataLoader(DummyDataset()), - "val": DataLoader(DummyDataset()), - "test": DataLoader(DummyDataset()), - "predict": DataLoader(DummyDataset()), - }, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - - # Assert that all hooks are overridden - assert system.training_step != super(System, system).training_step - assert system.train_dataloader != super(System, system).train_dataloader - assert system.on_train_start != super(System, system).on_train_start - assert system.on_train_end != super(System, system).on_train_end - - assert system.validation_step != super(System, system).validation_step - assert system.val_dataloader != super(System, system).val_dataloader - assert system.on_validation_start != super(System, system).on_validation_start - assert system.on_validation_end != super(System, system).on_validation_end - - assert system.test_step != super(System, system).test_step - assert system.test_dataloader != super(System, system).test_dataloader - assert system.on_test_start != super(System, system).on_test_start - assert system.on_test_end != super(System, system).on_test_end - - assert system.predict_step != super(System, system).predict_step - assert system.predict_dataloader != super(System, system).predict_dataloader - assert system.on_predict_start != super(System, system).on_predict_start - assert system.on_predict_end != super(System, system).on_predict_end - - # Test case 2: Only train dataloader is provided - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - system = System( - model=model, - dataloaders={"train": DataLoader(DummyDataset())}, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - - # Assert that only train hooks are overridden, other hooks remain as default - assert system.training_step != super(System, system).training_step - assert system.train_dataloader != super(System, system).train_dataloader - assert system.on_train_start != super(System, system).on_train_start - assert system.on_train_end != super(System, system).on_train_end - - assert system.validation_step == super(System, system).validation_step - assert system.val_dataloader == super(System, system).val_dataloader - assert system.on_validation_start == super(System, system).on_validation_start - assert system.on_validation_end == super(System, system).on_validation_end - - assert system.test_step == super(System, system).test_step - assert system.test_dataloader == super(System, system).test_dataloader - assert system.on_test_start == super(System, system).on_test_start - assert system.on_test_end == super(System, system).on_test_end - - assert system.predict_step == super(System, system).predict_step - assert system.predict_dataloader == super(System, system).predict_dataloader - assert system.on_predict_start == super(System, system).on_predict_start - assert system.on_predict_end == super(System, system).on_predict_end - - # Test case 3: Only val dataloader is provided - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - system = System( - model=model, - dataloaders={"val": DataLoader(DummyDataset())}, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - - # Assert that only validation hooks are overridden, other hooks remain as default - assert system.training_step == super(System, system).training_step - assert system.train_dataloader == super(System, system).train_dataloader - assert system.on_train_start == super(System, system).on_train_start - assert system.on_train_end == super(System, system).on_train_end - - assert system.validation_step != super(System, system).validation_step - assert system.val_dataloader != super(System, system).val_dataloader - assert system.on_validation_start != super(System, system).on_validation_start - assert system.on_validation_end != super(System, system).on_validation_end - - assert system.test_step == super(System, system).test_step - assert system.test_dataloader == super(System, system).test_dataloader - assert system.on_test_start == super(System, system).on_test_start - assert system.on_test_end == super(System, system).on_test_end - - assert system.predict_step == super(System, system).predict_step - assert system.predict_dataloader == super(System, system).predict_dataloader - assert system.on_predict_start == super(System, system).on_predict_start - assert system.on_predict_end == super(System, system).on_predict_end - - # Test case 4: Only test dataloader is provided - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - system = System( - model=model, - dataloaders={"test": DataLoader(DummyDataset())}, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - - # Assert that only test hooks are overridden, other hooks remain as default - assert system.training_step == super(System, system).training_step - assert system.train_dataloader == super(System, system).train_dataloader - assert system.on_train_start == super(System, system).on_train_start - assert system.on_train_end == super(System, system).on_train_end - - assert system.validation_step == super(System, system).validation_step - assert system.val_dataloader == super(System, system).val_dataloader - assert system.on_validation_start == super(System, system).on_validation_start - assert system.on_validation_end == super(System, system).on_validation_end - - assert system.test_step != super(System, system).test_step - assert system.test_dataloader != super(System, system).test_dataloader - assert system.on_test_start != super(System, system).on_test_start - assert system.on_test_end != super(System, system).on_test_end - - assert system.predict_step == super(System, system).predict_step - assert system.predict_dataloader == super(System, system).predict_dataloader - assert system.on_predict_start == super(System, system).on_predict_start - assert system.on_predict_end == super(System, system).on_predict_end - - # Test case 5: Only predict dataloader is provided - model = SimpleModel() - optimizer = SGD(model.parameters(), lr=0.01) - system = System( - model=model, - dataloaders={"predict": DataLoader(DummyDataset())}, - optimizer=optimizer, - scheduler=None, - criterion=nn.CrossEntropyLoss(), - metrics=None, - adapters=None, - inferer=None, - ) - - # Assert that only predict hooks are overridden, other hooks remain as default - assert system.training_step == super(System, system).training_step - assert system.train_dataloader == super(System, system).train_dataloader - assert system.on_train_start == super(System, system).on_train_start - assert system.on_train_end == super(System, system).on_train_end - - assert system.validation_step == super(System, system).validation_step - assert system.val_dataloader == super(System, system).val_dataloader - assert system.on_validation_start == super(System, system).on_validation_start - assert system.on_validation_end == super(System, system).on_validation_end - - assert system.test_step == super(System, system).test_step - assert system.test_dataloader == super(System, system).test_dataloader - assert system.on_test_start == super(System, system).on_test_start - assert system.on_test_end == super(System, system).on_test_end - - assert system.predict_step != super(System, system).predict_step - assert system.predict_dataloader != super(System, system).predict_dataloader - assert system.on_predict_start != super(System, system).on_predict_start - assert system.on_predict_end != super(System, system).on_predict_end - - -def test_log_stats_without_logger(simple_system): - """Test _log_stats when trainer has no logger.""" - # Override the mock to test actual _log_stats behavior - simple_system._log_stats = System._log_stats.__get__(simple_system) - simple_system.trainer.logger = None - - # This should not raise any errors and should return early - simple_system._log_stats(torch.tensor(1.0), None, 0) - - -def test_log_stats_with_logger(simple_system): - """Test _log_stats with a logger.""" - # Override the mock to test actual _log_stats behavior - simple_system._log_stats = System._log_stats.__get__(simple_system) - simple_system.trainer.logger = MagicMock() - simple_system.log = MagicMock() - - # Test single loss value - simple_system.mode = Mode.TRAIN - simple_system._log_stats(torch.tensor(1.0), None, 0) - # Twice (on step, on epoch) for the loss, twice for the SGD optimizer (lr, momentum) - assert simple_system.log.call_count == 4 - - # Test dict loss values - simple_system.log.reset_mock() - loss_dict = {"total": torch.tensor(1.0), "aux": torch.tensor(0.5)} - simple_system._log_stats(loss_dict, None, 0) - # Twice (on step, on epoch) for each loss, twice for the SGD optimizer (lr, momentum) - assert simple_system.log.call_count == 6 - - # Test metrics - simple_system.log.reset_mock() - metrics = {"accuracy": torch.tensor(0.95)} - simple_system._log_stats(None, metrics, 0) - # Twice (on step, on epoch) for the metric, twice for the SGD optimizer (lr, momentum) - assert simple_system.log.call_count == 4 - - # Test optimizer stats (only in train mode, batch_idx=0) - simple_system.log.reset_mock() - simple_system._log_stats(None, None, 0) - # Twice for the SGD optimizer (lr, momentum) - assert simple_system.log.call_count == 2 diff --git a/tests/unit/test_utils_data.py b/tests/unit/test_utils_data.py index 2f5a7b45..f5b779ec 100644 --- a/tests/unit/test_utils_data.py +++ b/tests/unit/test_utils_data.py @@ -1,3 +1,5 @@ +import pytest + from lighter.utils.data import collate_replace_corrupted @@ -40,3 +42,65 @@ def test_collate_replace_corrupted_all_corrupted(): collated_all_corrupted = collate_replace_corrupted(all_corrupted_batch, dataset) assert len(collated_all_corrupted) == len(all_corrupted_batch) assert all(val in dataset for val in collated_all_corrupted) + + +def test_collate_replace_corrupted_max_retries(): + """Test that max_retries prevents infinite loops. + + Tests: + - Function raises RuntimeError when max_retries is exceeded + - Error message provides helpful information about corruption rate + - Function works correctly with custom max_retries parameter + """ + + # Create a dataset that always returns None (fully corrupted) + class CorruptedDataset: + def __getitem__(self, idx): + return None + + def __len__(self): + return 100 + + dataset = CorruptedDataset() + batch = [None, None, None] + + # Test with low max_retries to trigger the error quickly + with pytest.raises(RuntimeError) as exc_info: + collate_replace_corrupted(batch, dataset, max_retries=5) + + # Verify the error message contains helpful information + error_msg = str(exc_info.value) + assert "maximum retry limit (5)" in error_msg + assert "high corruption rate" in error_msg + assert "increasing max_retries" in error_msg + + +def test_collate_replace_corrupted_custom_max_retries(): + """Test that custom max_retries parameter works correctly. + + Tests: + - Function respects custom max_retries value + - Function succeeds when valid samples are eventually found + """ + + # Create a dataset that returns None for first few accesses, then valid data + class PartiallyCorruptedDataset: + def __init__(self): + self.access_count = 0 + + def __getitem__(self, idx): + self.access_count += 1 + # Return None for first 10 accesses, then valid data + return None if self.access_count <= 10 else 42 + + def __len__(self): + return 100 + + dataset = PartiallyCorruptedDataset() + batch = [None, None] + + # This should eventually succeed with enough retries + result = collate_replace_corrupted(batch, dataset, max_retries=20) + assert len(result) == 2 + # At least some values should be 42 (the valid replacement) + assert any(val.item() == 42 for val in result) diff --git a/tests/unit/test_utils_dynamic_imports.py b/tests/unit/test_utils_dynamic_imports.py index c28ba8a2..760b7a41 100644 --- a/tests/unit/test_utils_dynamic_imports.py +++ b/tests/unit/test_utils_dynamic_imports.py @@ -1,158 +1,433 @@ +"""Tests for dynamic_imports module.""" + +import pickle import sys -from unittest.mock import MagicMock, patch +from io import BytesIO +from pathlib import Path +from unittest.mock import patch import pytest -from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS, import_module_from_path - - -def test_import_module_from_path_nonexistent(): - """ - Test importing a module from a nonexistent path raises FileNotFoundError. - - This test verifies that attempting to import a module from a path that - doesn't exist results in a FileNotFoundError being raised. - - Raises: - FileNotFoundError: Expected to be raised when path doesn't exist - """ - with pytest.raises(FileNotFoundError): - import_module_from_path("non_existent_module", "non_existent_path") - - -def test_optional_imports_nonexistent(): - """ - Test accessing a nonexistent module from OPTIONAL_IMPORTS raises ImportError. - - This test verifies that attempting to access a module that isn't defined - in the OPTIONAL_IMPORTS dictionary raises an ImportError. - - Raises: - ImportError: Expected to be raised when accessing undefined module - """ - with pytest.raises(ImportError): - _ = OPTIONAL_IMPORTS["non_existent_module"] - - -def test_optional_imports_available(): - """ - Test successful retrieval of an available optional import. - - This test verifies that when a module exists in OPTIONAL_IMPORTS, it: - 1. Returns the correct mock module - 2. Calls optional_import with the correct module name - 3. Only calls optional_import once - - Setup: - - Creates a mock module - - Patches optional_import to return the mock module - """ - mock_module = MagicMock() - with patch("lighter.utils.dynamic_imports.optional_import", return_value=(mock_module, True)) as mock_import: - module = OPTIONAL_IMPORTS["existent_module"] - assert module is mock_module - mock_import.assert_called_once_with("existent_module") - - -def test_import_module_from_path_already_imported(): - """ - Test importing an already imported module returns the existing module. - - This test verifies that when attempting to import a module that's already - in sys.modules, the function returns the existing module instead of - reloading it. - - Setup: - - Creates a mock module - - Adds mock module to sys.modules - """ - mock_module = MagicMock() - with patch.dict(sys.modules, {"already_imported_module": mock_module}): - import_module_from_path("already_imported_module", "some_path") - assert sys.modules["already_imported_module"] is mock_module - - -def test_import_module_from_path_with_init(): - """ - Test successful module import from a valid path with __init__.py. - - This test verifies the complete module import process: - 1. Path resolution and validation - 2. Module spec creation - 3. Module creation from spec - 4. Module execution - 5. Module registration in sys.modules - - Setup: - - Patches Path for file validation - - Patches spec creation and module creation utilities - - Creates mock spec and module objects - - The test verifies all steps in the import process are called correctly - and the module is properly registered in sys.modules. - """ - mock_spec = MagicMock() - mock_module = MagicMock() - - with ( - patch("lighter.utils.dynamic_imports.Path") as mock_path, - patch("lighter.utils.dynamic_imports.importlib.util.spec_from_file_location") as mock_spec_from_file, - patch("lighter.utils.dynamic_imports.importlib.util.module_from_spec") as mock_module_from_spec, - ): - # Setup mocks - mock_path.return_value.resolve.return_value.__truediv__.return_value.is_file.return_value = True - mock_spec_from_file.return_value = mock_spec - mock_module_from_spec.return_value = mock_module - - # Execute function - import_module_from_path("valid_module", "valid_path") - - # Verify mock interactions - mock_path.assert_called_once() - mock_spec_from_file.assert_called_once() - mock_module_from_spec.assert_called_once_with(mock_spec) - mock_spec.loader.exec_module.assert_called_once_with(mock_module) - assert sys.modules["valid_module"] is mock_module - - -def test_optional_import_success(): - """ - Tests that an available module is imported successfully - and is stored in the 'imports' dictionary. - """ - from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS - - # "sys" is a built-in module, guaranteed to be there. - mod = OPTIONAL_IMPORTS["sys"] - import sys - - assert mod is sys # We expect the returned object to be Python's 'sys' module - assert "sys" in OPTIONAL_IMPORTS.imports # The module name should now be in the 'imports' dictionary - - -def test_optional_import_already_imported(): - """ - Tests that requesting the same module a second time - does not re-import it but just returns the stored instance. - """ - from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS - - # First call forces import. - mod1 = OPTIONAL_IMPORTS["sys"] - # Second call should simply return from 'imports' dictionary. - mod2 = OPTIONAL_IMPORTS["sys"] - - assert mod1 is mod2 # Should be the same object reference (from the dictionary) - - -def test_optional_import_failure(): - """ - Tests that attempting to import a non-existent module - raises an ImportError. - """ - from lighter.utils.dynamic_imports import OPTIONAL_IMPORTS - - with pytest.raises(ImportError) as exc_info: - OPTIONAL_IMPORTS["this_module_does_not_exist_12345"] - - assert "this_module_does_not_exist_12345" in str(exc_info.value) +from lighter.utils.dynamic_imports import ( + _DynamicModuleFinder, + _HybridPickler, + _ModuleRegistry, + import_module_from_path, +) + + +class TestImportModuleFromPath: + """Tests for import_module_from_path function.""" + + def test_nonexistent_path_raises_error(self): + """Importing from a path without __init__.py raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="No __init__.py"): + import_module_from_path("nonexistent", "/path/that/does/not/exist") + + def test_already_imported_same_path_returns_existing(self, tmp_path): + """If module is already imported from same path, return cached module.""" + + # Create a real package + pkg_dir = tmp_path / "test_cached_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("VALUE = 99") + + module_name = "test_cached_module" + try: + # First import + result1 = import_module_from_path(module_name, pkg_dir) + # Second import from same path should return cached + result2 = import_module_from_path(module_name, pkg_dir) + assert result1 is result2 + finally: + sys.modules.pop(module_name, None) + + def test_already_imported_different_path_raises_error(self, tmp_path): + """If module is already imported from different path, raise ValueError.""" + # Create two different packages + pkg_dir1 = tmp_path / "pkg1" + pkg_dir1.mkdir() + (pkg_dir1 / "__init__.py").write_text("VALUE = 1") + + pkg_dir2 = tmp_path / "pkg2" + pkg_dir2.mkdir() + (pkg_dir2 / "__init__.py").write_text("VALUE = 2") + + module_name = "test_conflict_module" + try: + # First import + import_module_from_path(module_name, pkg_dir1) + # Second import from different path should raise + with pytest.raises(ValueError, match="already imported from"): + import_module_from_path(module_name, pkg_dir2) + finally: + sys.modules.pop(module_name, None) + + def test_successful_import(self, tmp_path): + """Successfully import a real package from filesystem.""" + # Create a real package + pkg_dir = tmp_path / "test_real_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("VALUE = 42") + + module_name = "test_real_pkg_import" + try: + result = import_module_from_path(module_name, pkg_dir) + + assert result is not None + assert result.VALUE == 42 + assert module_name in sys.modules + finally: + # Cleanup + sys.modules.pop(module_name, None) + + def test_import_with_submodule(self, tmp_path): + """Import a package and verify submodules can be imported.""" + # Create package with submodule + pkg_dir = tmp_path / "test_pkg_with_sub" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + (pkg_dir / "submod.py").write_text("SUBVALUE = 123") + + module_name = "test_pkg_with_sub_import" + try: + import_module_from_path(module_name, pkg_dir) + + # Import submodule using standard import + import importlib + + submod = importlib.import_module(f"{module_name}.submod") + assert submod.SUBVALUE == 123 + finally: + # Cleanup + sys.modules.pop(module_name, None) + sys.modules.pop(f"{module_name}.submod", None) + + +class TestModuleRegistry: + """Tests for _ModuleRegistry.""" + + def test_register_and_find_exact_match(self): + """Registry returns exact match for registered module.""" + registry = _ModuleRegistry() + test_path = Path("/test/path") + registry.register("mymodule", test_path) + + result = registry.find_root("mymodule") + assert result == ("mymodule", test_path) + + def test_find_submodule_returns_root(self): + """Registry returns root module for submodule queries.""" + registry = _ModuleRegistry() + test_path = Path("/test/path") + registry.register("mymodule", test_path) + + result = registry.find_root("mymodule.sub.deep") + assert result == ("mymodule", test_path) + + def test_find_unregistered_returns_none(self): + """Registry returns None for unregistered modules.""" + registry = _ModuleRegistry() + assert registry.find_root("unregistered") is None + + +class TestDynamicModuleFinder: + """Tests for _DynamicModuleFinder.""" + + def test_unregistered_module_returns_none(self): + """Finder returns None for modules not in registry.""" + finder = _DynamicModuleFinder() + # Use a unique name that won't be registered + result = finder.find_spec("completely_unknown_module_xyz", None, None) + assert result is None + + def test_missing_file_returns_none(self, tmp_path): + """Finder returns None when submodule file doesn't exist.""" + from lighter.utils.dynamic_imports import _registry + + # Create package without the submodule we'll look for + pkg_dir = tmp_path / "finder_test_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + + module_name = "finder_test_pkg_missing" + _registry.register(module_name, pkg_dir) + + try: + finder = _DynamicModuleFinder() + result = finder.find_spec(f"{module_name}.nonexistent", None, None) + assert result is None + finally: + # Registry doesn't have unregister, but module names are unique per test + pass + + def test_finds_root_package(self, tmp_path): + """Finder creates spec for root package __init__.py.""" + from lighter.utils.dynamic_imports import _registry + + pkg_dir = tmp_path / "finder_root_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("# root") + + module_name = "finder_root_pkg_test" + _registry.register(module_name, pkg_dir) + + finder = _DynamicModuleFinder() + result = finder.find_spec(module_name, None, None) + + assert result is not None + assert result.name == module_name + assert "__init__.py" in result.origin + + def test_finds_submodule(self, tmp_path): + """Finder creates spec for submodule .py file.""" + from lighter.utils.dynamic_imports import _registry + + pkg_dir = tmp_path / "finder_sub_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("") + (pkg_dir / "child.py").write_text("X = 1") + + module_name = "finder_sub_pkg_test" + _registry.register(module_name, pkg_dir) + + finder = _DynamicModuleFinder() + result = finder.find_spec(f"{module_name}.child", None, None) + + assert result is not None + assert result.name == f"{module_name}.child" + assert "child.py" in result.origin + + +class TestHybridPickler: + """Tests for _HybridPickler.""" + + def test_pickles_basic_types(self): + """HybridPickler can pickle and unpickle basic Python types.""" + buffer = BytesIO() + pickler = _HybridPickler(buffer) + + data = {"key": [1, 2, 3], "nested": {"a": "b"}} + pickler.dump(data) + + buffer.seek(0) + result = pickle.load(buffer) + assert result == data + + def test_pickles_lambda(self): + """HybridPickler can pickle lambdas (via cloudpickle).""" + buffer = BytesIO() + pickler = _HybridPickler(buffer) + + fn = lambda x: x * 2 # noqa: E731 + pickler.dump(fn) + + buffer.seek(0) + result = pickle.load(buffer) + assert result(5) == 10 + + def test_reducer_override_handles_functions(self): + """reducer_override returns valid reduction for functions.""" + buffer = BytesIO() + pickler = _HybridPickler(buffer) + + fn = lambda x: x + 1 # noqa: E731 + result = pickler.reducer_override(fn) + + # Should return a reduction tuple (callable, args) or similar + # NotImplemented means "use default pickling" + assert result is not NotImplemented + + def test_reducer_override_defers_multiprocessing_objects(self): + """reducer_override returns NotImplemented for multiprocessing internals. + + Multiprocessing objects (Queue, Pipe connections, etc.) have special reducers + in ForkingPickler._extra_reducers that must be preserved. HybridPickler's + reducer_override should return NotImplemented for these objects, allowing + the standard ForkingPickler dispatch to handle them correctly. + """ + import multiprocessing + from multiprocessing.connection import Connection + + buffer = BytesIO() + pickler = _HybridPickler(buffer) + + # Create a Pipe which gives us Connection objects - these are in _extra_reducers + recv_conn, send_conn = multiprocessing.Pipe(duplex=False) + + try: + # Verify Connection type is in _extra_reducers + from multiprocessing.reduction import ForkingPickler + + extra_reducers = getattr(ForkingPickler, "_extra_reducers", {}) + assert Connection in extra_reducers, "Connection should be in _extra_reducers" + + # reducer_override should return NotImplemented for Connection objects + result = pickler.reducer_override(recv_conn) + assert result is NotImplemented, "Should defer to ForkingPickler for Connection" + + result = pickler.reducer_override(send_conn) + assert result is NotImplemented, "Should defer to ForkingPickler for Connection" + + # Verify the object can still be pickled using HybridPickler + # (dispatch_table includes ForkingPickler's reducers) + buffer = BytesIO() + pickler = _HybridPickler(buffer) + pickler.dump(send_conn) + + # Verify serialization produced data + assert buffer.tell() > 0, "Should have written pickle data" + + finally: + recv_conn.close() + send_conn.close() + + def test_multiprocessing_objects_work_through_subprocess(self): + """Verify multiprocessing objects can be passed through actual child processes. + + This test ensures that the ForkingPickler._extra_reducers handling remains intact + by actually passing a multiprocessing Queue through a spawn-started child process. + The Queue uses Connection objects internally which are in _extra_reducers. + """ + import multiprocessing + + # Use spawn context to match PyTorch DataLoader behavior + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + + # Define worker function outside to avoid pickling issues + def worker(q): + q.put("success") + + # Start a child process that uses the queue + process = ctx.Process(target=worker, args=(queue,)) + process.start() + process.join(timeout=10) + + # Verify the queue worked correctly through the subprocess + assert not queue.empty(), "Queue should have received data from child process" + result = queue.get(timeout=1) + assert result == "success", "Should receive correct data from child process" + + # Cleanup + queue.close() + queue.join_thread() + + +class TestImportModuleFromPathErrors: + """Tests for error cases in import_module_from_path.""" + + def test_spec_from_file_returns_none(self, tmp_path): + """Test ModuleNotFoundError when spec_from_file_location returns None.""" + import importlib.util + + # Create a valid package + pkg_dir = tmp_path / "test_spec_none_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("VALUE = 1") + + module_name = "test_spec_none_module" + + # Mock spec_from_file_location to return None + with patch.object(importlib.util, "spec_from_file_location", return_value=None): + with pytest.raises(ModuleNotFoundError, match="Could not load"): + import_module_from_path(module_name, pkg_dir) + + def test_module_load_failure_cleans_up_state(self, tmp_path): + """Test that module load failure cleans up sys.modules and registry.""" + from lighter.utils.dynamic_imports import _registry + + # Create a package with a syntax error in __init__.py + pkg_dir = tmp_path / "broken_pkg" + pkg_dir.mkdir() + (pkg_dir / "__init__.py").write_text("def broken(\n") # Syntax error + + module_name = "test_broken_module_cleanup" + + # Verify module is not in sys.modules before + assert module_name not in sys.modules + + # Attempt import - should fail with SyntaxError + with pytest.raises(SyntaxError): + import_module_from_path(module_name, pkg_dir) + + # Verify cleanup: module should NOT be in sys.modules after failure + assert module_name not in sys.modules + + # Verify registry was not updated (find_root returns None for unregistered) + assert _registry.find_root(module_name) is None + + def test_module_runtime_error_cleans_up_state(self, tmp_path): + """Test that runtime error during module execution cleans up sys.modules and registry. + + This test triggers a module load failure via a runtime exception (not syntax error) + to verify the except block at lines 199-202 in dynamic_imports.py properly removes + the module from sys.modules and does not register it in _registry. + """ + from lighter.utils.dynamic_imports import _registry + + # Create a package that raises an error during execution + pkg_dir = tmp_path / "runtime_error_pkg" + pkg_dir.mkdir() + init_content = """ +# This module raises an error during import +VALUE = 1 +raise RuntimeError("Intentional failure during module load") +""" + (pkg_dir / "__init__.py").write_text(init_content) + + module_name = "test_runtime_error_module_cleanup" + + # Capture initial state + initial_modules = set(sys.modules.keys()) + assert module_name not in sys.modules, "Module should not exist before test" + assert _registry.find_root(module_name) is None, "Module should not be in registry before test" + + # Attempt import - should fail with RuntimeError + with pytest.raises(RuntimeError, match="Intentional failure"): + import_module_from_path(module_name, pkg_dir) + + # Verify cleanup: module should NOT be in sys.modules after failure + assert module_name not in sys.modules, "Failed module must be removed from sys.modules" + + # Verify no new modules were left behind (check module and potential submodules) + current_modules = set(sys.modules.keys()) + new_modules = current_modules - initial_modules + assert not any(m.startswith(module_name) for m in new_modules), ( + f"No modules starting with '{module_name}' should remain, but found: " + f"{[m for m in new_modules if m.startswith(module_name)]}" + ) + + # Verify registry was not updated + assert _registry.find_root(module_name) is None, "Failed module must not be registered in _registry" + + def test_module_import_error_cleans_up_state(self, tmp_path): + """Test that ImportError during module execution cleans up sys.modules and registry. + + This test triggers a failure via a missing import to verify cleanup works for + ImportError exceptions as well. + """ + from lighter.utils.dynamic_imports import _registry + + # Create a package that fails due to missing import + pkg_dir = tmp_path / "import_error_pkg" + pkg_dir.mkdir() + init_content = """ +# This module fails to import a nonexistent module +from nonexistent_module_xyz_12345 import something +""" + (pkg_dir / "__init__.py").write_text(init_content) + + module_name = "test_import_error_module_cleanup" + + # Verify preconditions + assert module_name not in sys.modules + assert _registry.find_root(module_name) is None + + # Attempt import - should fail with ModuleNotFoundError + with pytest.raises(ModuleNotFoundError): + import_module_from_path(module_name, pkg_dir) + + # Verify cleanup + assert module_name not in sys.modules, "Failed module must be removed from sys.modules" + assert _registry.find_root(module_name) is None, "Failed module must not be registered in _registry" diff --git a/tests/unit/test_utils_logging.py b/tests/unit/test_utils_logging.py index 02922785..59d74bfe 100644 --- a/tests/unit/test_utils_logging.py +++ b/tests/unit/test_utils_logging.py @@ -7,9 +7,9 @@ def test_setup_logging(): - """Test basic logging setup.""" + """Test basic logging setup completes without error.""" _setup_logging() - assert True # Just ensure no exceptions are raised + # Test passes if no exception is raised def test_warnings_handler(): diff --git a/tests/unit/test_utils_misc.py b/tests/unit/test_utils_misc.py index f460073f..2bc5e119 100644 --- a/tests/unit/test_utils_misc.py +++ b/tests/unit/test_utils_misc.py @@ -1,165 +1,132 @@ -from unittest.mock import MagicMock +"""Unit tests for utility functions in lighter/utils/misc.py""" -import pytest import torch from torch.optim import SGD, Adam -from lighter.utils.misc import ensure_list, get_name, get_optimizer_stats, hasarg, setattr_dot_notation +from lighter.utils.misc import ensure_list, get_name, get_optimizer_stats, hasarg -def test_ensure_list(): - """ - Test the ensure_list function which converts various input types to a list. +def test_ensure_list_with_list(): + """Test ensure_list returns list as-is.""" + input_list = [1, 2, 3] + assert ensure_list(input_list) == [1, 2, 3] + assert ensure_list(input_list) is input_list # Should return same object - Tests: - - Converting a single value to a single-item list - - Preserving an existing list - - Converting a tuple to a list - """ - assert ensure_list(1) == [1] - assert ensure_list([1, 2]) == [1, 2] - assert ensure_list((1, 2)) == [1, 2] # Test with tuple input +def test_ensure_list_with_tuple(): + """Test ensure_list converts tuple to list.""" + assert ensure_list((1, 2, 3)) == [1, 2, 3] -def test_setattr_dot_notation(): - """ - Test the setattr_dot_notation function which sets attributes using dot notation. - Tests: - - Setting a direct attribute on an object - - Setting a nested attribute using dot notation - - Verifying that attempting to set a non-existent attribute raises AttributeError +def test_ensure_list_with_none(): + """Test ensure_list returns empty list for None.""" + assert ensure_list(None) == [] - The function uses dummy classes to simulate nested object structures. - """ - class Dummy: - def __init__(self): - self.attr = MagicMock() +def test_ensure_list_with_single_value(): + """Test ensure_list wraps single value.""" + assert ensure_list(42) == [42] + assert ensure_list("string") == ["string"] - class NestedDummy: - def __init__(self): - self.inner = Dummy() - obj = Dummy() - nested_obj = NestedDummy() - setattr_dot_notation(obj, "attr", 10) - assert obj.attr == 10 +def test_hasarg_with_function(): + """Test hasarg with a simple function.""" - # Test with nested attribute using dot notation - setattr_dot_notation(nested_obj, "inner.attr", 20) - assert nested_obj.inner.attr == 20 + def test_func(a, b, c=10): + return a + b + c - with pytest.raises(AttributeError): - setattr_dot_notation(obj, "non_existent_attr", 10) + assert hasarg(test_func, "a") is True + assert hasarg(test_func, "b") is True + assert hasarg(test_func, "c") is True + assert hasarg(test_func, "d") is False -def test_hasarg(): - """ - Test the hasarg function which checks if a function has a specific argument. +def test_hasarg_with_method(): + """Test hasarg with a class method.""" - Tests: - - Verifying that an existing argument is correctly identified - - Verifying that a non-existent argument returns False + class TestClass: + def method(self, x, y): + return x + y - Uses a simple test function with two arguments to verify the functionality. - """ - - def func_with_args(a, b): - pass - - assert hasarg(func_with_args, "a") is True - assert hasarg(func_with_args, "c") is False + assert hasarg(TestClass.method, "self") is True + assert hasarg(TestClass.method, "x") is True + assert hasarg(TestClass.method, "y") is True + assert hasarg(TestClass.method, "z") is False -def test_get_name(): - """ - Test the get_name function which retrieves the name of a function or class. +def test_get_name_without_module(): + """Test get_name without module name.""" - Tests: - - Getting the name of a function - - Getting the name of a class - - Getting the fully qualified name (including module) of a function - - Verifies both simple name retrieval and module-included name retrieval. - """ - - def sample_function(): + def test_function(): pass - class Dummy: + class TestClass: pass - assert get_name(sample_function) == "sample_function" - assert get_name(Dummy) == "Dummy" - assert "test_utils_misc" in get_name(sample_function, include_module_name=True) + assert get_name(test_function) == "test_function" + assert get_name(TestClass) == "TestClass" + +def test_get_name_with_module(): + """Test get_name with module name.""" -def test_get_optimizer_stats(): - """ - Test the get_optimizer_stats function which extracts statistics from PyTorch optimizers. + def test_function(): + pass - Tests: - - Basic optimizer configuration with single parameter group - - Complex optimizer configuration with multiple parameter groups + # The test function's module is __main__ during testing + name = get_name(test_function, include_module_name=True) + assert "test_function" in name - Verifies: - - Correct extraction of learning rate and momentum values - - Proper handling of multiple parameter groups with distinct settings - - Correct formatting of stat names including group numbers - Uses SGD optimizer with both single and multiple parameter group configurations - to ensure comprehensive coverage of optimizer statistics extraction. - """ +def test_get_optimizer_stats_single_group(): + """Test get_optimizer_stats with single parameter group.""" model = torch.nn.Linear(10, 1) optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) + stats = get_optimizer_stats(optimizer) + assert "optimizer/SGD/lr" in stats - assert stats["optimizer/SGD/lr"] == 0.01 assert "optimizer/SGD/momentum" in stats + assert stats["optimizer/SGD/lr"] == 0.01 assert stats["optimizer/SGD/momentum"] == 0.9 - # Test with multiple parameter groups - # Create separate parameter groups with distinct parameters + +def test_get_optimizer_stats_multiple_groups(): + """Test get_optimizer_stats with multiple parameter groups.""" model1 = torch.nn.Linear(10, 1) model2 = torch.nn.Linear(10, 1) + optimizer = SGD( [ {"params": model1.parameters(), "lr": 0.01, "momentum": 0.9}, {"params": model2.parameters(), "lr": 0.02, "momentum": 0.8}, ] ) + stats = get_optimizer_stats(optimizer) + assert "optimizer/SGD/lr/group1" in stats - assert stats["optimizer/SGD/lr/group1"] == 0.01 - assert "optimizer/SGD/momentum/group1" in stats - assert stats["optimizer/SGD/momentum/group1"] == 0.9 assert "optimizer/SGD/lr/group2" in stats - assert stats["optimizer/SGD/lr/group2"] == 0.02 + assert "optimizer/SGD/momentum/group1" in stats assert "optimizer/SGD/momentum/group2" in stats + assert stats["optimizer/SGD/lr/group1"] == 0.01 + assert stats["optimizer/SGD/lr/group2"] == 0.02 + assert stats["optimizer/SGD/momentum/group1"] == 0.9 assert stats["optimizer/SGD/momentum/group2"] == 0.8 def test_get_optimizer_stats_with_betas(): - """ - Test the get_optimizer_stats function with optimizers that use betas instead of momentum. - - Tests: - - Optimizer with betas parameter (e.g., Adam) - - Multiple parameter groups with different betas values - - Verifies: - - Correct extraction of learning rate and beta1 values - - Proper handling of multiple parameter groups - - Correct formatting of stat names - """ + """Test get_optimizer_stats with Adam optimizer (uses betas instead of momentum).""" model = torch.nn.Linear(10, 1) optimizer = Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) stats = get_optimizer_stats(optimizer) assert "optimizer/Adam/lr" in stats assert stats["optimizer/Adam/lr"] == 0.001 - assert "optimizer/Adam/momentum" in stats - assert stats["optimizer/Adam/momentum"] == 0.9 # beta1 value + # Adam reports beta1 and beta2, not momentum + assert "optimizer/Adam/beta1" in stats + assert "optimizer/Adam/beta2" in stats + assert stats["optimizer/Adam/beta1"] == 0.9 + assert stats["optimizer/Adam/beta2"] == 0.999 # Test with multiple parameter groups with different betas model1 = torch.nn.Linear(10, 1) @@ -172,10 +139,46 @@ def test_get_optimizer_stats_with_betas(): ) stats = get_optimizer_stats(optimizer) assert "optimizer/Adam/lr/group1" in stats - assert stats["optimizer/Adam/lr/group1"] == 0.001 - assert "optimizer/Adam/momentum/group1" in stats - assert stats["optimizer/Adam/momentum/group1"] == 0.9 assert "optimizer/Adam/lr/group2" in stats + assert "optimizer/Adam/beta1/group1" in stats + assert "optimizer/Adam/beta1/group2" in stats + assert stats["optimizer/Adam/lr/group1"] == 0.001 assert stats["optimizer/Adam/lr/group2"] == 0.002 - assert "optimizer/Adam/momentum/group2" in stats - assert stats["optimizer/Adam/momentum/group2"] == 0.8 + assert stats["optimizer/Adam/beta1/group1"] == 0.9 + assert stats["optimizer/Adam/beta1/group2"] == 0.8 + + +def test_get_optimizer_stats_no_momentum(): + """Test get_optimizer_stats with optimizer without momentum.""" + model = torch.nn.Linear(10, 1) + optimizer = SGD(model.parameters(), lr=0.01, momentum=0) # No momentum + + stats = get_optimizer_stats(optimizer) + + assert "optimizer/SGD/lr" in stats + assert stats["optimizer/SGD/lr"] == 0.01 + # Should still include momentum even if it's 0 + assert "optimizer/SGD/momentum" in stats + assert stats["optimizer/SGD/momentum"] == 0 + + +def test_get_optimizer_stats_with_weight_decay(): + """Test get_optimizer_stats includes weight decay when non-zero.""" + model = torch.nn.Linear(10, 1) + optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.01) + + stats = get_optimizer_stats(optimizer) + + assert "optimizer/Adam/weight_decay" in stats + assert stats["optimizer/Adam/weight_decay"] == 0.01 + + +def test_get_optimizer_stats_zero_weight_decay(): + """Test get_optimizer_stats excludes weight decay when zero.""" + model = torch.nn.Linear(10, 1) + optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0) + + stats = get_optimizer_stats(optimizer) + + # Weight decay should not be in stats when it's 0 + assert "optimizer/Adam/weight_decay" not in stats diff --git a/tests/unit/test_utils_model.py b/tests/unit/test_utils_model.py deleted file mode 100644 index b62e651c..00000000 --- a/tests/unit/test_utils_model.py +++ /dev/null @@ -1,400 +0,0 @@ -import pytest -import torch -from torch.nn import Identity, Linear, Sequential - -from lighter.utils.model import ( - adjust_prefix_and_load_state_dict, - remove_n_last_layers_sequentially, - replace_layer_with, - replace_layer_with_identity, -) - - -@pytest.fixture -def dummy_model(): - """ - Creates a simple PyTorch model with two linear layers for testing purposes. - - Returns: - torch.nn.Module: A model with two linear layers (10->10 dimensions each). - """ - - class DummyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = Linear(10, 10) - self.layer2 = Linear(10, 10) - - def forward(self, x): - return self.layer2(self.layer1(x)) - - return DummyModel() - - -@pytest.fixture -def sequential_model(): - """ - Creates a sequential PyTorch model with three linear layers. - - Returns: - torch.nn.Sequential: A sequential model with three 10->10 dimensional linear layers. - """ - return Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10)) - - -@pytest.fixture -def state_dict_file(tmp_path): - """ - Creates a temporary file containing a state dict for a simple model. - - Args: - tmp_path: Pytest fixture providing a temporary directory unique to each test. - - Returns: - tuple: (Path to saved state dict, The state dict dictionary) - """ - model = torch.nn.Module() - model.layer1 = Linear(10, 10) - model.layer2 = Linear(10, 10) - - state_dict = { - "layer1.weight": torch.randn(10, 10), - "layer1.bias": torch.randn(10), - "layer2.weight": torch.randn(10, 10), - "layer2.bias": torch.randn(10), - } - - path = tmp_path / "test_ckpt.pth" - torch.save(state_dict, path) # nosec B614 - return path, state_dict - - -@pytest.fixture -def state_dict_with_prefix_file(tmp_path): - """ - Creates a temporary file containing a state dict with 'model.' prefix. - - Args: - tmp_path: Pytest fixture providing a temporary directory unique to each test. - - Returns: - tuple: (Path to saved state dict, The state dict dictionary with prefixes) - """ - state_dict_with_prefix = { - "model.layer1.weight": torch.randn(10, 10), - "model.layer1.bias": torch.randn(10), - "model.layer2.weight": torch.randn(10, 10), - "model.layer2.bias": torch.randn(10), - } - - path = tmp_path / "test_ckpt_with_prefix.pth" - torch.save({"state_dict": state_dict_with_prefix}, path) # nosec B614 - return path, state_dict_with_prefix - - -@pytest.fixture -def state_dict_with_custom_prefix_file(tmp_path): - """ - Creates a temporary file containing a state dict with custom prefixes. - - Args: - tmp_path: Pytest fixture providing a temporary directory unique to each test. - - Returns: - tuple: (Path to saved state dict, The state dict dictionary with custom prefixes) - """ - state_dict_with_prefix = { - "prefix1.layer1.weight": torch.randn(10, 10), - "prefix1.layer1.bias": torch.randn(10), - "prefix2.layer2.weight": torch.randn(10, 10), - "prefix2.layer2.bias": torch.randn(10), - } - - path = tmp_path / "test_ckpt_with_prefix_adjustment.pth" - torch.save(state_dict_with_prefix, path) # nosec B614 - return path, state_dict_with_prefix - - -@pytest.fixture -def mismatched_state_dict_file(tmp_path): - """ - Creates a temporary file containing a state dict with completely different keys. - - Args: - tmp_path: Pytest fixture providing a temporary directory unique to each test. - - Returns: - str: Path to saved state dict - """ - state_dict = { - "completely.different.weight": torch.randn(10, 10), - "another.different.bias": torch.randn(10), - } - - path = tmp_path / "test_mismatched_ckpt.pth" - torch.save(state_dict, path) # nosec B614 - return path - - -@pytest.fixture -def empty_state_dict_file(tmp_path): - """ - Creates a temporary file containing an empty state dict. - - Args: - tmp_path: Pytest fixture providing a temporary directory unique to each test. - - Returns: - str: Path to saved state dict - """ - path = tmp_path / "test_empty_ckpt.pth" - torch.save({}, path) # nosec B614 - return path - - -@pytest.fixture -def perfect_match_state_dict_file(tmp_path): - """ - Creates a temporary file containing a state dict that perfectly matches the model. - - Args: - tmp_path: Pytest fixture providing a temporary directory unique to each test. - - Returns: - str: Path to saved state dict - """ - state_dict = { - "layer1.weight": torch.randn(10, 10), - "layer1.bias": torch.randn(10), - "layer2.weight": torch.randn(10, 10), - "layer2.bias": torch.randn(10), - } - - path = tmp_path / "test_perfect_match_ckpt.pth" - torch.save(state_dict, path) # nosec B614 - return path - - -def test_replace_layer_with(dummy_model): - """ - Tests if a layer in the model can be successfully replaced with a new layer. - - Args: - dummy_model: Fixture providing a simple model for testing. - """ - new_layer = Linear(10, 4) - replace_layer_with(dummy_model, "layer1", new_layer) - assert dummy_model.layer1 == new_layer - - -def test_replace_layer_with_identity(dummy_model): - """ - Tests if a layer can be successfully replaced with an Identity layer. - - Args: - dummy_model: Fixture providing a simple model for testing. - """ - replace_layer_with_identity(dummy_model, "layer1") - assert isinstance(dummy_model.layer1, Identity) - - -def test_remove_n_last_layers_sequentially(sequential_model): - """ - Tests removing the last n layers from a sequential model. - - Args: - sequential_model: Fixture providing a sequential model with three layers. - - Tests: - - Removing one layer reduces model length by 1 - - Removing two layers reduces model length by 2 - """ - new_model = remove_n_last_layers_sequentially(sequential_model, num_layers=1) - assert len(new_model) == 2 - - new_model = remove_n_last_layers_sequentially(sequential_model, num_layers=2) - assert len(new_model) == 1 - - -def test_adjust_prefix_and_load_state_dict_basic(dummy_model, state_dict_file): - """ - Tests basic functionality of loading a state dict into a model. - - Args: - dummy_model: Fixture providing a simple model for testing. - state_dict_file: Fixture providing a state dict file and its contents. - - Tests: - - Weights change after loading - - New weights match the state dict values - """ - path, state_dict = state_dict_file - - # Store initial random weights - initial_weights = { - "layer1.weight": dummy_model.layer1.weight.clone(), - "layer1.bias": dummy_model.layer1.bias.clone(), - "layer2.weight": dummy_model.layer2.weight.clone(), - "layer2.bias": dummy_model.layer2.bias.clone(), - } - - # Load state dict - adjust_prefix_and_load_state_dict(dummy_model, path) - - # Verify weights changed and match state dict - assert not torch.equal(dummy_model.layer1.weight, initial_weights["layer1.weight"]) - assert not torch.equal(dummy_model.layer1.bias, initial_weights["layer1.bias"]) - assert torch.equal(dummy_model.layer1.weight, state_dict["layer1.weight"]) - assert torch.equal(dummy_model.layer1.bias, state_dict["layer1.bias"]) - - -def test_adjust_prefix_and_load_state_dict_with_ignored_layers(dummy_model, state_dict_file): - """ - Tests loading a state dict while ignoring specific layers. - - Args: - dummy_model: Fixture providing a simple model for testing. - state_dict_file: Fixture providing a state dict file and its contents. - - Tests: - - Ignored layers maintain their original values - """ - path, _ = state_dict_file - - # Store original values of layer2 - original_weight = dummy_model.layer2.weight.clone() - original_bias = dummy_model.layer2.bias.clone() - - # Load state dict with ignored layers - adjust_prefix_and_load_state_dict(dummy_model, path, layers_to_ignore=["layer2.weight", "layer2.bias"]) - - # Verify ignored layers remain unchanged - assert torch.equal(dummy_model.layer2.weight, original_weight) - assert torch.equal(dummy_model.layer2.bias, original_bias) - - -def test_adjust_prefix_and_load_state_dict_with_model_prefix(dummy_model, state_dict_with_prefix_file): - """ - Tests loading a state dict that has 'model.' prefix in its keys. - - Args: - dummy_model: Fixture providing a simple model for testing. - state_dict_with_prefix_file: Fixture providing a state dict file with 'model.' prefixed keys. - - Tests: - - Model loads correctly despite prefix differences - - Weights match the state dict values after prefix adjustment - """ - path, state_dict = state_dict_with_prefix_file - - # Load state dict - adjust_prefix_and_load_state_dict(dummy_model, path) - - # Verify weights match state dict (without the "model." prefix) - assert torch.equal(dummy_model.layer1.weight, state_dict["model.layer1.weight"]) - assert torch.equal(dummy_model.layer1.bias, state_dict["model.layer1.bias"]) - - -def test_adjust_prefix_and_load_state_dict_with_custom_prefix(dummy_model, state_dict_with_custom_prefix_file): - """ - Tests loading a state dict with custom prefixes using prefix mapping. - - Args: - dummy_model: Fixture providing a simple model for testing. - state_dict_with_custom_prefix_file: Fixture providing a state dict file with custom prefixed keys. - - Tests: - - Model loads correctly with custom prefix mapping - - Weights match the state dict values after prefix adjustment - """ - path, state_dict = state_dict_with_custom_prefix_file - - # Load state dict with prefix adjustments - ckpt_to_model_prefix = {"prefix1": "", "prefix2": ""} - adjust_prefix_and_load_state_dict(dummy_model, path, ckpt_to_model_prefix=ckpt_to_model_prefix) - - # Verify weights match state dict - assert torch.equal(dummy_model.layer1.weight, state_dict["prefix1.layer1.weight"]) - assert torch.equal(dummy_model.layer1.bias, state_dict["prefix1.layer1.bias"]) - assert torch.equal(dummy_model.layer2.weight, state_dict["prefix2.layer2.weight"]) - assert torch.equal(dummy_model.layer2.bias, state_dict["prefix2.layer2.bias"]) - - -def test_adjust_prefix_and_load_state_dict_no_overlap(dummy_model, mismatched_state_dict_file): - """ - Tests that an error is raised when there is no overlap between checkpoint and model keys. - - Args: - dummy_model: Fixture providing a simple model for testing. - mismatched_state_dict_file: Fixture providing a state dict file with completely different keys. - - Tests: - - ValueError is raised when there's no overlap between checkpoint and model keys - - Error message contains information about both model and checkpoint keys - """ - with pytest.raises(ValueError) as exc_info: - adjust_prefix_and_load_state_dict(dummy_model, mismatched_state_dict_file) - - assert "There is no overlap between checkpoint's and model's state_dict" in str(exc_info.value) - assert "Model keys" in str(exc_info.value) - assert "Checkpoint keys" in str(exc_info.value) - - -def test_adjust_prefix_and_load_state_dict_empty_checkpoint(dummy_model, empty_state_dict_file): - """ - Tests that an error is raised when loading an empty checkpoint. - - Args: - dummy_model: Fixture providing a simple model for testing. - empty_state_dict_file: Fixture providing a state dict file with no keys. - - Tests: - - ValueError is raised when checkpoint is empty - - Error message shows empty checkpoint keys - """ - with pytest.raises(ValueError) as exc_info: - adjust_prefix_and_load_state_dict(dummy_model, empty_state_dict_file) - - assert "There is no overlap between checkpoint's and model's state_dict" in str(exc_info.value) - assert "Model keys" in str(exc_info.value) - assert "Checkpoint keys: []" in str(exc_info.value) - - -def test_adjust_prefix_and_load_state_dict_perfect_match(dummy_model, perfect_match_state_dict_file): - """ - Tests loading a state dict that perfectly matches the model structure. - - Args: - dummy_model: Fixture providing a simple model for testing. - perfect_match_state_dict_file: Fixture providing a state dict file with perfectly matching keys. - - Tests: - - Loading succeeds with no incompatible keys - - Success message is logged - """ - # Load state dict with perfect match - adjust_prefix_and_load_state_dict(dummy_model, perfect_match_state_dict_file) - # Note: We can't directly test the logger output, but the function should complete without raising exceptions - - -def test_adjust_prefix_and_load_state_dict_no_matching_prefix(dummy_model, state_dict_file): - """ - Tests loading a state dict with prefix mapping where no keys match the prefix. - - Args: - dummy_model: Fixture providing a simple model for testing. - state_dict_file: Fixture providing a state dict file and its contents. - - Tests: - - When no keys match the prefix mapping, the original state dict is used - - Loading succeeds with the original state dict - """ - path, state_dict = state_dict_file - - # Load state dict with prefix mapping that won't match any keys - ckpt_to_model_prefix = {"nonexistent_prefix": ""} - adjust_prefix_and_load_state_dict(dummy_model, path, ckpt_to_model_prefix=ckpt_to_model_prefix) - - # Verify weights match the original state dict - assert torch.equal(dummy_model.layer1.weight, state_dict["layer1.weight"]) - assert torch.equal(dummy_model.layer1.bias, state_dict["layer1.bias"]) diff --git a/tests/unit/test_utils_types_containers.py b/tests/unit/test_utils_types_containers.py deleted file mode 100644 index d03ab832..00000000 --- a/tests/unit/test_utils_types_containers.py +++ /dev/null @@ -1,138 +0,0 @@ -from dataclasses import dataclass, is_dataclass - -import pytest -from torchmetrics import Accuracy, MetricCollection - -from lighter.adapters import BatchAdapter, CriterionAdapter, LoggingAdapter, MetricsAdapter -from lighter.utils.types.containers import Metrics, Predict, Test, Train, Val, nested - - -# Define a nested dataclass for testing -@nested -@dataclass -class Inner: - value: int - - -@nested -@dataclass -class Outer: - inner: Inner - name: str = "default" - - -def test_nested_decorator_initialization(): - """Test that the nested decorator correctly initializes nested dataclasses.""" - data = {"inner": {"value": 10}, "name": "test"} - outer_instance = Outer(**data) - - assert isinstance(outer_instance, Outer) - assert isinstance(outer_instance.inner, Inner) - assert outer_instance.inner.value == 10 - assert outer_instance.name == "test" - - -def test_nested_decorator_default_values(): - """Test that the nested decorator respects default values.""" - data = {"inner": {"value": 5}} - outer_instance = Outer(**data) - - assert outer_instance.name == "default" - - -def test_nested_decorator_is_dataclass(): - """Test that the decorated class is still recognized as a dataclass.""" - assert is_dataclass(Outer) - assert is_dataclass(Inner) - - -def test_nested_decorator_missing_nested_data(): - """Test that missing nested data raises a TypeError.""" - with pytest.raises(TypeError): - Outer(name="test") - - -def test_metrics_convert_to_collection(): - """Test that _convert_to_collection converts non-MetricCollection to MetricCollection.""" - - # Create a Metrics instance with a single Metric - accuracy_metric = Accuracy(task="binary") - metrics_instance = Metrics(train=accuracy_metric) - - # Check if the train metric is converted to a MetricCollection - assert not isinstance(accuracy_metric, MetricCollection) - assert isinstance(metrics_instance.train, MetricCollection) - assert accuracy_metric in metrics_instance.train.values() - - -def test_metrics_convert_none_to_collection(): - """Test that _convert_to_collection handles None values correctly.""" - metrics_instance = Metrics(train=None, val=None, test=None) - - assert metrics_instance.train is None - assert metrics_instance.val is None - assert metrics_instance.test is None - - -def test_train_default_factory(): - """Test that the Train dataclass uses the correct default factories.""" - train_instance = Train() - - assert isinstance(train_instance.batch, BatchAdapter) - assert train_instance.batch.input_accessor == 0 - assert train_instance.batch.target_accessor == 1 - - assert isinstance(train_instance.criterion, CriterionAdapter) - assert train_instance.criterion.pred_argument == 0 - assert train_instance.criterion.target_argument == 1 - - assert isinstance(train_instance.metrics, MetricsAdapter) - assert train_instance.metrics.pred_argument == 0 - assert train_instance.metrics.target_argument == 1 - - assert isinstance(train_instance.logging, LoggingAdapter) - - -def test_val_default_factory(): - """Test that the Val dataclass uses the correct default factories.""" - val_instance = Val() - - assert isinstance(val_instance.batch, BatchAdapter) - assert val_instance.batch.input_accessor == 0 - assert val_instance.batch.target_accessor == 1 - - assert isinstance(val_instance.criterion, CriterionAdapter) - assert val_instance.criterion.pred_argument == 0 - assert val_instance.criterion.target_argument == 1 - - assert isinstance(val_instance.metrics, MetricsAdapter) - assert val_instance.metrics.pred_argument == 0 - assert val_instance.metrics.target_argument == 1 - - assert isinstance(val_instance.logging, LoggingAdapter) - - -def test_test_default_factory(): - """Test that the Test dataclass uses the correct default factories.""" - test_instance = Test() - - assert isinstance(test_instance.batch, BatchAdapter) - assert test_instance.batch.input_accessor == 0 - assert test_instance.batch.target_accessor == 1 - - assert isinstance(test_instance.metrics, MetricsAdapter) - assert test_instance.metrics.pred_argument == 0 - assert test_instance.metrics.target_argument == 1 - - assert isinstance(test_instance.logging, LoggingAdapter) - - -def test_predict_default_factory(): - """Test that the Predict dataclass uses the correct default factories.""" - predict_instance = Predict() - - assert isinstance(predict_instance.batch, BatchAdapter) - assert callable(predict_instance.batch.input_accessor) - assert predict_instance.batch.input_accessor("this is a batch") == "this is a batch" - - assert isinstance(predict_instance.logging, LoggingAdapter) diff --git a/uv.lock b/uv.lock index 5944a411..bc6651de 100644 --- a/uv.lock +++ b/uv.lock @@ -124,24 +124,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] -[[package]] -name = "appnope" -version = "0.1.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170, upload-time = "2024-02-06T09:43:11.258Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321, upload-time = "2024-02-06T09:43:09.663Z" }, -] - -[[package]] -name = "asttokens" -version = "3.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, -] - [[package]] name = "async-timeout" version = "4.0.3" @@ -160,37 +142,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/21/5b6702a7f963e95456c0de2d495f67bf5fd62840ac655dc451586d23d39a/attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2", size = 63001, upload-time = "2024-08-06T14:37:36.958Z" }, ] -[[package]] -name = "av" -version = "12.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/45/282a04df72c17cc6cdd86df51b77754f61d1a8f3d68155e1821a8a76e399/av-12.0.0.tar.gz", hash = "sha256:bcf21ebb722d4538b4099e5a78f730d78814dd70003511c185941dba5651b14d", size = 3760491, upload-time = "2024-03-21T11:17:51.518Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/ba/215bab7672c73575da3a4e25df27fcbc4dcb1c2cb715ac95de4a56eb6179/av-12.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9d0890553951f76c479a9f2bb952aebae902b1c7d52feea614d37e1cd728a44", size = 27272332, upload-time = "2024-03-21T11:15:11.828Z" }, - { url = "https://files.pythonhosted.org/packages/d3/0a/88384d0a01a97cf7f56f4731599a57289bc743558230e7b62cc53bfb0e47/av-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5d7f229a253c2e3fea9682c09c5ae179bd6d5d2da38d89eb7f29ef7bed10cb2f", size = 20726464, upload-time = "2024-03-21T11:15:15.786Z" }, - { url = "https://files.pythonhosted.org/packages/22/28/3b3c6889f1db50b07353a4b3a1a0b5216804fa46fc5e9597330620b84897/av-12.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61b3555d143aacf02e0446f6030319403538eba4dc713c18dfa653a2a23e7f9c", size = 31937450, upload-time = "2024-03-21T11:15:20.067Z" }, - { url = "https://files.pythonhosted.org/packages/69/cc/ed8489c00cdddec6c78706bf5bf092b9705337f5a2ac33a59926f8ac6548/av-12.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:607e13b2c2b26159a37525d7b6f647a32ce78711fccff23d146d3e255ffa115f", size = 31462934, upload-time = "2024-03-21T11:15:24.2Z" }, - { url = "https://files.pythonhosted.org/packages/50/c3/81e9efb59751b7a151c213f7976ce3bfac9a7786949947afbd60eee279df/av-12.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f0b4cfb89f4f06b339c766f92648e798a96747d4163f2fa78660d1ab1f1b5e", size = 33807200, upload-time = "2024-03-21T11:15:28.189Z" }, - { url = "https://files.pythonhosted.org/packages/17/cb/f99d79bd829522a587e16e89409f494f40a99b77ca414c3e591ffad40489/av-12.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:41dcb8c269fa58a56edf3a3c814c32a0c69586827f132b4e395a951b0ce14fad", size = 26327832, upload-time = "2024-03-21T11:15:32.299Z" }, - { url = "https://files.pythonhosted.org/packages/5b/da/f4c9673be50c1ce01e0f81c74432cedb0666fb4a1cff1e151d9a069d7b32/av-12.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fa78fbe0e4469226512380180063116105048c66cb12e18ab4b518466c57e6c", size = 27266665, upload-time = "2024-03-21T11:15:36.18Z" }, - { url = "https://files.pythonhosted.org/packages/15/54/b52aad45b95e266af8465fa405fb74c00f59faaf01ba88de69fac8e0ee6a/av-12.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60a869be1d6af916e65ea461cb93922f5db0698655ed7a7eae7c3ecd4af4debb", size = 20720210, upload-time = "2024-03-21T11:15:39.351Z" }, - { url = "https://files.pythonhosted.org/packages/64/c3/bccea1e7254c4d5ede7169bbee335fcc4162ee4bf6f8d321d63bff2524c6/av-12.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df61811cc551c186f0a0e530d97b8b139453534d0f92c1790a923f666522ceda", size = 32810453, upload-time = "2024-03-21T11:15:42.713Z" }, - { url = "https://files.pythonhosted.org/packages/58/83/992c4da55e06783557191c90e9530b16149f07c58da741df379f5abdd989/av-12.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99cd2fc53091ebfb9a2fa9dd3580267f5bd1c040d0efd99fbc1a162576b271cb", size = 32250599, upload-time = "2024-03-21T11:15:46.662Z" }, - { url = "https://files.pythonhosted.org/packages/51/95/d43a02410e4e1cb6397a675facf5a65285d670697c0762a0e29c0c3615b4/av-12.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a6d4f1e261df48932128e6495772faa4cc23f5dd1512eec73daab82ad9f3240", size = 34691295, upload-time = "2024-03-21T11:15:50.774Z" }, - { url = "https://files.pythonhosted.org/packages/8a/51/0e5f5e31d834061f5b17ebf53570539f8ad124412b3dbdd5d138f36721a0/av-12.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:6aec88e41a498b1e01e2dce5371557e20f9a51aae0c16decc5924ec0be2e22b6", size = 26328218, upload-time = "2024-03-21T11:15:54.196Z" }, - { url = "https://files.pythonhosted.org/packages/af/58/e74f30f35983339567ebd78a22e2fe710889c917ee7d0b179b17859c90ef/av-12.0.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90eb8f2d548e96cbc6f78e89c911cdb15a3d80fd944f31111660ce45939cd037", size = 27273235, upload-time = "2024-03-21T11:15:57.522Z" }, - { url = "https://files.pythonhosted.org/packages/9d/89/e419cf754f9dcd07913c7f3aeb70f4f4b54ce04d028d0eaea2beb3d9c7ce/av-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d7f3a02910e77d750dbd516256a16db15030e5371530ff5a5ae902dc03d9005d", size = 20723302, upload-time = "2024-03-21T11:16:00.772Z" }, - { url = "https://files.pythonhosted.org/packages/50/94/8fec8a9bf4771c8c93ecdb1784c97215ec4aeb35e743e58113599f381a31/av-12.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2477cc51526aa50575313d66e5e8ad7ab944588469be5e557b360ed572ae536", size = 33101646, upload-time = "2024-03-21T11:16:04.614Z" }, - { url = "https://files.pythonhosted.org/packages/2a/15/cc0fad46f6cfd6a1e8378155c9b15b76c813fca9e8b8f3687133350d714a/av-12.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a2f47149d3ca6deb79f3e515b8bef50e27ebdb160813e6d67dba77278d2a7883", size = 32565329, upload-time = "2024-03-21T11:16:08.011Z" }, - { url = "https://files.pythonhosted.org/packages/67/24/a287acccddeec8d16cde6cff5a0e230b71c895e7b7169221f81b548a22c7/av-12.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3306e4a3ce8b5bfcc3075793d4ed3a2df69179d8fba22cb944a6164dc235dfb6", size = 35066735, upload-time = "2024-03-21T11:16:11.575Z" }, - { url = "https://files.pythonhosted.org/packages/51/30/b8d2f6a57108632c8a9922d063480171bf85b64dbad88c8c0f8c99d21f8c/av-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:dc1b742e7f6df1b499fb960bd6697d1dd8e7ada7484a041a8c20e70a87225f53", size = 26331009, upload-time = "2024-03-21T11:16:14.91Z" }, - { url = "https://files.pythonhosted.org/packages/3f/a7/52858fd3c201fd555280f672d2697a1499a4f588eba9bd5b0931e70f1232/av-12.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4819e3ef6c3a44ef6f75907229133a1ee7f688245b2cf49b6b8e969a81ca72c9", size = 26896823, upload-time = "2024-03-21T11:17:03.641Z" }, - { url = "https://files.pythonhosted.org/packages/4b/71/618d0f7597b9cc9eedeefe87ce74ec27b5eb2b21679243c96a7eb79a630a/av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb16bb314cf1503b0250fc46b2c455ee196584231101be0123f4f78638227b62", size = 23790080, upload-time = "2024-03-21T11:17:07.722Z" }, - { url = "https://files.pythonhosted.org/packages/4f/37/04e18e21237d9a9cd72d07ba2c9c06e7243472fbac62cbc05aedcdcdc4b9/av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3e6a62bda9a1e144feeb59bbee046d7a2d98399634a30f57e4990197313c158", size = 23739697, upload-time = "2024-03-21T11:17:10.77Z" }, - { url = "https://files.pythonhosted.org/packages/ee/f3/9c16e120cf1573854d979400e6a147da359d98e9d2fc8cb4f7ae62fc8208/av-12.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08175ffbafa3a70c7b2f81083e160e34122a208cdf70f150b8f5d02c2de6965", size = 25634269, upload-time = "2024-03-21T11:17:14.566Z" }, - { url = "https://files.pythonhosted.org/packages/88/3a/315677e9baf1b32c9ca10bd77536fab5a65741be9aab8530bc80665e3eb1/av-12.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e1d255be317b7c1ebdc4dae98935b9f3869161112dc829c625e54f90d8bdd7ab", size = 26118575, upload-time = "2024-03-21T11:17:17.956Z" }, -] - [[package]] name = "babel" version = "2.17.0" @@ -250,63 +201,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/d5/c84e1a17bf61d4df64ca866a1c9a913874b4e9bdc131ec689a0ad013fb36/certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90", size = 162960, upload-time = "2024-07-04T01:36:09.038Z" }, ] -[[package]] -name = "cffi" -version = "1.17.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pycparser" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621, upload-time = "2024-09-04T20:45:21.852Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/90/07/f44ca684db4e4f08a3fdc6eeb9a0d15dc6883efc7b8c90357fdbf74e186c/cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", size = 182191, upload-time = "2024-09-04T20:43:30.027Z" }, - { url = "https://files.pythonhosted.org/packages/08/fd/cc2fedbd887223f9f5d170c96e57cbf655df9831a6546c1727ae13fa977a/cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", size = 178592, upload-time = "2024-09-04T20:43:32.108Z" }, - { url = "https://files.pythonhosted.org/packages/de/cc/4635c320081c78d6ffc2cab0a76025b691a91204f4aa317d568ff9280a2d/cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", size = 426024, upload-time = "2024-09-04T20:43:34.186Z" }, - { url = "https://files.pythonhosted.org/packages/b6/7b/3b2b250f3aab91abe5f8a51ada1b717935fdaec53f790ad4100fe2ec64d1/cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", size = 448188, upload-time = "2024-09-04T20:43:36.286Z" }, - { url = "https://files.pythonhosted.org/packages/d3/48/1b9283ebbf0ec065148d8de05d647a986c5f22586b18120020452fff8f5d/cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", size = 455571, upload-time = "2024-09-04T20:43:38.586Z" }, - { url = "https://files.pythonhosted.org/packages/40/87/3b8452525437b40f39ca7ff70276679772ee7e8b394934ff60e63b7b090c/cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", size = 436687, upload-time = "2024-09-04T20:43:40.084Z" }, - { url = "https://files.pythonhosted.org/packages/8d/fb/4da72871d177d63649ac449aec2e8a29efe0274035880c7af59101ca2232/cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", size = 446211, upload-time = "2024-09-04T20:43:41.526Z" }, - { url = "https://files.pythonhosted.org/packages/ab/a0/62f00bcb411332106c02b663b26f3545a9ef136f80d5df746c05878f8c4b/cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", size = 461325, upload-time = "2024-09-04T20:43:43.117Z" }, - { url = "https://files.pythonhosted.org/packages/36/83/76127035ed2e7e27b0787604d99da630ac3123bfb02d8e80c633f218a11d/cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", size = 438784, upload-time = "2024-09-04T20:43:45.256Z" }, - { url = "https://files.pythonhosted.org/packages/21/81/a6cd025db2f08ac88b901b745c163d884641909641f9b826e8cb87645942/cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", size = 461564, upload-time = "2024-09-04T20:43:46.779Z" }, - { url = "https://files.pythonhosted.org/packages/f8/fe/4d41c2f200c4a457933dbd98d3cf4e911870877bd94d9656cc0fcb390681/cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", size = 171804, upload-time = "2024-09-04T20:43:48.186Z" }, - { url = "https://files.pythonhosted.org/packages/d1/b6/0b0f5ab93b0df4acc49cae758c81fe4e5ef26c3ae2e10cc69249dfd8b3ab/cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", size = 181299, upload-time = "2024-09-04T20:43:49.812Z" }, - { url = "https://files.pythonhosted.org/packages/6b/f4/927e3a8899e52a27fa57a48607ff7dc91a9ebe97399b357b85a0c7892e00/cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", size = 182264, upload-time = "2024-09-04T20:43:51.124Z" }, - { url = "https://files.pythonhosted.org/packages/6c/f5/6c3a8efe5f503175aaddcbea6ad0d2c96dad6f5abb205750d1b3df44ef29/cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", size = 178651, upload-time = "2024-09-04T20:43:52.872Z" }, - { url = "https://files.pythonhosted.org/packages/94/dd/a3f0118e688d1b1a57553da23b16bdade96d2f9bcda4d32e7d2838047ff7/cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", size = 445259, upload-time = "2024-09-04T20:43:56.123Z" }, - { url = "https://files.pythonhosted.org/packages/2e/ea/70ce63780f096e16ce8588efe039d3c4f91deb1dc01e9c73a287939c79a6/cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", size = 469200, upload-time = "2024-09-04T20:43:57.891Z" }, - { url = "https://files.pythonhosted.org/packages/1c/a0/a4fa9f4f781bda074c3ddd57a572b060fa0df7655d2a4247bbe277200146/cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", size = 477235, upload-time = "2024-09-04T20:44:00.18Z" }, - { url = "https://files.pythonhosted.org/packages/62/12/ce8710b5b8affbcdd5c6e367217c242524ad17a02fe5beec3ee339f69f85/cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", size = 459721, upload-time = "2024-09-04T20:44:01.585Z" }, - { url = "https://files.pythonhosted.org/packages/ff/6b/d45873c5e0242196f042d555526f92aa9e0c32355a1be1ff8c27f077fd37/cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", size = 467242, upload-time = "2024-09-04T20:44:03.467Z" }, - { url = "https://files.pythonhosted.org/packages/1a/52/d9a0e523a572fbccf2955f5abe883cfa8bcc570d7faeee06336fbd50c9fc/cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", size = 477999, upload-time = "2024-09-04T20:44:05.023Z" }, - { url = "https://files.pythonhosted.org/packages/44/74/f2a2460684a1a2d00ca799ad880d54652841a780c4c97b87754f660c7603/cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", size = 454242, upload-time = "2024-09-04T20:44:06.444Z" }, - { url = "https://files.pythonhosted.org/packages/f8/4a/34599cac7dfcd888ff54e801afe06a19c17787dfd94495ab0c8d35fe99fb/cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b", size = 478604, upload-time = "2024-09-04T20:44:08.206Z" }, - { url = "https://files.pythonhosted.org/packages/34/33/e1b8a1ba29025adbdcda5fb3a36f94c03d771c1b7b12f726ff7fef2ebe36/cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", size = 171727, upload-time = "2024-09-04T20:44:09.481Z" }, - { url = "https://files.pythonhosted.org/packages/3d/97/50228be003bb2802627d28ec0627837ac0bf35c90cf769812056f235b2d1/cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", size = 181400, upload-time = "2024-09-04T20:44:10.873Z" }, - { url = "https://files.pythonhosted.org/packages/5a/84/e94227139ee5fb4d600a7a4927f322e1d4aea6fdc50bd3fca8493caba23f/cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", size = 183178, upload-time = "2024-09-04T20:44:12.232Z" }, - { url = "https://files.pythonhosted.org/packages/da/ee/fb72c2b48656111c4ef27f0f91da355e130a923473bf5ee75c5643d00cca/cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", size = 178840, upload-time = "2024-09-04T20:44:13.739Z" }, - { url = "https://files.pythonhosted.org/packages/cc/b6/db007700f67d151abadf508cbfd6a1884f57eab90b1bb985c4c8c02b0f28/cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", size = 454803, upload-time = "2024-09-04T20:44:15.231Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/f8d151540d8c200eb1c6fba8cd0dfd40904f1b0682ea705c36e6c2e97ab3/cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", size = 478850, upload-time = "2024-09-04T20:44:17.188Z" }, - { url = "https://files.pythonhosted.org/packages/28/c0/b31116332a547fd2677ae5b78a2ef662dfc8023d67f41b2a83f7c2aa78b1/cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", size = 485729, upload-time = "2024-09-04T20:44:18.688Z" }, - { url = "https://files.pythonhosted.org/packages/91/2b/9a1ddfa5c7f13cab007a2c9cc295b70fbbda7cb10a286aa6810338e60ea1/cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", size = 471256, upload-time = "2024-09-04T20:44:20.248Z" }, - { url = "https://files.pythonhosted.org/packages/b2/d5/da47df7004cb17e4955df6a43d14b3b4ae77737dff8bf7f8f333196717bf/cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", size = 479424, upload-time = "2024-09-04T20:44:21.673Z" }, - { url = "https://files.pythonhosted.org/packages/0b/ac/2a28bcf513e93a219c8a4e8e125534f4f6db03e3179ba1c45e949b76212c/cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", size = 484568, upload-time = "2024-09-04T20:44:23.245Z" }, - { url = "https://files.pythonhosted.org/packages/d4/38/ca8a4f639065f14ae0f1d9751e70447a261f1a30fa7547a828ae08142465/cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", size = 488736, upload-time = "2024-09-04T20:44:24.757Z" }, - { url = "https://files.pythonhosted.org/packages/86/c5/28b2d6f799ec0bdecf44dced2ec5ed43e0eb63097b0f58c293583b406582/cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", size = 172448, upload-time = "2024-09-04T20:44:26.208Z" }, - { url = "https://files.pythonhosted.org/packages/50/b9/db34c4755a7bd1cb2d1603ac3863f22bcecbd1ba29e5ee841a4bc510b294/cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", size = 181976, upload-time = "2024-09-04T20:44:27.578Z" }, - { url = "https://files.pythonhosted.org/packages/8d/f8/dd6c246b148639254dad4d6803eb6a54e8c85c6e11ec9df2cffa87571dbe/cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", size = 182989, upload-time = "2024-09-04T20:44:28.956Z" }, - { url = "https://files.pythonhosted.org/packages/8b/f1/672d303ddf17c24fc83afd712316fda78dc6fce1cd53011b839483e1ecc8/cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", size = 178802, upload-time = "2024-09-04T20:44:30.289Z" }, - { url = "https://files.pythonhosted.org/packages/0e/2d/eab2e858a91fdff70533cab61dcff4a1f55ec60425832ddfdc9cd36bc8af/cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", size = 454792, upload-time = "2024-09-04T20:44:32.01Z" }, - { url = "https://files.pythonhosted.org/packages/75/b2/fbaec7c4455c604e29388d55599b99ebcc250a60050610fadde58932b7ee/cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", size = 478893, upload-time = "2024-09-04T20:44:33.606Z" }, - { url = "https://files.pythonhosted.org/packages/4f/b7/6e4a2162178bf1935c336d4da8a9352cccab4d3a5d7914065490f08c0690/cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", size = 485810, upload-time = "2024-09-04T20:44:35.191Z" }, - { url = "https://files.pythonhosted.org/packages/c7/8a/1d0e4a9c26e54746dc08c2c6c037889124d4f59dffd853a659fa545f1b40/cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", size = 471200, upload-time = "2024-09-04T20:44:36.743Z" }, - { url = "https://files.pythonhosted.org/packages/26/9f/1aab65a6c0db35f43c4d1b4f580e8df53914310afc10ae0397d29d697af4/cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", size = 479447, upload-time = "2024-09-04T20:44:38.492Z" }, - { url = "https://files.pythonhosted.org/packages/5f/e4/fb8b3dd8dc0e98edf1135ff067ae070bb32ef9d509d6cb0f538cd6f7483f/cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", size = 484358, upload-time = "2024-09-04T20:44:40.046Z" }, - { url = "https://files.pythonhosted.org/packages/f1/47/d7145bf2dc04684935d57d67dff9d6d795b2ba2796806bb109864be3a151/cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", size = 488469, upload-time = "2024-09-04T20:44:41.616Z" }, - { url = "https://files.pythonhosted.org/packages/bf/ee/f94057fa6426481d663b88637a9a10e859e492c73d0384514a17d78ee205/cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", size = 172475, upload-time = "2024-09-04T20:44:43.733Z" }, - { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009, upload-time = "2024-09-04T20:44:45.309Z" }, -] - [[package]] name = "cfgv" version = "3.4.0" @@ -383,24 +277,21 @@ wheels = [ ] [[package]] -name = "colorama" -version = "0.4.6" +name = "cloudpickle" +version = "3.1.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, ] [[package]] -name = "comm" -version = "0.2.2" +name = "colorama" +version = "0.4.6" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "traitlets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/a8/fb783cb0abe2b5fded9f55e5703015cdf1c9c85b3669087c538dd15a6a86/comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e", size = 6210, upload-time = "2024-03-12T16:53:41.133Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180, upload-time = "2024-03-12T16:53:39.226Z" }, + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] [[package]] @@ -467,53 +358,6 @@ toml = [ { name = "tomli", marker = "python_full_version <= '3.11'" }, ] -[[package]] -name = "coverage-badge" -version = "1.1.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "coverage" }, - { name = "setuptools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/be/8f/e92b0a010c76b0da82709838b3f3ae9aec638d0c44dbfb1186a5751f5d2e/coverage_badge-1.1.2.tar.gz", hash = "sha256:fe7ed58a3b72dad85a553b64a99e963dea3847dcd0b8ddd2b38a00333618642c", size = 6335, upload-time = "2024-08-02T23:34:08.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/90/3d/5642a1a06191b2e1e0f87a2e824e6d3eb7c32c589a68ed4d1dcbd3324d63/coverage_badge-1.1.2-py2.py3-none-any.whl", hash = "sha256:d8413ce51c91043a1692b943616b450868cbeeb0ea6a0c54a32f8318c9c96ff7", size = 6493, upload-time = "2024-08-02T23:34:07.063Z" }, -] - -[[package]] -name = "debugpy" -version = "1.8.14" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bd/75/087fe07d40f490a78782ff3b0a30e3968936854105487decdb33446d4b0e/debugpy-1.8.14.tar.gz", hash = "sha256:7cd287184318416850aa8b60ac90105837bb1e59531898c07569d197d2ed5322", size = 1641444, upload-time = "2025-04-10T19:46:10.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/df/156df75a41aaebd97cee9d3870fe68f8001b6c1c4ca023e221cfce69bece/debugpy-1.8.14-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:93fee753097e85623cab1c0e6a68c76308cd9f13ffdf44127e6fab4fbf024339", size = 2076510, upload-time = "2025-04-10T19:46:13.315Z" }, - { url = "https://files.pythonhosted.org/packages/69/cd/4fc391607bca0996db5f3658762106e3d2427beaef9bfd363fd370a3c054/debugpy-1.8.14-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d937d93ae4fa51cdc94d3e865f535f185d5f9748efb41d0d49e33bf3365bd79", size = 3559614, upload-time = "2025-04-10T19:46:14.647Z" }, - { url = "https://files.pythonhosted.org/packages/1a/42/4e6d2b9d63e002db79edfd0cb5656f1c403958915e0e73ab3e9220012eec/debugpy-1.8.14-cp310-cp310-win32.whl", hash = "sha256:c442f20577b38cc7a9aafecffe1094f78f07fb8423c3dddb384e6b8f49fd2987", size = 5208588, upload-time = "2025-04-10T19:46:16.233Z" }, - { url = "https://files.pythonhosted.org/packages/97/b1/cc9e4e5faadc9d00df1a64a3c2d5c5f4b9df28196c39ada06361c5141f89/debugpy-1.8.14-cp310-cp310-win_amd64.whl", hash = "sha256:f117dedda6d969c5c9483e23f573b38f4e39412845c7bc487b6f2648df30fe84", size = 5241043, upload-time = "2025-04-10T19:46:17.768Z" }, - { url = "https://files.pythonhosted.org/packages/67/e8/57fe0c86915671fd6a3d2d8746e40485fd55e8d9e682388fbb3a3d42b86f/debugpy-1.8.14-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:1b2ac8c13b2645e0b1eaf30e816404990fbdb168e193322be8f545e8c01644a9", size = 2175064, upload-time = "2025-04-10T19:46:19.486Z" }, - { url = "https://files.pythonhosted.org/packages/3b/97/2b2fd1b1c9569c6764ccdb650a6f752e4ac31be465049563c9eb127a8487/debugpy-1.8.14-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf431c343a99384ac7eab2f763980724834f933a271e90496944195318c619e2", size = 3132359, upload-time = "2025-04-10T19:46:21.192Z" }, - { url = "https://files.pythonhosted.org/packages/c0/ee/b825c87ed06256ee2a7ed8bab8fb3bb5851293bf9465409fdffc6261c426/debugpy-1.8.14-cp311-cp311-win32.whl", hash = "sha256:c99295c76161ad8d507b413cd33422d7c542889fbb73035889420ac1fad354f2", size = 5133269, upload-time = "2025-04-10T19:46:23.047Z" }, - { url = "https://files.pythonhosted.org/packages/d5/a6/6c70cd15afa43d37839d60f324213843174c1d1e6bb616bd89f7c1341bac/debugpy-1.8.14-cp311-cp311-win_amd64.whl", hash = "sha256:7816acea4a46d7e4e50ad8d09d963a680ecc814ae31cdef3622eb05ccacf7b01", size = 5158156, upload-time = "2025-04-10T19:46:24.521Z" }, - { url = "https://files.pythonhosted.org/packages/d9/2a/ac2df0eda4898f29c46eb6713a5148e6f8b2b389c8ec9e425a4a1d67bf07/debugpy-1.8.14-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:8899c17920d089cfa23e6005ad9f22582fd86f144b23acb9feeda59e84405b84", size = 2501268, upload-time = "2025-04-10T19:46:26.044Z" }, - { url = "https://files.pythonhosted.org/packages/10/53/0a0cb5d79dd9f7039169f8bf94a144ad3efa52cc519940b3b7dde23bcb89/debugpy-1.8.14-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6bb5c0dcf80ad5dbc7b7d6eac484e2af34bdacdf81df09b6a3e62792b722826", size = 4221077, upload-time = "2025-04-10T19:46:27.464Z" }, - { url = "https://files.pythonhosted.org/packages/f8/d5/84e01821f362327bf4828728aa31e907a2eca7c78cd7c6ec062780d249f8/debugpy-1.8.14-cp312-cp312-win32.whl", hash = "sha256:281d44d248a0e1791ad0eafdbbd2912ff0de9eec48022a5bfbc332957487ed3f", size = 5255127, upload-time = "2025-04-10T19:46:29.467Z" }, - { url = "https://files.pythonhosted.org/packages/33/16/1ed929d812c758295cac7f9cf3dab5c73439c83d9091f2d91871e648093e/debugpy-1.8.14-cp312-cp312-win_amd64.whl", hash = "sha256:5aa56ef8538893e4502a7d79047fe39b1dae08d9ae257074c6464a7b290b806f", size = 5297249, upload-time = "2025-04-10T19:46:31.538Z" }, - { url = "https://files.pythonhosted.org/packages/4d/e4/395c792b243f2367d84202dc33689aa3d910fb9826a7491ba20fc9e261f5/debugpy-1.8.14-cp313-cp313-macosx_14_0_universal2.whl", hash = "sha256:329a15d0660ee09fec6786acdb6e0443d595f64f5d096fc3e3ccf09a4259033f", size = 2485676, upload-time = "2025-04-10T19:46:32.96Z" }, - { url = "https://files.pythonhosted.org/packages/ba/f1/6f2ee3f991327ad9e4c2f8b82611a467052a0fb0e247390192580e89f7ff/debugpy-1.8.14-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f920c7f9af409d90f5fd26e313e119d908b0dd2952c2393cd3247a462331f15", size = 4217514, upload-time = "2025-04-10T19:46:34.336Z" }, - { url = "https://files.pythonhosted.org/packages/79/28/b9d146f8f2dc535c236ee09ad3e5ac899adb39d7a19b49f03ac95d216beb/debugpy-1.8.14-cp313-cp313-win32.whl", hash = "sha256:3784ec6e8600c66cbdd4ca2726c72d8ca781e94bce2f396cc606d458146f8f4e", size = 5254756, upload-time = "2025-04-10T19:46:36.199Z" }, - { url = "https://files.pythonhosted.org/packages/e0/62/a7b4a57013eac4ccaef6977966e6bec5c63906dd25a86e35f155952e29a1/debugpy-1.8.14-cp313-cp313-win_amd64.whl", hash = "sha256:684eaf43c95a3ec39a96f1f5195a7ff3d4144e4a18d69bb66beeb1a6de605d6e", size = 5297119, upload-time = "2025-04-10T19:46:38.141Z" }, - { url = "https://files.pythonhosted.org/packages/97/1a/481f33c37ee3ac8040d3d51fc4c4e4e7e61cb08b8bc8971d6032acc2279f/debugpy-1.8.14-py2.py3-none-any.whl", hash = "sha256:5cd9a579d553b6cb9759a7908a41988ee6280b961f24f63336835d9418216a20", size = 5256230, upload-time = "2025-04-10T19:46:54.077Z" }, -] - -[[package]] -name = "decorator" -version = "5.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, -] - [[package]] name = "distlib" version = "0.3.9" @@ -532,15 +376,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453, upload-time = "2024-07-12T22:25:58.476Z" }, ] -[[package]] -name = "executing" -version = "2.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693, upload-time = "2025-01-22T15:41:29.403Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" }, -] - [[package]] name = "filelock" version = "3.15.4" @@ -704,124 +539,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892, upload-time = "2023-01-07T11:08:09.864Z" }, ] -[[package]] -name = "ipykernel" -version = "6.29.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "appnope", marker = "sys_platform == 'darwin'" }, - { name = "comm" }, - { name = "debugpy" }, - { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "jupyter-client" }, - { name = "jupyter-core" }, - { name = "matplotlib-inline" }, - { name = "nest-asyncio" }, - { name = "packaging" }, - { name = "psutil" }, - { name = "pyzmq" }, - { name = "tornado" }, - { name = "traitlets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/5c/67594cb0c7055dc50814b21731c22a601101ea3b1b50a9a1b090e11f5d0f/ipykernel-6.29.5.tar.gz", hash = "sha256:f093a22c4a40f8828f8e330a9c297cb93dcab13bd9678ded6de8e5cf81c56215", size = 163367, upload-time = "2024-07-01T14:07:22.543Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/5c/368ae6c01c7628438358e6d337c19b05425727fbb221d2a3c4303c372f42/ipykernel-6.29.5-py3-none-any.whl", hash = "sha256:afdb66ba5aa354b09b91379bac28ae4afebbb30e8b39510c9690afb7a10421b5", size = 117173, upload-time = "2024-07-01T14:07:19.603Z" }, -] - -[[package]] -name = "ipython" -version = "8.37.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] -dependencies = [ - { name = "colorama", marker = "python_full_version < '3.11' and sys_platform == 'win32'" }, - { name = "decorator", marker = "python_full_version < '3.11'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "jedi", marker = "python_full_version < '3.11'" }, - { name = "matplotlib-inline", marker = "python_full_version < '3.11'" }, - { name = "pexpect", marker = "python_full_version < '3.11' and sys_platform != 'emscripten' and sys_platform != 'win32'" }, - { name = "prompt-toolkit", marker = "python_full_version < '3.11'" }, - { name = "pygments", marker = "python_full_version < '3.11'" }, - { name = "stack-data", marker = "python_full_version < '3.11'" }, - { name = "traitlets", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/85/31/10ac88f3357fc276dc8a64e8880c82e80e7459326ae1d0a211b40abf6665/ipython-8.37.0.tar.gz", hash = "sha256:ca815841e1a41a1e6b73a0b08f3038af9b2252564d01fc405356d34033012216", size = 5606088, upload-time = "2025-05-31T16:39:09.613Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/91/d0/274fbf7b0b12643cbbc001ce13e6a5b1607ac4929d1b11c72460152c9fc3/ipython-8.37.0-py3-none-any.whl", hash = "sha256:ed87326596b878932dbcb171e3e698845434d8c61b8d8cd474bf663041a9dcf2", size = 831864, upload-time = "2025-05-31T16:39:06.38Z" }, -] - -[[package]] -name = "ipython" -version = "9.4.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.13'", - "python_full_version == '3.12.*'", - "python_full_version == '3.11.*'", -] -dependencies = [ - { name = "colorama", marker = "python_full_version >= '3.11' and sys_platform == 'win32'" }, - { name = "decorator", marker = "python_full_version >= '3.11'" }, - { name = "ipython-pygments-lexers", marker = "python_full_version >= '3.11'" }, - { name = "jedi", marker = "python_full_version >= '3.11'" }, - { name = "matplotlib-inline", marker = "python_full_version >= '3.11'" }, - { name = "pexpect", marker = "python_full_version >= '3.11' and sys_platform != 'emscripten' and sys_platform != 'win32'" }, - { name = "prompt-toolkit", marker = "python_full_version >= '3.11'" }, - { name = "pygments", marker = "python_full_version >= '3.11'" }, - { name = "stack-data", marker = "python_full_version >= '3.11'" }, - { name = "traitlets", marker = "python_full_version >= '3.11'" }, - { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/54/80/406f9e3bde1c1fd9bf5a0be9d090f8ae623e401b7670d8f6fdf2ab679891/ipython-9.4.0.tar.gz", hash = "sha256:c033c6d4e7914c3d9768aabe76bbe87ba1dc66a92a05db6bfa1125d81f2ee270", size = 4385338, upload-time = "2025-07-01T11:11:30.606Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/63/f8/0031ee2b906a15a33d6bfc12dd09c3dfa966b3cb5b284ecfb7549e6ac3c4/ipython-9.4.0-py3-none-any.whl", hash = "sha256:25850f025a446d9b359e8d296ba175a36aedd32e83ca9b5060430fe16801f066", size = 611021, upload-time = "2025-07-01T11:11:27.85Z" }, -] - -[[package]] -name = "ipython-pygments-lexers" -version = "1.1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pygments", marker = "python_full_version >= '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, -] - -[[package]] -name = "ipywidgets" -version = "8.1.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "comm" }, - { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "jupyterlab-widgets" }, - { name = "traitlets" }, - { name = "widgetsnbextension" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c7/4c/dab2a281b07596a5fc220d49827fe6c794c66f1493d7a74f1df0640f2cc5/ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17", size = 116723, upload-time = "2024-08-22T12:19:51.302Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/22/2d/9c0b76f2f9cc0ebede1b9371b6f317243028ed60b90705863d493bae622e/ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245", size = 139767, upload-time = "2024-08-22T12:19:49.494Z" }, -] - -[[package]] -name = "jedi" -version = "0.19.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "parso" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, -] - [[package]] name = "jinja2" version = "3.1.4" @@ -834,50 +551,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/80/3a54838c3fb461f6fec263ebf3a3a41771bd05190238de3486aae8540c36/jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d", size = 133271, upload-time = "2024-05-05T23:41:59.928Z" }, ] -[[package]] -name = "jupyter-client" -version = "8.6.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jupyter-core" }, - { name = "python-dateutil" }, - { name = "pyzmq" }, - { name = "tornado" }, - { name = "traitlets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019, upload-time = "2024-09-17T10:44:17.613Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105, upload-time = "2024-09-17T10:44:15.218Z" }, -] - -[[package]] -name = "jupyter-core" -version = "5.8.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, - { name = "traitlets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/99/1b/72906d554acfeb588332eaaa6f61577705e9ec752ddb486f302dafa292d9/jupyter_core-5.8.1.tar.gz", hash = "sha256:0a5f9706f70e64786b75acba995988915ebd4601c8a52e534a40b51c95f59941", size = 88923, upload-time = "2025-05-27T07:38:16.655Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/57/6bffd4b20b88da3800c5d691e0337761576ee688eb01299eae865689d2df/jupyter_core-5.8.1-py3-none-any.whl", hash = "sha256:c28d268fc90fb53f1338ded2eb410704c5449a358406e8a948b75706e24863d0", size = 28880, upload-time = "2025-05-27T07:38:15.137Z" }, -] - -[[package]] -name = "jupyterlab-widgets" -version = "3.0.15" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b9/7d/160595ca88ee87ac6ba95d82177d29ec60aaa63821d3077babb22ce031a5/jupyterlab_widgets-3.0.15.tar.gz", hash = "sha256:2920888a0c2922351a9202817957a68c07d99673504d6cd37345299e971bb08b", size = 213149, upload-time = "2025-05-05T12:32:31.004Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571, upload-time = "2025-05-05T12:32:29.534Z" }, -] - [[package]] name = "lighter" -version = "0.0.4" +version = "0.0.5" source = { editable = "." } dependencies = [ + { name = "cloudpickle" }, { name = "loguru" }, { name = "numpy" }, { name = "pandas" }, @@ -893,13 +572,8 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "aiohttp" }, - { name = "av" }, { name = "bump-my-version" }, { name = "coverage" }, - { name = "coverage-badge" }, - { name = "ipykernel" }, - { name = "ipywidgets" }, { name = "mkdocs-autorefs" }, { name = "mkdocs-gen-files" }, { name = "mkdocs-literate-nav" }, @@ -910,9 +584,11 @@ dev = [ { name = "pre-commit-uv" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "pytest-github-actions-annotate-failures" }, { name = "pytest-html" }, { name = "pytest-metadata" }, { name = "ruff" }, + { name = "types-pyyaml" }, { name = "typing-extensions" }, ] doc = [ @@ -930,29 +606,29 @@ quality = [ { name = "ruff" }, ] test = [ - { name = "aiohttp" }, - { name = "av" }, { name = "coverage" }, - { name = "coverage-badge" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "pytest-github-actions-annotate-failures" }, { name = "pytest-html" }, { name = "pytest-metadata" }, ] types = [ { name = "mypy" }, + { name = "types-pyyaml" }, { name = "typing-extensions" }, ] [package.metadata] requires-dist = [ + { name = "cloudpickle", specifier = ">=3.0.0" }, { name = "loguru", specifier = ">=0.6.0" }, { name = "numpy", specifier = "<2.0.0" }, { name = "pandas", specifier = ">=1.5.3" }, { name = "pytorch-lightning", specifier = ">=2.1.3" }, { name = "requests", specifier = ">=2.31.0" }, { name = "rich", specifier = ">=13.7.0" }, - { name = "sparkwheel", specifier = ">=0.0.5" }, + { name = "sparkwheel", specifier = ">=0.0.9" }, { name = "tensorboard", specifier = ">=2.11.2" }, { name = "torch", specifier = ">=2.1.2" }, { name = "torchmetrics", specifier = ">=1.2.0" }, @@ -961,13 +637,8 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "aiohttp", specifier = ">=3.8.3" }, - { name = "av", specifier = ">=12.0.0" }, { name = "bump-my-version", specifier = "==0.30.1" }, { name = "coverage", specifier = ">=7.0.5" }, - { name = "coverage-badge", specifier = ">=1.1.0" }, - { name = "ipykernel", specifier = "==6.29.5" }, - { name = "ipywidgets", specifier = "==8.1.5" }, { name = "mkdocs-autorefs", specifier = ">=1.4.1" }, { name = "mkdocs-gen-files", specifier = ">=0.5.0" }, { name = "mkdocs-literate-nav", specifier = ">=0.6.2" }, @@ -978,9 +649,11 @@ dev = [ { name = "pre-commit-uv", specifier = "==4.1.4" }, { name = "pytest", specifier = ">=7.4.0" }, { name = "pytest-cov", specifier = ">=4.0.0" }, + { name = "pytest-github-actions-annotate-failures", specifier = ">=0.2.0" }, { name = "pytest-html", specifier = ">=3.2.0" }, { name = "pytest-metadata", specifier = ">=3.1.1" }, { name = "ruff", specifier = ">=0.5.0" }, + { name = "types-pyyaml", specifier = ">=6.0.0" }, { name = "typing-extensions", specifier = ">=4.4.0" }, ] doc = [ @@ -994,17 +667,16 @@ doc = [ maintain = [{ name = "bump-my-version", specifier = "==0.30.1" }] quality = [{ name = "ruff", specifier = ">=0.5.0" }] test = [ - { name = "aiohttp", specifier = ">=3.8.3" }, - { name = "av", specifier = ">=12.0.0" }, { name = "coverage", specifier = ">=7.0.5" }, - { name = "coverage-badge", specifier = ">=1.1.0" }, { name = "pytest", specifier = ">=7.4.0" }, { name = "pytest-cov", specifier = ">=4.0.0" }, + { name = "pytest-github-actions-annotate-failures", specifier = ">=0.2.0" }, { name = "pytest-html", specifier = ">=3.2.0" }, { name = "pytest-metadata", specifier = ">=3.1.1" }, ] types = [ { name = "mypy", specifier = ">=1.14.1" }, + { name = "types-pyyaml", specifier = ">=6.0.0" }, { name = "typing-extensions", specifier = ">=4.4.0" }, ] @@ -1094,18 +766,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/14/c3554d512d5f9100a95e737502f4a2323a1959f6d0d01e0d0997b35f7b10/MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb", size = 17127, upload-time = "2024-02-02T16:30:44.418Z" }, ] -[[package]] -name = "matplotlib-inline" -version = "0.1.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "traitlets" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159, upload-time = "2024-04-15T13:44:44.803Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, -] - [[package]] name = "mdurl" version = "0.1.2" @@ -1391,15 +1051,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] -[[package]] -name = "nest-asyncio" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, -] - [[package]] name = "networkx" version = "3.2.1" @@ -1580,36 +1231,63 @@ wheels = [ [[package]] name = "pandas" -version = "1.5.3" +version = "2.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "python-dateutil" }, { name = "pytz" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/74/ee/146cab1ff6d575b54ace8a6a5994048380dc94879b0125b25e62edcb9e52/pandas-1.5.3.tar.gz", hash = "sha256:74a3fd7e5a7ec052f183273dc7b0acd3a863edf7520f5d3a1765c04ffdb3b0b1", size = 5203060, upload-time = "2023-01-19T08:31:39.615Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/cd/34f6b0780301be81be804d7aa71d571457369e6131e2b330af2b0fed1aad/pandas-1.5.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3749077d86e3a2f0ed51367f30bf5b82e131cc0f14260c4d3e499186fccc4406", size = 18619230, upload-time = "2023-01-19T08:29:07.301Z" }, - { url = "https://files.pythonhosted.org/packages/5f/34/b7858bb7d6d6bf4d9df1dde777a11fcf3ff370e1d1b3956e3d0fcca8322c/pandas-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:972d8a45395f2a2d26733eb8d0f629b2f90bebe8e8eddbb8829b180c09639572", size = 11982991, upload-time = "2023-01-19T08:29:15.383Z" }, - { url = "https://files.pythonhosted.org/packages/b8/6c/005bd604994f7cbede4d7bf030614ef49a2213f76bc3d738ecf5b0dcc810/pandas-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50869a35cbb0f2e0cd5ec04b191e7b12ed688874bd05dd777c19b28cbea90996", size = 10927131, upload-time = "2023-01-19T08:29:20.342Z" }, - { url = "https://files.pythonhosted.org/packages/27/c7/35b81ce5f680f2dac55eac14d103245cd8cf656ae4a2ff3be2e69fd1d330/pandas-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ac844a0fe00bfaeb2c9b51ab1424e5c8744f89860b138434a363b1f620f354", size = 11368188, upload-time = "2023-01-19T08:29:25.807Z" }, - { url = "https://files.pythonhosted.org/packages/49/e2/79e46612dc25ebc7603dc11c560baa7266c90f9e48537ecf1a02a0dd6bff/pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0a56cef15fd1586726dace5616db75ebcfec9179a3a55e78f72c5639fa2a23", size = 12062104, upload-time = "2023-01-19T08:29:30.695Z" }, - { url = "https://files.pythonhosted.org/packages/d9/cd/f27c2992cbe05a3e39937f73a4be635a9ec149ec3ca4467d8cf039718994/pandas-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:478ff646ca42b20376e4ed3fa2e8d7341e8a63105586efe54fa2508ee087f328", size = 10362473, upload-time = "2023-01-19T08:29:37.506Z" }, - { url = "https://files.pythonhosted.org/packages/e2/24/a26af514113fd5eca2d8fe41ba4f22f70dfe6afefde4a6beb6a203570935/pandas-1.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6973549c01ca91ec96199e940495219c887ea815b2083722821f1d7abfa2b4dc", size = 18387750, upload-time = "2023-01-19T08:29:43.119Z" }, - { url = "https://files.pythonhosted.org/packages/53/c9/d2f910dace7ef849b626980d0fd033b9cded36568949c8d560c9630ad2e0/pandas-1.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c39a8da13cede5adcd3be1182883aea1c925476f4e84b2807a46e2775306305d", size = 11868668, upload-time = "2023-01-19T08:29:48.733Z" }, - { url = "https://files.pythonhosted.org/packages/b0/be/1843b9aff84b98899663e7cad9f45513dfdd11d69cb5bd85c648aaf6a8d4/pandas-1.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f76d097d12c82a535fda9dfe5e8dd4127952b45fea9b0276cb30cca5ea313fbc", size = 10814036, upload-time = "2023-01-19T08:29:54.886Z" }, - { url = "https://files.pythonhosted.org/packages/63/8d/c2bd356b9d4baf1c5cf8d7e251fb4540e87083072c905430da48c2bb31eb/pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae", size = 11374218, upload-time = "2023-01-19T08:30:00.5Z" }, - { url = "https://files.pythonhosted.org/packages/56/73/3351beeb807dca69fcc3c4966bcccc51552bd01549a9b13c04ab00a43f21/pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6", size = 12017319, upload-time = "2023-01-19T08:30:06.097Z" }, - { url = "https://files.pythonhosted.org/packages/da/6d/1235da14daddaa6e47f74ba0c255358f0ce7a6ee05da8bf8eb49161aa6b5/pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003", size = 10303385, upload-time = "2023-01-19T08:30:11.148Z" }, -] - -[[package]] -name = "parso" -version = "0.8.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609, upload-time = "2024-04-05T09:43:55.897Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b", size = 4495223, upload-time = "2025-09-29T23:34:51.853Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/f7/f425a00df4fcc22b292c6895c6831c0c8ae1d9fac1e024d16f98a9ce8749/pandas-2.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:376c6446ae31770764215a6c937f72d917f214b43560603cd60da6408f183b6c", size = 11555763, upload-time = "2025-09-29T23:16:53.287Z" }, + { url = "https://files.pythonhosted.org/packages/13/4f/66d99628ff8ce7857aca52fed8f0066ce209f96be2fede6cef9f84e8d04f/pandas-2.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e19d192383eab2f4ceb30b412b22ea30690c9e618f78870357ae1d682912015a", size = 10801217, upload-time = "2025-09-29T23:17:04.522Z" }, + { url = "https://files.pythonhosted.org/packages/1d/03/3fc4a529a7710f890a239cc496fc6d50ad4a0995657dccc1d64695adb9f4/pandas-2.3.3-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5caf26f64126b6c7aec964f74266f435afef1c1b13da3b0636c7518a1fa3e2b1", size = 12148791, upload-time = "2025-09-29T23:17:18.444Z" }, + { url = "https://files.pythonhosted.org/packages/40/a8/4dac1f8f8235e5d25b9955d02ff6f29396191d4e665d71122c3722ca83c5/pandas-2.3.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dd7478f1463441ae4ca7308a70e90b33470fa593429f9d4c578dd00d1fa78838", size = 12769373, upload-time = "2025-09-29T23:17:35.846Z" }, + { url = "https://files.pythonhosted.org/packages/df/91/82cc5169b6b25440a7fc0ef3a694582418d875c8e3ebf796a6d6470aa578/pandas-2.3.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4793891684806ae50d1288c9bae9330293ab4e083ccd1c5e383c34549c6e4250", size = 13200444, upload-time = "2025-09-29T23:17:49.341Z" }, + { url = "https://files.pythonhosted.org/packages/10/ae/89b3283800ab58f7af2952704078555fa60c807fff764395bb57ea0b0dbd/pandas-2.3.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:28083c648d9a99a5dd035ec125d42439c6c1c525098c58af0fc38dd1a7a1b3d4", size = 13858459, upload-time = "2025-09-29T23:18:03.722Z" }, + { url = "https://files.pythonhosted.org/packages/85/72/530900610650f54a35a19476eca5104f38555afccda1aa11a92ee14cb21d/pandas-2.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:503cf027cf9940d2ceaa1a93cfb5f8c8c7e6e90720a2850378f0b3f3b1e06826", size = 11346086, upload-time = "2025-09-29T23:18:18.505Z" }, + { url = "https://files.pythonhosted.org/packages/c1/fa/7ac648108144a095b4fb6aa3de1954689f7af60a14cf25583f4960ecb878/pandas-2.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:602b8615ebcc4a0c1751e71840428ddebeb142ec02c786e8ad6b1ce3c8dec523", size = 11578790, upload-time = "2025-09-29T23:18:30.065Z" }, + { url = "https://files.pythonhosted.org/packages/9b/35/74442388c6cf008882d4d4bdfc4109be87e9b8b7ccd097ad1e7f006e2e95/pandas-2.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8fe25fc7b623b0ef6b5009149627e34d2a4657e880948ec3c840e9402e5c1b45", size = 10833831, upload-time = "2025-09-29T23:38:56.071Z" }, + { url = "https://files.pythonhosted.org/packages/fe/e4/de154cbfeee13383ad58d23017da99390b91d73f8c11856f2095e813201b/pandas-2.3.3-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b468d3dad6ff947df92dcb32ede5b7bd41a9b3cceef0a30ed925f6d01fb8fa66", size = 12199267, upload-time = "2025-09-29T23:18:41.627Z" }, + { url = "https://files.pythonhosted.org/packages/bf/c9/63f8d545568d9ab91476b1818b4741f521646cbdd151c6efebf40d6de6f7/pandas-2.3.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b98560e98cb334799c0b07ca7967ac361a47326e9b4e5a7dfb5ab2b1c9d35a1b", size = 12789281, upload-time = "2025-09-29T23:18:56.834Z" }, + { url = "https://files.pythonhosted.org/packages/f2/00/a5ac8c7a0e67fd1a6059e40aa08fa1c52cc00709077d2300e210c3ce0322/pandas-2.3.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37b5848ba49824e5c30bedb9c830ab9b7751fd049bc7914533e01c65f79791", size = 13240453, upload-time = "2025-09-29T23:19:09.247Z" }, + { url = "https://files.pythonhosted.org/packages/27/4d/5c23a5bc7bd209231618dd9e606ce076272c9bc4f12023a70e03a86b4067/pandas-2.3.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db4301b2d1f926ae677a751eb2bd0e8c5f5319c9cb3f88b0becbbb0b07b34151", size = 13890361, upload-time = "2025-09-29T23:19:25.342Z" }, + { url = "https://files.pythonhosted.org/packages/8e/59/712db1d7040520de7a4965df15b774348980e6df45c129b8c64d0dbe74ef/pandas-2.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f086f6fe114e19d92014a1966f43a3e62285109afe874f067f5abbdcbb10e59c", size = 11348702, upload-time = "2025-09-29T23:19:38.296Z" }, + { url = "https://files.pythonhosted.org/packages/9c/fb/231d89e8637c808b997d172b18e9d4a4bc7bf31296196c260526055d1ea0/pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d21f6d74eb1725c2efaa71a2bfc661a0689579b58e9c0ca58a739ff0b002b53", size = 11597846, upload-time = "2025-09-29T23:19:48.856Z" }, + { url = "https://files.pythonhosted.org/packages/5c/bd/bf8064d9cfa214294356c2d6702b716d3cf3bb24be59287a6a21e24cae6b/pandas-2.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3fd2f887589c7aa868e02632612ba39acb0b8948faf5cc58f0850e165bd46f35", size = 10729618, upload-time = "2025-09-29T23:39:08.659Z" }, + { url = "https://files.pythonhosted.org/packages/57/56/cf2dbe1a3f5271370669475ead12ce77c61726ffd19a35546e31aa8edf4e/pandas-2.3.3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecaf1e12bdc03c86ad4a7ea848d66c685cb6851d807a26aa245ca3d2017a1908", size = 11737212, upload-time = "2025-09-29T23:19:59.765Z" }, + { url = "https://files.pythonhosted.org/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b3d11d2fda7eb164ef27ffc14b4fcab16a80e1ce67e9f57e19ec0afaf715ba89", size = 12362693, upload-time = "2025-09-29T23:20:14.098Z" }, + { url = "https://files.pythonhosted.org/packages/a6/de/8b1895b107277d52f2b42d3a6806e69cfef0d5cf1d0ba343470b9d8e0a04/pandas-2.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a68e15f780eddf2b07d242e17a04aa187a7ee12b40b930bfdd78070556550e98", size = 12771002, upload-time = "2025-09-29T23:20:26.76Z" }, + { url = "https://files.pythonhosted.org/packages/87/21/84072af3187a677c5893b170ba2c8fbe450a6ff911234916da889b698220/pandas-2.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:371a4ab48e950033bcf52b6527eccb564f52dc826c02afd9a1bc0ab731bba084", size = 13450971, upload-time = "2025-09-29T23:20:41.344Z" }, + { url = "https://files.pythonhosted.org/packages/86/41/585a168330ff063014880a80d744219dbf1dd7a1c706e75ab3425a987384/pandas-2.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:a16dcec078a01eeef8ee61bf64074b4e524a2a3f4b3be9326420cabe59c4778b", size = 10992722, upload-time = "2025-09-29T23:20:54.139Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4b/18b035ee18f97c1040d94debd8f2e737000ad70ccc8f5513f4eefad75f4b/pandas-2.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:56851a737e3470de7fa88e6131f41281ed440d29a9268dcbf0002da5ac366713", size = 11544671, upload-time = "2025-09-29T23:21:05.024Z" }, + { url = "https://files.pythonhosted.org/packages/31/94/72fac03573102779920099bcac1c3b05975c2cb5f01eac609faf34bed1ca/pandas-2.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdcd9d1167f4885211e401b3036c0c8d9e274eee67ea8d0758a256d60704cfe8", size = 10680807, upload-time = "2025-09-29T23:21:15.979Z" }, + { url = "https://files.pythonhosted.org/packages/16/87/9472cf4a487d848476865321de18cc8c920b8cab98453ab79dbbc98db63a/pandas-2.3.3-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e32e7cc9af0f1cc15548288a51a3b681cc2a219faa838e995f7dc53dbab1062d", size = 11709872, upload-time = "2025-09-29T23:21:27.165Z" }, + { url = "https://files.pythonhosted.org/packages/15/07/284f757f63f8a8d69ed4472bfd85122bd086e637bf4ed09de572d575a693/pandas-2.3.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:318d77e0e42a628c04dc56bcef4b40de67918f7041c2b061af1da41dcff670ac", size = 12306371, upload-time = "2025-09-29T23:21:40.532Z" }, + { url = "https://files.pythonhosted.org/packages/33/81/a3afc88fca4aa925804a27d2676d22dcd2031c2ebe08aabd0ae55b9ff282/pandas-2.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4e0a175408804d566144e170d0476b15d78458795bb18f1304fb94160cabf40c", size = 12765333, upload-time = "2025-09-29T23:21:55.77Z" }, + { url = "https://files.pythonhosted.org/packages/8d/0f/b4d4ae743a83742f1153464cf1a8ecfafc3ac59722a0b5c8602310cb7158/pandas-2.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93c2d9ab0fc11822b5eece72ec9587e172f63cff87c00b062f6e37448ced4493", size = 13418120, upload-time = "2025-09-29T23:22:10.109Z" }, + { url = "https://files.pythonhosted.org/packages/4f/c7/e54682c96a895d0c808453269e0b5928a07a127a15704fedb643e9b0a4c8/pandas-2.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee", size = 10993991, upload-time = "2025-09-29T23:25:04.889Z" }, + { url = "https://files.pythonhosted.org/packages/f9/ca/3f8d4f49740799189e1395812f3bf23b5e8fc7c190827d55a610da72ce55/pandas-2.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:75ea25f9529fdec2d2e93a42c523962261e567d250b0013b16210e1d40d7c2e5", size = 12048227, upload-time = "2025-09-29T23:22:24.343Z" }, + { url = "https://files.pythonhosted.org/packages/0e/5a/f43efec3e8c0cc92c4663ccad372dbdff72b60bdb56b2749f04aa1d07d7e/pandas-2.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74ecdf1d301e812db96a465a525952f4dde225fdb6d8e5a521d47e1f42041e21", size = 11411056, upload-time = "2025-09-29T23:22:37.762Z" }, + { url = "https://files.pythonhosted.org/packages/46/b1/85331edfc591208c9d1a63a06baa67b21d332e63b7a591a5ba42a10bb507/pandas-2.3.3-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6435cb949cb34ec11cc9860246ccb2fdc9ecd742c12d3304989017d53f039a78", size = 11645189, upload-time = "2025-09-29T23:22:51.688Z" }, + { url = "https://files.pythonhosted.org/packages/44/23/78d645adc35d94d1ac4f2a3c4112ab6f5b8999f4898b8cdf01252f8df4a9/pandas-2.3.3-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:900f47d8f20860de523a1ac881c4c36d65efcb2eb850e6948140fa781736e110", size = 12121912, upload-time = "2025-09-29T23:23:05.042Z" }, + { url = "https://files.pythonhosted.org/packages/53/da/d10013df5e6aaef6b425aa0c32e1fc1f3e431e4bcabd420517dceadce354/pandas-2.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a45c765238e2ed7d7c608fc5bc4a6f88b642f2f01e70c0c23d2224dd21829d86", size = 12712160, upload-time = "2025-09-29T23:23:28.57Z" }, + { url = "https://files.pythonhosted.org/packages/bd/17/e756653095a083d8a37cbd816cb87148debcfcd920129b25f99dd8d04271/pandas-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c4fc4c21971a1a9f4bdb4c73978c7f7256caa3e62b323f70d6cb80db583350bc", size = 13199233, upload-time = "2025-09-29T23:24:24.876Z" }, + { url = "https://files.pythonhosted.org/packages/04/fd/74903979833db8390b73b3a8a7d30d146d710bd32703724dd9083950386f/pandas-2.3.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:ee15f284898e7b246df8087fc82b87b01686f98ee67d85a17b7ab44143a3a9a0", size = 11540635, upload-time = "2025-09-29T23:25:52.486Z" }, + { url = "https://files.pythonhosted.org/packages/21/00/266d6b357ad5e6d3ad55093a7e8efc7dd245f5a842b584db9f30b0f0a287/pandas-2.3.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1611aedd912e1ff81ff41c745822980c49ce4a7907537be8692c8dbc31924593", size = 10759079, upload-time = "2025-09-29T23:26:33.204Z" }, + { url = "https://files.pythonhosted.org/packages/ca/05/d01ef80a7a3a12b2f8bbf16daba1e17c98a2f039cbc8e2f77a2c5a63d382/pandas-2.3.3-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d2cefc361461662ac48810cb14365a365ce864afe85ef1f447ff5a1e99ea81c", size = 11814049, upload-time = "2025-09-29T23:27:15.384Z" }, + { url = "https://files.pythonhosted.org/packages/15/b2/0e62f78c0c5ba7e3d2c5945a82456f4fac76c480940f805e0b97fcbc2f65/pandas-2.3.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ee67acbbf05014ea6c763beb097e03cd629961c8a632075eeb34247120abcb4b", size = 12332638, upload-time = "2025-09-29T23:27:51.625Z" }, + { url = "https://files.pythonhosted.org/packages/c5/33/dd70400631b62b9b29c3c93d2feee1d0964dc2bae2e5ad7a6c73a7f25325/pandas-2.3.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c46467899aaa4da076d5abc11084634e2d197e9460643dd455ac3db5856b24d6", size = 12886834, upload-time = "2025-09-29T23:28:21.289Z" }, + { url = "https://files.pythonhosted.org/packages/d3/18/b5d48f55821228d0d2692b34fd5034bb185e854bdb592e9c640f6290e012/pandas-2.3.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:6253c72c6a1d990a410bc7de641d34053364ef8bcd3126f7e7450125887dffe3", size = 13409925, upload-time = "2025-09-29T23:28:58.261Z" }, + { url = "https://files.pythonhosted.org/packages/a6/3d/124ac75fcd0ecc09b8fdccb0246ef65e35b012030defb0e0eba2cbbbe948/pandas-2.3.3-cp314-cp314-win_amd64.whl", hash = "sha256:1b07204a219b3b7350abaae088f451860223a52cfb8a6c53358e7948735158e5", size = 11109071, upload-time = "2025-09-29T23:32:27.484Z" }, + { url = "https://files.pythonhosted.org/packages/89/9c/0e21c895c38a157e0faa1fb64587a9226d6dd46452cac4532d80c3c4a244/pandas-2.3.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2462b1a365b6109d275250baaae7b760fd25c726aaca0054649286bcfbb3e8ec", size = 12048504, upload-time = "2025-09-29T23:29:31.47Z" }, + { url = "https://files.pythonhosted.org/packages/d7/82/b69a1c95df796858777b68fbe6a81d37443a33319761d7c652ce77797475/pandas-2.3.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0242fe9a49aa8b4d78a4fa03acb397a58833ef6199e9aa40a95f027bb3a1b6e7", size = 11410702, upload-time = "2025-09-29T23:29:54.591Z" }, + { url = "https://files.pythonhosted.org/packages/f9/88/702bde3ba0a94b8c73a0181e05144b10f13f29ebfc2150c3a79062a8195d/pandas-2.3.3-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a21d830e78df0a515db2b3d2f5570610f5e6bd2e27749770e8bb7b524b89b450", size = 11634535, upload-time = "2025-09-29T23:30:21.003Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1e/1bac1a839d12e6a82ec6cb40cda2edde64a2013a66963293696bbf31fbbb/pandas-2.3.3-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e3ebdb170b5ef78f19bfb71b0dc5dc58775032361fa188e814959b74d726dd5", size = 12121582, upload-time = "2025-09-29T23:30:43.391Z" }, + { url = "https://files.pythonhosted.org/packages/44/91/483de934193e12a3b1d6ae7c8645d083ff88dec75f46e827562f1e4b4da6/pandas-2.3.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d051c0e065b94b7a3cea50eb1ec32e912cd96dba41647eb24104b6c6c14c5788", size = 12699963, upload-time = "2025-09-29T23:31:10.009Z" }, + { url = "https://files.pythonhosted.org/packages/70/44/5191d2e4026f86a2a109053e194d3ba7a31a2d10a9c2348368c63ed4e85a/pandas-2.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3869faf4bd07b3b66a9f462417d0ca3a9df29a9f6abd5d0d0dbab15dac7abe87", size = 13202175, upload-time = "2025-09-29T23:31:59.173Z" }, ] [[package]] @@ -1621,18 +1299,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, ] -[[package]] -name = "pexpect" -version = "4.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ptyprocess" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, -] - [[package]] name = "pillow" version = "10.4.0" @@ -1765,39 +1431,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/94/d77bd282d3d53155147166c2bbd156f540009b0d7be24330f76286668b90/protobuf-5.27.3-py3-none-any.whl", hash = "sha256:8572c6533e544ebf6899c360e91d6bcbbee2549251643d32c52cf8a5de295ba5", size = 164778, upload-time = "2024-07-31T16:29:39.791Z" }, ] -[[package]] -name = "psutil" -version = "7.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003, upload-time = "2025-02-13T21:54:07.946Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051, upload-time = "2025-02-13T21:54:12.36Z" }, - { url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535, upload-time = "2025-02-13T21:54:16.07Z" }, - { url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004, upload-time = "2025-02-13T21:54:18.662Z" }, - { url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986, upload-time = "2025-02-13T21:54:21.811Z" }, - { url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544, upload-time = "2025-02-13T21:54:24.68Z" }, - { url = "https://files.pythonhosted.org/packages/50/e6/eecf58810b9d12e6427369784efe814a1eec0f492084ce8eb8f4d89d6d61/psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99", size = 241053, upload-time = "2025-02-13T21:54:34.31Z" }, - { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" }, -] - -[[package]] -name = "ptyprocess" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, -] - -[[package]] -name = "pure-eval" -version = "0.2.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, -] - [[package]] name = "py" version = "1.11.0" @@ -1807,15 +1440,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", size = 98708, upload-time = "2021-11-04T17:17:00.152Z" }, ] -[[package]] -name = "pycparser" -version = "2.22" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736, upload-time = "2024-03-30T13:22:22.564Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552, upload-time = "2024-03-30T13:22:20.476Z" }, -] - [[package]] name = "pydantic" version = "2.11.7" @@ -1984,6 +1608,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", size = 21949, upload-time = "2023-05-24T18:44:54.079Z" }, ] +[[package]] +name = "pytest-github-actions-annotate-failures" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/d4/c54ee6a871eee4a7468e3a8c0dead28e634c0bc2110c694309dcb7563a66/pytest_github_actions_annotate_failures-0.3.0.tar.gz", hash = "sha256:d4c3177c98046c3900a7f8ddebb22ea54b9f6822201b5d3ab8fcdea51e010db7", size = 11248, upload-time = "2025-01-17T22:39:32.722Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/73/7b0b15cb8605ee967b34aa1d949737ab664f94e6b0f1534e8339d9e64ab2/pytest_github_actions_annotate_failures-0.3.0-py3-none-any.whl", hash = "sha256:41ea558ba10c332c0bfc053daeee0c85187507b2034e990f21e4f7e5fef044cf", size = 6030, upload-time = "2025-01-17T22:39:31.701Z" }, +] + [[package]] name = "pytest-html" version = "3.2.0" @@ -2059,25 +1695,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/3d/a121f284241f08268b21359bd425f7d4825cffc5ac5cd0e1b3d82ffd2b10/pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319", size = 505474, upload-time = "2024-02-02T01:18:37.283Z" }, ] -[[package]] -name = "pywin32" -version = "310" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/da/a5f38fffbba2fb99aa4aa905480ac4b8e83ca486659ac8c95bce47fb5276/pywin32-310-cp310-cp310-win32.whl", hash = "sha256:6dd97011efc8bf51d6793a82292419eba2c71cf8e7250cfac03bba284454abc1", size = 8848240, upload-time = "2025-03-17T00:55:46.783Z" }, - { url = "https://files.pythonhosted.org/packages/aa/fe/d873a773324fa565619ba555a82c9dabd677301720f3660a731a5d07e49a/pywin32-310-cp310-cp310-win_amd64.whl", hash = "sha256:c3e78706e4229b915a0821941a84e7ef420bf2b77e08c9dae3c76fd03fd2ae3d", size = 9601854, upload-time = "2025-03-17T00:55:48.783Z" }, - { url = "https://files.pythonhosted.org/packages/3c/84/1a8e3d7a15490d28a5d816efa229ecb4999cdc51a7c30dd8914f669093b8/pywin32-310-cp310-cp310-win_arm64.whl", hash = "sha256:33babed0cf0c92a6f94cc6cc13546ab24ee13e3e800e61ed87609ab91e4c8213", size = 8522963, upload-time = "2025-03-17T00:55:50.969Z" }, - { url = "https://files.pythonhosted.org/packages/f7/b1/68aa2986129fb1011dabbe95f0136f44509afaf072b12b8f815905a39f33/pywin32-310-cp311-cp311-win32.whl", hash = "sha256:1e765f9564e83011a63321bb9d27ec456a0ed90d3732c4b2e312b855365ed8bd", size = 8784284, upload-time = "2025-03-17T00:55:53.124Z" }, - { url = "https://files.pythonhosted.org/packages/b3/bd/d1592635992dd8db5bb8ace0551bc3a769de1ac8850200cfa517e72739fb/pywin32-310-cp311-cp311-win_amd64.whl", hash = "sha256:126298077a9d7c95c53823934f000599f66ec9296b09167810eb24875f32689c", size = 9520748, upload-time = "2025-03-17T00:55:55.203Z" }, - { url = "https://files.pythonhosted.org/packages/90/b1/ac8b1ffce6603849eb45a91cf126c0fa5431f186c2e768bf56889c46f51c/pywin32-310-cp311-cp311-win_arm64.whl", hash = "sha256:19ec5fc9b1d51c4350be7bb00760ffce46e6c95eaf2f0b2f1150657b1a43c582", size = 8455941, upload-time = "2025-03-17T00:55:57.048Z" }, - { url = "https://files.pythonhosted.org/packages/6b/ec/4fdbe47932f671d6e348474ea35ed94227fb5df56a7c30cbbb42cd396ed0/pywin32-310-cp312-cp312-win32.whl", hash = "sha256:8a75a5cc3893e83a108c05d82198880704c44bbaee4d06e442e471d3c9ea4f3d", size = 8796239, upload-time = "2025-03-17T00:55:58.807Z" }, - { url = "https://files.pythonhosted.org/packages/e3/e5/b0627f8bb84e06991bea89ad8153a9e50ace40b2e1195d68e9dff6b03d0f/pywin32-310-cp312-cp312-win_amd64.whl", hash = "sha256:bf5c397c9a9a19a6f62f3fb821fbf36cac08f03770056711f765ec1503972060", size = 9503839, upload-time = "2025-03-17T00:56:00.8Z" }, - { url = "https://files.pythonhosted.org/packages/1f/32/9ccf53748df72301a89713936645a664ec001abd35ecc8578beda593d37d/pywin32-310-cp312-cp312-win_arm64.whl", hash = "sha256:2349cc906eae872d0663d4d6290d13b90621eaf78964bb1578632ff20e152966", size = 8459470, upload-time = "2025-03-17T00:56:02.601Z" }, - { url = "https://files.pythonhosted.org/packages/1c/09/9c1b978ffc4ae53999e89c19c77ba882d9fce476729f23ef55211ea1c034/pywin32-310-cp313-cp313-win32.whl", hash = "sha256:5d241a659c496ada3253cd01cfaa779b048e90ce4b2b38cd44168ad555ce74ab", size = 8794384, upload-time = "2025-03-17T00:56:04.383Z" }, - { url = "https://files.pythonhosted.org/packages/45/3c/b4640f740ffebadd5d34df35fecba0e1cfef8fde9f3e594df91c28ad9b50/pywin32-310-cp313-cp313-win_amd64.whl", hash = "sha256:667827eb3a90208ddbdcc9e860c81bde63a135710e21e4cb3348968e4bd5249e", size = 9503039, upload-time = "2025-03-17T00:56:06.207Z" }, - { url = "https://files.pythonhosted.org/packages/b4/f4/f785020090fb050e7fb6d34b780f2231f302609dc964672f72bfaeb59a28/pywin32-310-cp313-cp313-win_arm64.whl", hash = "sha256:e308f831de771482b7cf692a1f308f8fca701b2d8f9dde6cc440c7da17e47b33", size = 8458152, upload-time = "2025-03-17T00:56:07.819Z" }, -] - [[package]] name = "pyyaml" version = "6.0.2" @@ -2134,66 +1751,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/66/bbb1dd374f5c870f59c5bb1db0e18cbe7fa739415a24cbd95b2d1f5ae0c4/pyyaml_env_tag-0.1-py3-none-any.whl", hash = "sha256:af31106dec8a4d68c60207c1886031cbf839b68aa7abccdb19868200532c2069", size = 3911, upload-time = "2020-11-12T02:38:24.638Z" }, ] -[[package]] -name = "pyzmq" -version = "27.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cffi", marker = "implementation_name == 'pypy'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f1/06/50a4e9648b3e8b992bef8eb632e457307553a89d294103213cfd47b3da69/pyzmq-27.0.0.tar.gz", hash = "sha256:b1f08eeb9ce1510e6939b6e5dcd46a17765e2333daae78ecf4606808442e52cf", size = 280478, upload-time = "2025-06-13T14:09:07.087Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9c/09/1681d4b047626d352c083770618ac29655ab1f5c20eee31dc94c000b9b7b/pyzmq-27.0.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:b973ee650e8f442ce482c1d99ca7ab537c69098d53a3d046676a484fd710c87a", size = 1329291, upload-time = "2025-06-13T14:06:57.945Z" }, - { url = "https://files.pythonhosted.org/packages/9d/b2/9c9385225fdd54db9506ed8accbb9ea63ca813ba59d43d7f282a6a16a30b/pyzmq-27.0.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:661942bc7cd0223d569d808f2e5696d9cc120acc73bf3e88a1f1be7ab648a7e4", size = 905952, upload-time = "2025-06-13T14:07:03.232Z" }, - { url = "https://files.pythonhosted.org/packages/41/73/333c72c7ec182cdffe25649e3da1c3b9f3cf1cede63cfdc23d1384d4a601/pyzmq-27.0.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:50360fb2a056ffd16e5f4177eee67f1dd1017332ea53fb095fe7b5bf29c70246", size = 666165, upload-time = "2025-06-13T14:07:04.667Z" }, - { url = "https://files.pythonhosted.org/packages/a5/fe/fc7b9c1a50981928e25635a926653cb755364316db59ccd6e79cfb9a0b4f/pyzmq-27.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cf209a6dc4b420ed32a7093642843cbf8703ed0a7d86c16c0b98af46762ebefb", size = 853755, upload-time = "2025-06-13T14:07:06.93Z" }, - { url = "https://files.pythonhosted.org/packages/8c/4c/740ed4b6e8fa160cd19dc5abec8db68f440564b2d5b79c1d697d9862a2f7/pyzmq-27.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c2dace4a7041cca2fba5357a2d7c97c5effdf52f63a1ef252cfa496875a3762d", size = 1654868, upload-time = "2025-06-13T14:07:08.224Z" }, - { url = "https://files.pythonhosted.org/packages/97/00/875b2ecfcfc78ab962a59bd384995186818524ea957dc8ad3144611fae12/pyzmq-27.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:63af72b2955fc77caf0a77444baa2431fcabb4370219da38e1a9f8d12aaebe28", size = 2033443, upload-time = "2025-06-13T14:07:09.653Z" }, - { url = "https://files.pythonhosted.org/packages/60/55/6dd9c470c42d713297c5f2a56f7903dc1ebdb4ab2edda996445c21651900/pyzmq-27.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e8c4adce8e37e75c4215297d7745551b8dcfa5f728f23ce09bf4e678a9399413", size = 1891288, upload-time = "2025-06-13T14:07:11.099Z" }, - { url = "https://files.pythonhosted.org/packages/28/5d/54b0ef50d40d7c65a627f4a4b4127024ba9820f2af8acd933a4d30ae192e/pyzmq-27.0.0-cp310-cp310-win32.whl", hash = "sha256:5d5ef4718ecab24f785794e0e7536436698b459bfbc19a1650ef55280119d93b", size = 567936, upload-time = "2025-06-13T14:07:12.468Z" }, - { url = "https://files.pythonhosted.org/packages/18/ea/dedca4321de748ca48d3bcdb72274d4d54e8d84ea49088d3de174bd45d88/pyzmq-27.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:e40609380480b3d12c30f841323f42451c755b8fece84235236f5fe5ffca8c1c", size = 628686, upload-time = "2025-06-13T14:07:14.051Z" }, - { url = "https://files.pythonhosted.org/packages/d4/a7/fcdeedc306e71e94ac262cba2d02337d885f5cdb7e8efced8e5ffe327808/pyzmq-27.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:6b0397b0be277b46762956f576e04dc06ced265759e8c2ff41a0ee1aa0064198", size = 559039, upload-time = "2025-06-13T14:07:15.289Z" }, - { url = "https://files.pythonhosted.org/packages/44/df/84c630654106d9bd9339cdb564aa941ed41b023a0264251d6743766bb50e/pyzmq-27.0.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:21457825249b2a53834fa969c69713f8b5a79583689387a5e7aed880963ac564", size = 1332718, upload-time = "2025-06-13T14:07:16.555Z" }, - { url = "https://files.pythonhosted.org/packages/c1/8e/f6a5461a07654d9840d256476434ae0ff08340bba562a455f231969772cb/pyzmq-27.0.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1958947983fef513e6e98eff9cb487b60bf14f588dc0e6bf35fa13751d2c8251", size = 908248, upload-time = "2025-06-13T14:07:18.033Z" }, - { url = "https://files.pythonhosted.org/packages/7c/93/82863e8d695a9a3ae424b63662733ae204a295a2627d52af2f62c2cd8af9/pyzmq-27.0.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0dc628b5493f9a8cd9844b8bee9732ef587ab00002157c9329e4fc0ef4d3afa", size = 668647, upload-time = "2025-06-13T14:07:19.378Z" }, - { url = "https://files.pythonhosted.org/packages/f3/85/15278769b348121eacdbfcbd8c4d40f1102f32fa6af5be1ffc032ed684be/pyzmq-27.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7bbe9e1ed2c8d3da736a15694d87c12493e54cc9dc9790796f0321794bbc91f", size = 856600, upload-time = "2025-06-13T14:07:20.906Z" }, - { url = "https://files.pythonhosted.org/packages/d4/af/1c469b3d479bd095edb28e27f12eee10b8f00b356acbefa6aeb14dd295d1/pyzmq-27.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dc1091f59143b471d19eb64f54bae4f54bcf2a466ffb66fe45d94d8d734eb495", size = 1657748, upload-time = "2025-06-13T14:07:22.549Z" }, - { url = "https://files.pythonhosted.org/packages/8c/f4/17f965d0ee6380b1d6326da842a50e4b8b9699745161207945f3745e8cb5/pyzmq-27.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7011ade88c8e535cf140f8d1a59428676fbbce7c6e54fefce58bf117aefb6667", size = 2034311, upload-time = "2025-06-13T14:07:23.966Z" }, - { url = "https://files.pythonhosted.org/packages/e0/6e/7c391d81fa3149fd759de45d298003de6cfab343fb03e92c099821c448db/pyzmq-27.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2c386339d7e3f064213aede5d03d054b237937fbca6dd2197ac8cf3b25a6b14e", size = 1893630, upload-time = "2025-06-13T14:07:25.899Z" }, - { url = "https://files.pythonhosted.org/packages/0e/e0/eaffe7a86f60e556399e224229e7769b717f72fec0706b70ab2c03aa04cb/pyzmq-27.0.0-cp311-cp311-win32.whl", hash = "sha256:0546a720c1f407b2172cb04b6b094a78773491497e3644863cf5c96c42df8cff", size = 567706, upload-time = "2025-06-13T14:07:27.595Z" }, - { url = "https://files.pythonhosted.org/packages/c9/05/89354a8cffdcce6e547d48adaaf7be17007fc75572123ff4ca90a4ca04fc/pyzmq-27.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:15f39d50bd6c9091c67315ceb878a4f531957b121d2a05ebd077eb35ddc5efed", size = 630322, upload-time = "2025-06-13T14:07:28.938Z" }, - { url = "https://files.pythonhosted.org/packages/fa/07/4ab976d5e1e63976719389cc4f3bfd248a7f5f2bb2ebe727542363c61b5f/pyzmq-27.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c5817641eebb391a2268c27fecd4162448e03538387093cdbd8bf3510c316b38", size = 558435, upload-time = "2025-06-13T14:07:30.256Z" }, - { url = "https://files.pythonhosted.org/packages/93/a7/9ad68f55b8834ede477842214feba6a4c786d936c022a67625497aacf61d/pyzmq-27.0.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:cbabc59dcfaac66655c040dfcb8118f133fb5dde185e5fc152628354c1598e52", size = 1305438, upload-time = "2025-06-13T14:07:31.676Z" }, - { url = "https://files.pythonhosted.org/packages/ba/ee/26aa0f98665a22bc90ebe12dced1de5f3eaca05363b717f6fb229b3421b3/pyzmq-27.0.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:cb0ac5179cba4b2f94f1aa208fbb77b62c4c9bf24dd446278b8b602cf85fcda3", size = 895095, upload-time = "2025-06-13T14:07:33.104Z" }, - { url = "https://files.pythonhosted.org/packages/cf/85/c57e7ab216ecd8aa4cc7e3b83b06cc4e9cf45c87b0afc095f10cd5ce87c1/pyzmq-27.0.0-cp312-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53a48f0228eab6cbf69fde3aa3c03cbe04e50e623ef92ae395fce47ef8a76152", size = 651826, upload-time = "2025-06-13T14:07:34.831Z" }, - { url = "https://files.pythonhosted.org/packages/69/9a/9ea7e230feda9400fb0ae0d61d7d6ddda635e718d941c44eeab22a179d34/pyzmq-27.0.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:111db5f395e09f7e775f759d598f43cb815fc58e0147623c4816486e1a39dc22", size = 839750, upload-time = "2025-06-13T14:07:36.553Z" }, - { url = "https://files.pythonhosted.org/packages/08/66/4cebfbe71f3dfbd417011daca267539f62ed0fbc68105357b68bbb1a25b7/pyzmq-27.0.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c8878011653dcdc27cc2c57e04ff96f0471e797f5c19ac3d7813a245bcb24371", size = 1641357, upload-time = "2025-06-13T14:07:38.21Z" }, - { url = "https://files.pythonhosted.org/packages/ac/f6/b0f62578c08d2471c791287149cb8c2aaea414ae98c6e995c7dbe008adfb/pyzmq-27.0.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:c0ed2c1f335ba55b5fdc964622254917d6b782311c50e138863eda409fbb3b6d", size = 2020281, upload-time = "2025-06-13T14:07:39.599Z" }, - { url = "https://files.pythonhosted.org/packages/37/b9/4f670b15c7498495da9159edc374ec09c88a86d9cd5a47d892f69df23450/pyzmq-27.0.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e918d70862d4cfd4b1c187310015646a14e1f5917922ab45b29f28f345eeb6be", size = 1877110, upload-time = "2025-06-13T14:07:41.027Z" }, - { url = "https://files.pythonhosted.org/packages/66/31/9dee25c226295b740609f0d46db2fe972b23b6f5cf786360980524a3ba92/pyzmq-27.0.0-cp312-abi3-win32.whl", hash = "sha256:88b4e43cab04c3c0f0d55df3b1eef62df2b629a1a369b5289a58f6fa8b07c4f4", size = 559297, upload-time = "2025-06-13T14:07:42.533Z" }, - { url = "https://files.pythonhosted.org/packages/9b/12/52da5509800f7ff2d287b2f2b4e636e7ea0f001181cba6964ff6c1537778/pyzmq-27.0.0-cp312-abi3-win_amd64.whl", hash = "sha256:dce4199bf5f648a902ce37e7b3afa286f305cd2ef7a8b6ec907470ccb6c8b371", size = 619203, upload-time = "2025-06-13T14:07:43.843Z" }, - { url = "https://files.pythonhosted.org/packages/93/6d/7f2e53b19d1edb1eb4f09ec7c3a1f945ca0aac272099eab757d15699202b/pyzmq-27.0.0-cp312-abi3-win_arm64.whl", hash = "sha256:56e46bbb85d52c1072b3f809cc1ce77251d560bc036d3a312b96db1afe76db2e", size = 551927, upload-time = "2025-06-13T14:07:45.51Z" }, - { url = "https://files.pythonhosted.org/packages/19/62/876b27c4ff777db4ceba1c69ea90d3c825bb4f8d5e7cd987ce5802e33c55/pyzmq-27.0.0-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:c36ad534c0c29b4afa088dc53543c525b23c0797e01b69fef59b1a9c0e38b688", size = 1340826, upload-time = "2025-06-13T14:07:46.881Z" }, - { url = "https://files.pythonhosted.org/packages/43/69/58ef8f4f59d3bcd505260c73bee87b008850f45edca40ddaba54273c35f4/pyzmq-27.0.0-cp313-cp313t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:67855c14173aec36395d7777aaba3cc527b393821f30143fd20b98e1ff31fd38", size = 897283, upload-time = "2025-06-13T14:07:49.562Z" }, - { url = "https://files.pythonhosted.org/packages/43/15/93a0d0396700a60475ad3c5d42c5f1c308d3570bc94626b86c71ef9953e0/pyzmq-27.0.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8617c7d43cd8ccdb62aebe984bfed77ca8f036e6c3e46dd3dddda64b10f0ab7a", size = 660567, upload-time = "2025-06-13T14:07:51.364Z" }, - { url = "https://files.pythonhosted.org/packages/0e/b3/fe055513e498ca32f64509abae19b9c9eb4d7c829e02bd8997dd51b029eb/pyzmq-27.0.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:67bfbcbd0a04c575e8103a6061d03e393d9f80ffdb9beb3189261e9e9bc5d5e9", size = 847681, upload-time = "2025-06-13T14:07:52.77Z" }, - { url = "https://files.pythonhosted.org/packages/b6/4f/ff15300b00b5b602191f3df06bbc8dd4164e805fdd65bb77ffbb9c5facdc/pyzmq-27.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:5cd11d46d7b7e5958121b3eaf4cd8638eff3a720ec527692132f05a57f14341d", size = 1650148, upload-time = "2025-06-13T14:07:54.178Z" }, - { url = "https://files.pythonhosted.org/packages/c4/6f/84bdfff2a224a6f26a24249a342e5906993c50b0761e311e81b39aef52a7/pyzmq-27.0.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:b801c2e40c5aa6072c2f4876de8dccd100af6d9918d4d0d7aa54a1d982fd4f44", size = 2023768, upload-time = "2025-06-13T14:07:55.714Z" }, - { url = "https://files.pythonhosted.org/packages/64/39/dc2db178c26a42228c5ac94a9cc595030458aa64c8d796a7727947afbf55/pyzmq-27.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:20d5cb29e8c5f76a127c75b6e7a77e846bc4b655c373baa098c26a61b7ecd0ef", size = 1885199, upload-time = "2025-06-13T14:07:57.166Z" }, - { url = "https://files.pythonhosted.org/packages/c7/21/dae7b06a1f8cdee5d8e7a63d99c5d129c401acc40410bef2cbf42025e26f/pyzmq-27.0.0-cp313-cp313t-win32.whl", hash = "sha256:a20528da85c7ac7a19b7384e8c3f8fa707841fd85afc4ed56eda59d93e3d98ad", size = 575439, upload-time = "2025-06-13T14:07:58.959Z" }, - { url = "https://files.pythonhosted.org/packages/eb/bc/1709dc55f0970cf4cb8259e435e6773f9946f41a045c2cb90e870b7072da/pyzmq-27.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d8229f2efece6a660ee211d74d91dbc2a76b95544d46c74c615e491900dc107f", size = 639933, upload-time = "2025-06-13T14:08:00.777Z" }, - { url = "https://files.pythonhosted.org/packages/09/6f/be6523a7f3821c0b5370912ef02822c028611360e0d206dd945bdbf9eaef/pyzmq-27.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:656c1866505a5735d0660b7da6d7147174bbf59d4975fc2b7f09f43c9bc25745", size = 835950, upload-time = "2025-06-13T14:08:35Z" }, - { url = "https://files.pythonhosted.org/packages/c6/1e/a50fdd5c15018de07ab82a61bc460841be967ee7bbe7abee3b714d66f7ac/pyzmq-27.0.0-pp310-pypy310_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:74175b9e12779382432dd1d1f5960ebe7465d36649b98a06c6b26be24d173fab", size = 799876, upload-time = "2025-06-13T14:08:36.849Z" }, - { url = "https://files.pythonhosted.org/packages/88/a1/89eb5b71f5a504f8f887aceb8e1eb3626e00c00aa8085381cdff475440dc/pyzmq-27.0.0-pp310-pypy310_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8c6de908465697a8708e4d6843a1e884f567962fc61eb1706856545141d0cbb", size = 567400, upload-time = "2025-06-13T14:08:38.95Z" }, - { url = "https://files.pythonhosted.org/packages/56/aa/4571dbcff56cfb034bac73fde8294e123c975ce3eea89aff31bf6dc6382b/pyzmq-27.0.0-pp310-pypy310_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c644aaacc01d0df5c7072826df45e67301f191c55f68d7b2916d83a9ddc1b551", size = 747031, upload-time = "2025-06-13T14:08:40.413Z" }, - { url = "https://files.pythonhosted.org/packages/46/e0/d25f30fe0991293c5b2f5ef3b070d35fa6d57c0c7428898c3ab4913d0297/pyzmq-27.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:10f70c1d9a446a85013a36871a296007f6fe4232b530aa254baf9da3f8328bc0", size = 544726, upload-time = "2025-06-13T14:08:41.997Z" }, - { url = "https://files.pythonhosted.org/packages/98/a6/92394373b8dbc1edc9d53c951e8d3989d518185174ee54492ec27711779d/pyzmq-27.0.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:cd1dc59763effd1576f8368047c9c31468fce0af89d76b5067641137506792ae", size = 835948, upload-time = "2025-06-13T14:08:43.516Z" }, - { url = "https://files.pythonhosted.org/packages/56/f3/4dc38d75d9995bfc18773df3e41f2a2ca9b740b06f1a15dbf404077e7588/pyzmq-27.0.0-pp311-pypy311_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:60e8cc82d968174650c1860d7b716366caab9973787a1c060cf8043130f7d0f7", size = 799874, upload-time = "2025-06-13T14:08:45.017Z" }, - { url = "https://files.pythonhosted.org/packages/ab/ba/64af397e0f421453dc68e31d5e0784d554bf39013a2de0872056e96e58af/pyzmq-27.0.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:14fe7aaac86e4e93ea779a821967360c781d7ac5115b3f1a171ced77065a0174", size = 567400, upload-time = "2025-06-13T14:08:46.855Z" }, - { url = "https://files.pythonhosted.org/packages/63/87/ec956cbe98809270b59a22891d5758edae147a258e658bf3024a8254c855/pyzmq-27.0.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6ad0562d4e6abb785be3e4dd68599c41be821b521da38c402bc9ab2a8e7ebc7e", size = 747031, upload-time = "2025-06-13T14:08:48.419Z" }, - { url = "https://files.pythonhosted.org/packages/be/8a/4a3764a68abc02e2fbb0668d225b6fda5cd39586dd099cee8b2ed6ab0452/pyzmq-27.0.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:9df43a2459cd3a3563404c1456b2c4c69564daa7dbaf15724c09821a3329ce46", size = 544726, upload-time = "2025-06-13T14:08:49.903Z" }, -] - [[package]] name = "questionary" version = "2.1.0" @@ -2294,28 +1851,14 @@ wheels = [ [[package]] name = "sparkwheel" -version = "0.0.5" +version = "0.0.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pyyaml" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/f0/4c538e8b5fbea48c999313cd11c73ae52bfcb48235a4bf157a04f32cbcff/sparkwheel-0.0.5.tar.gz", hash = "sha256:c5c4117b1989eea78e99f6663d0c421573f5d6dfcabf8633e77225710d0243fe", size = 43530, upload-time = "2025-11-14T02:32:26.727Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/3e/eea7646716b39fe523c42e7672df5c0a4d3087351c809ba6cf59e3ce09e8/sparkwheel-0.0.9.tar.gz", hash = "sha256:604cded3ecc6c8dceb5b769e9eb273e15e0dc206598b549ad18595f842ab80bc", size = 49543, upload-time = "2025-11-29T03:30:25.302Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/51/9252dcb284bca884965becbc9b6cce7279011cc99d2a1f04110f1038cfc9/sparkwheel-0.0.5-py3-none-any.whl", hash = "sha256:363c60c516ffa9a4d7b8aab978181ee0277a05a32182f8118d706e87a9ea3fe5", size = 53447, upload-time = "2025-11-14T02:32:25.647Z" }, -] - -[[package]] -name = "stack-data" -version = "0.6.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "asttokens" }, - { name = "executing" }, - { name = "pure-eval" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, + { url = "https://files.pythonhosted.org/packages/b6/a5/d03dce442b1226905837a07668399fa9232cc73a2684041df3b0cedc0229/sparkwheel-0.0.9-py3-none-any.whl", hash = "sha256:d18f6eef804ead414c1dad3fab055ee4e701e67d6f91461b8ce08f71c9cd7b28", size = 59158, upload-time = "2025-11-29T03:30:26.465Z" }, ] [[package]] @@ -2491,25 +2034,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/b4/b0247c2a953322e1ac3fe4c31aad4a39530ae5f60128f3cdb760136386e3/torchvision-0.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:bb0da0950d2034a0412c251a3a9117ff9612157f45177d37ba1b20b472c0864b", size = 1567343, upload-time = "2024-10-17T14:48:19.793Z" }, ] -[[package]] -name = "tornado" -version = "6.5.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/89/c72771c81d25d53fe33e3dca61c233b665b2780f21820ba6fd2c6793c12b/tornado-6.5.1.tar.gz", hash = "sha256:84ceece391e8eb9b2b95578db65e920d2a61070260594819589609ba9bc6308c", size = 509934, upload-time = "2025-05-22T18:15:38.788Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/89/f4532dee6843c9e0ebc4e28d4be04c67f54f60813e4bf73d595fe7567452/tornado-6.5.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d50065ba7fd11d3bd41bcad0825227cc9a95154bad83239357094c36708001f7", size = 441948, upload-time = "2025-05-22T18:15:20.862Z" }, - { url = "https://files.pythonhosted.org/packages/15/9a/557406b62cffa395d18772e0cdcf03bed2fff03b374677348eef9f6a3792/tornado-6.5.1-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:9e9ca370f717997cb85606d074b0e5b247282cf5e2e1611568b8821afe0342d6", size = 440112, upload-time = "2025-05-22T18:15:22.591Z" }, - { url = "https://files.pythonhosted.org/packages/55/82/7721b7319013a3cf881f4dffa4f60ceff07b31b394e459984e7a36dc99ec/tornado-6.5.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b77e9dfa7ed69754a54c89d82ef746398be82f749df69c4d3abe75c4d1ff4888", size = 443672, upload-time = "2025-05-22T18:15:24.027Z" }, - { url = "https://files.pythonhosted.org/packages/7d/42/d11c4376e7d101171b94e03cef0cbce43e823ed6567ceda571f54cf6e3ce/tornado-6.5.1-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b76040ee3bab8bcf7ba9feb136436a3787208717a1fb9f2c16b744fba7331", size = 443019, upload-time = "2025-05-22T18:15:25.735Z" }, - { url = "https://files.pythonhosted.org/packages/7d/f7/0c48ba992d875521ac761e6e04b0a1750f8150ae42ea26df1852d6a98942/tornado-6.5.1-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:308473f4cc5a76227157cdf904de33ac268af770b2c5f05ca6c1161d82fdd95e", size = 443252, upload-time = "2025-05-22T18:15:27.499Z" }, - { url = "https://files.pythonhosted.org/packages/89/46/d8d7413d11987e316df4ad42e16023cd62666a3c0dfa1518ffa30b8df06c/tornado-6.5.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:caec6314ce8a81cf69bd89909f4b633b9f523834dc1a352021775d45e51d9401", size = 443930, upload-time = "2025-05-22T18:15:29.299Z" }, - { url = "https://files.pythonhosted.org/packages/78/b2/f8049221c96a06df89bed68260e8ca94beca5ea532ffc63b1175ad31f9cc/tornado-6.5.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:13ce6e3396c24e2808774741331638ee6c2f50b114b97a55c5b442df65fd9692", size = 443351, upload-time = "2025-05-22T18:15:31.038Z" }, - { url = "https://files.pythonhosted.org/packages/76/ff/6a0079e65b326cc222a54720a748e04a4db246870c4da54ece4577bfa702/tornado-6.5.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5cae6145f4cdf5ab24744526cc0f55a17d76f02c98f4cff9daa08ae9a217448a", size = 443328, upload-time = "2025-05-22T18:15:32.426Z" }, - { url = "https://files.pythonhosted.org/packages/49/18/e3f902a1d21f14035b5bc6246a8c0f51e0eef562ace3a2cea403c1fb7021/tornado-6.5.1-cp39-abi3-win32.whl", hash = "sha256:e0a36e1bc684dca10b1aa75a31df8bdfed656831489bc1e6a6ebed05dc1ec365", size = 444396, upload-time = "2025-05-22T18:15:34.205Z" }, - { url = "https://files.pythonhosted.org/packages/7b/09/6526e32bf1049ee7de3bebba81572673b19a2a8541f795d887e92af1a8bc/tornado-6.5.1-cp39-abi3-win_amd64.whl", hash = "sha256:908e7d64567cecd4c2b458075589a775063453aeb1d2a1853eedb806922f568b", size = 444840, upload-time = "2025-05-22T18:15:36.1Z" }, - { url = "https://files.pythonhosted.org/packages/55/a7/535c44c7bea4578e48281d83c615219f3ab19e6abc67625ef637c73987be/tornado-6.5.1-cp39-abi3-win_arm64.whl", hash = "sha256:02420a0eb7bf617257b9935e2b754d1b63897525d8a289c9d65690d580b4dcf7", size = 443596, upload-time = "2025-05-22T18:15:37.433Z" }, -] - [[package]] name = "tqdm" version = "4.66.5" @@ -2522,15 +2046,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/5d/acf5905c36149bbaec41ccf7f2b68814647347b72075ac0b1fe3022fdc73/tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd", size = 78351, upload-time = "2024-08-03T22:35:36.644Z" }, ] -[[package]] -name = "traitlets" -version = "5.14.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, -] - [[package]] name = "triton" version = "3.1.0" @@ -2544,6 +2059,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444, upload-time = "2024-10-14T16:05:53.433Z" }, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20250915" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/69/3c51b36d04da19b92f9e815be12753125bd8bc247ba0470a982e6979e71c/types_pyyaml-6.0.12.20250915.tar.gz", hash = "sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3", size = 17522, upload-time = "2025-09-15T03:01:00.728Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/e0/1eed384f02555dde685fff1a1ac805c1c7dcb6dd019c916fe659b1c1f9ec/types_pyyaml-6.0.12.20250915-py3-none-any.whl", hash = "sha256:e7d4d9e064e89a3b3cae120b4990cd370874d2bf12fa5f46c97018dd5d3c9ab6", size = 20338, upload-time = "2025-09-15T03:00:59.218Z" }, +] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -2565,6 +2089,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, ] +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + [[package]] name = "urllib3" version = "2.2.2" @@ -2678,15 +2211,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/6e/e792999e816d19d7fcbfa94c730936750036d65656a76a5a688b57a656c4/werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8", size = 227274, upload-time = "2024-05-05T23:10:29.567Z" }, ] -[[package]] -name = "widgetsnbextension" -version = "4.0.14" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/53/2e0253c5efd69c9656b1843892052a31c36d37ad42812b5da45c62191f7e/widgetsnbextension-4.0.14.tar.gz", hash = "sha256:a3629b04e3edb893212df862038c7232f62973373869db5084aed739b437b5af", size = 1097428, upload-time = "2025-04-10T13:01:25.628Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl", hash = "sha256:4875a9eaf72fbf5079dc372a51a9f268fc38d46f767cbf85c43a36da5cb9b575", size = 2196503, upload-time = "2025-04-10T13:01:23.086Z" }, -] - [[package]] name = "win32-setctime" version = "1.1.0"