diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 6cafd7ee..a1471382 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -48,7 +48,7 @@ jobs: strategy: matrix: PyTorch_latest: - image: speediedan/interpretune:py3.12-pt2.9.0-azpl-init + image: speediedan/interpretune:py3.12-pt2.9.1-azpl-init scope: "" timeoutInMinutes: 100 cancelTimeoutInMinutes: 2 @@ -146,31 +146,52 @@ jobs: displayName: 'Maybe reset HF cache' - bash: | - . /tmp/venvs/it_dev/bin/activate - python -m pip install --upgrade pip setuptools setuptools-scm wheel build - python -m pip install -r requirements/ci/requirements.txt -r requirements/ci/platform_dependent.txt --no-warn-script-location - python -m pip install -e '.[test,examples,lightning]' --no-warn-script-location - if ([ "${APPLY_POST_UPGRADES:-}" = "1" ] || [ "${APPLY_POST_UPGRADES:-}" = "true" ]) && [ -s requirements/ci/post_upgrades.txt ]; then - echo "Applying post-upgrades (requirements/ci/post_upgrades.txt)..." - python -m pip install --upgrade -r requirements/ci/post_upgrades.txt --cache-dir "$PIP_CACHE_DIR" - else - echo "Skipping post-upgrades (either disabled or file empty)." + set -e # Exit on any error + source /tmp/venvs/it_dev/bin/activate + + echo "=== Installing interpretune in editable mode with git dependencies ===" + if ! uv pip install -e . --group git-deps; then + echo "ERROR: Failed to install interpretune in editable mode" + exit 1 + fi + echo "✓ Interpretune installation completed" + + echo "=== Installing locked CI requirements ===" + if ! uv pip install -r requirements/ci/requirements.txt; then + echo "ERROR: Failed to install locked CI requirements" + echo "Checking for permission issues in venv..." + find /tmp/venvs/it_dev -type d -name __pycache__ ! -writable 2>/dev/null | head -20 + exit 1 fi - python -m pip list + echo "✓ CI requirements installation completed" + + echo "=== Installed packages ===" + uv pip list + + echo "=== Verifying critical packages ===" + for pkg in coverage pytest torch transformers; do + if uv pip show "$pkg" >/dev/null 2>&1; then + echo "✓ $pkg is installed" + else + echo "✗ ERROR: $pkg is NOT installed" + exit 1 + fi + done + echo "✓ All critical packages verified" displayName: 'Install dependencies' - bash: | - . /tmp/venvs/it_dev/bin/activate + source /tmp/venvs/it_dev/bin/activate python requirements/utils/collect_env_details.py displayName: 'Env details and package versions' - bash: | - . /tmp/venvs/it_dev/bin/activate + source /tmp/venvs/it_dev/bin/activate python -m coverage run --append --source src/interpretune -m pytest src/interpretune tests -v --junitxml=$(Build.Repository.LocalPath)/test-results.xml --durations=50 displayName: 'Testing: standard' - bash: | - . /tmp/venvs/it_dev/bin/activate + source /tmp/venvs/it_dev/bin/activate export HF_GATED_PUBLIC_REPO_AUTH_KEY=$HF_GATED_PUBLIC_REPO_AUTH_KEY export HF_TOKEN=$HF_TOKEN bash ./tests/special_tests.sh --mark_type=standalone @@ -180,12 +201,12 @@ jobs: HF_TOKEN: $(HF_TOKEN) - bash: | - . /tmp/venvs/it_dev/bin/activate + source /tmp/venvs/it_dev/bin/activate bash ./tests/special_tests.sh --mark_type=profile_ci displayName: 'Testing: CI Profiling' - bash: | - . /tmp/venvs/it_dev/bin/activate + source /tmp/venvs/it_dev/bin/activate python -m coverage report python -m coverage xml python -m coverage html diff --git a/.github/actions/install-ci-dependencies/action.yml b/.github/actions/install-ci-dependencies/action.yml index 923d3c90..1afe4a56 100644 --- a/.github/actions/install-ci-dependencies/action.yml +++ b/.github/actions/install-ci-dependencies/action.yml @@ -1,56 +1,38 @@ name: "Install CI Dependencies" -description: "Install Python dependencies for CI workflows including platform-dependent packages and post-upgrades" +description: "Install Python dependencies for CI workflows" inputs: + python_version: + description: "Python version to use" + required: false + default: "3.12" show_pip_list: - description: "Whether to show pip list output after installations" + description: "Whether to show package list output after installations" required: false default: "false" - apply_post_upgrades: - description: "Whether to apply post-upgrade packages" - required: false - default: "true" runs: using: "composite" steps: - - name: Set up venv and install ci dependencies - shell: bash - run: | - python -m pip install --upgrade pip setuptools setuptools-scm wheel build --cache-dir "$PIP_CACHE_DIR" - # Prefer CI pinned requirements if present - if [ -f requirements/ci/requirements.txt ]; then - pip install -r requirements/ci/requirements.txt --cache-dir "$PIP_CACHE_DIR" - python -m pip install -e '.[test,examples,lightning]' --cache-dir "$PIP_CACHE_DIR" - else - python -m pip install -e '.[test,examples,lightning]' --cache-dir "$PIP_CACHE_DIR" - fi - if [ "${{ inputs.show_pip_list }}" = "true" ]; then - pip list - fi + - name: Install uv and set Python version + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ inputs.python_version }} + activate-environment: true + enable-cache: true - - name: Install platform-dependent packages + - name: Install project dependencies shell: bash run: | - # Install platform-dependent packages with flexible constraints to allow platform-specific resolution - if [ -f requirements/ci/platform_dependent.txt ] && [ -s requirements/ci/platform_dependent.txt ]; then - echo "Installing platform-dependent packages from requirements/ci/platform_dependent.txt..." - python -m pip install -r requirements/ci/platform_dependent.txt --cache-dir "$PIP_CACHE_DIR" || echo "Some platform-dependent packages may not be available on this platform, continuing..." - else - echo "No platform-dependent packages to install." - fi + # interpretune editable with git-deps group (uv doesn't currently support url deps in locked reqs) + echo "Installing interpretune in editable mode with any git url dependencies..." + uv pip install -e . --group git-deps + echo "Installing locked CI requirements..." + uv pip install -r requirements/ci/requirements.txt + - - name: Optional post-upgrades (datasets/fsspec etc) + - name: Show package list + if: inputs.show_pip_list == 'true' shell: bash - env: - APPLY_POST_UPGRADES: ${{ inputs.apply_post_upgrades }} run: | - if ([ "${APPLY_POST_UPGRADES}" = "1" ] || [ "${APPLY_POST_UPGRADES}" = "true" ]) && [ -s requirements/ci/post_upgrades.txt ]; then - echo "Applying post-upgrades (requirements/ci/post_upgrades.txt)..." - python -m pip install --upgrade -r requirements/ci/post_upgrades.txt --cache-dir "$PIP_CACHE_DIR" - if [ "${{ inputs.show_pip_list }}" = "true" ]; then - pip list - fi - else - echo "Skipping post-upgrades (either disabled or file empty)." - fi + uv pip list diff --git a/.github/actions/regen-ci-reqs/action.yml b/.github/actions/regen-ci-reqs/action.yml index 8acae805..2df5bbcb 100644 --- a/.github/actions/regen-ci-reqs/action.yml +++ b/.github/actions/regen-ci-reqs/action.yml @@ -12,7 +12,7 @@ inputs: compare_paths: description: Space-separated list of files/paths to compare with git diff required: false - default: "requirements/ci/requirements.txt requirements/ci/post_upgrades.txt requirements/ci/platform_dependent.txt" + default: "requirements/ci/requirements.txt" patch_path: description: Path to write the patch file required: false @@ -20,21 +20,16 @@ inputs: runs: using: composite steps: - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ inputs.python_version }} - - - name: Install regen deps + - name: Install uv shell: bash run: | - python -m pip install --upgrade pip - python -m pip install pip-tools toml + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.cargo/bin" >> $GITHUB_PATH - - name: Run regen (pip-compile) + - name: Run lock regeneration shell: bash run: | - python requirements/utils/regen_reqfiles.py --mode pip-compile --ci-output-dir=${{ inputs.ci_output_dir }} + bash requirements/utils/lock_ci_requirements.sh - id: check name: Check for diffs and write patch diff --git a/.github/actions/run-pytest-instrumented/action.yml b/.github/actions/run-pytest-instrumented/action.yml index 06fe5601..a51f529a 100644 --- a/.github/actions/run-pytest-instrumented/action.yml +++ b/.github/actions/run-pytest-instrumented/action.yml @@ -99,7 +99,6 @@ runs: PYTHONPATH: ${{ inputs.workspace }}/src PYTEST_ADDOPTS: "--log-cli-level=${{ inputs.it_ci_log_level }} --log-cli-format='%(asctime)s [%(levelname)8s] %(name)s: %(message)s' --capture=no" PYTEST_FILTER_PATTERN: ${{ inputs.pytest_filter_pattern }} - PIP_CACHE_DIR: ${{ inputs.pip_cache_dir }} run: | if [ -n "$PYTEST_FILTER_PATTERN" ]; then echo "PYTEST_FILTER_PATTERN is set to '$PYTEST_FILTER_PATTERN'. Disabling coverage collection and running filtered tests only." diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 1bed66f0..e3093379 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -31,36 +31,112 @@ cd /home/runner/work/interpretune/interpretune && python -m pytest src/interpret ## Build and Validation Commands ### Environment Setup -Always install dependencies in order to avoid conflicts: +Development environment uses `uv` for fast, reliable dependency management: ```bash -# Basic development setup -python -m pip install --upgrade pip setuptools setuptools-scm wheel build -python -m pip install -r requirements/ci/requirements.txt -r requirements/ci/platform_dependent.txt -python -m pip install -e '.[test,examples,lightning]' - -# If circuit-tracer install fails, use the built-in tool after basic install: -pip install interpretune[examples] -interpretune-install-circuit-tracer -``` +# Install uv (one-time setup) +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Create development environment (creates traditional venv at ~/.venvs/it_latest) +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest + +# Activate the environment +source ~/.venvs/it_latest/bin/activate -**⚠️ Known Issue:** Full dependency install may timeout due to large ML packages. Install basic deps first, then add extras incrementally. +# Run commands directly (no need for 'uv run') +python --version +python -m pytest tests/ +``` ### Development Environment Scripts -For complex setups, use the provided build script: +Use the provided build script for automated setup: ```bash # Standard development build (recommended for dev work) ./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest -# Build with a circuit-tracer commit pin +# Quick editable install (without locked requirements) +uv pip install -e ".[test,examples,lightning,profiling]" --group git-deps dev + +# Venv Location Options (for hardlink performance and standalone process wrappers): +# +# OPTION 1 (Recommended for standalone process wrappers): Use --venv-dir to set BASE directory +# The venv will be created at: / +# This is most robust when using with manage_standalone_processes.sh: +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --venv-dir=/mnt/cache/username/.venvs +# Creates venv at: /mnt/cache/username/.venvs/it_latest +# +# OPTION 2: Use IT_VENV_BASE environment variable to set base directory +# This approach uses IT_VENV_BASE as base + target_env_name: +export IT_VENV_BASE=/mnt/cache/username/.venvs ./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest +# Creates venv at: /mnt/cache/username/.venvs/it_latest +# +# OPTION 3: Use default (~/.venvs/) - simplest but may cause hardlink warnings +# If UV cache is on different filesystem, you'll see "Failed to hardlink files" warnings +# Creates venv at: ~/.venvs/it_latest +# +# Why placement matters: UV uses hardlinks for fast installs, but hardlinks only work within +# the same filesystem. Placing venv on same filesystem as UV cache ensures fast installs and +# no warnings. Example UV cache location: /mnt/cache/username/.cache/uv + +# Build with specific PyTorch nightly version +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --torch_dev_ver=dev20240201 + +# Build with PyTorch test channel +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --torch_test_channel + +# Build with single package from source (no extras) +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --from-source="circuit_tracer:${HOME}/repos/circuit-tracer" + +# Build with package from source with extras +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all" + +# Build with package from source with extras and environment variable +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" + +# Build with multiple packages from source (using multiple --from-source flags - cleaner!) +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest \ + --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" \ + --from-source="circuit_tracer:${HOME}/repos/circuit-tracer" + +# Build with multiple packages from source (using semicolon separator - also supported) +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1;circuit_tracer:${HOME}/repos/circuit-tracer" + +# Important: When using with manage_standalone_processes.sh wrapper, use --venv-dir: +~/repos/interpretune/scripts/manage_standalone_processes.sh --use-nohup scripts/build_it_env.sh \ + --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest \ + --venv-dir=/mnt/cache/speediedan/.venvs/it_latest \ + --from-source="finetuning_scheduler:~/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" ``` +**Important: Git Dependency Caching and override-dependencies** + +When installing from-source packages that specify git dependencies (e.g., finetuning-scheduler with USE_CI_COMMIT_PIN=1 pinning Lightning to a specific commit), UV's caching ensures correct behavior: + +1. From-source packages are installed FIRST with all their dependencies +2. UV caches git dependencies by their fully-resolved commit hash +3. When interpretune is subsequently installed, UV respects the cached commit-pinned versions +4. The [tool.uv] override-dependencies in pyproject.toml replaces interpretune's git URL dependencies with version constraints, allowing editable installations to satisfy requirements + +See [UV's dependency caching docs](https://docs.astral.sh/uv/concepts/cache/#dependency-caching) for details on git dependency caching behavior. + +**From-Source Package Version Requirements:** + +When installing packages from source (especially transformer-lens), ensure the package version in the source repo satisfies dependent package requirements: +- circuit-tracer requires `transformer-lens>=v2.16.0` +- TransformerLens repo default version (with the old v2 poetry install) is 0.0.0 in pyproject.toml (set by CI pipeline on release) +- For local development, update TransformerLens version to 2.16.1 or higher: `sed -i 's/version="0\.0\.0"/version="2.16.1"/' ~/repos/TransformerLens/pyproject.toml` +- This ensures circuit-tracer's dependency is satisfied without UV upgrading transformer-lens to PyPI version +- This should not be necessary when installing transformer_lens from source with versions >= 3.0.0 as uv is used + ### Linting and Code Quality -**Always run linting before committing:** +**Always run linting before committing (assumes activated venv):** ```bash +# Activate your environment first +source ~/.venvs/it_latest/bin/activate + # Run ruff linting (configured in pyproject.toml) # we don't have ruff installed as a separate package but use it via pre-commit (with the --fix flag) # there are two phases, the check and format, run each separately @@ -74,8 +150,11 @@ pre-commit run --all-files **Expected Ruff Issues:** The `tests/*_parity/` directories contain imported research code with many linting violations - these are intentionally excluded from pre-commit checks. ### Testing -**Test command:** +**Test command (assumes activated venv):** ```bash +# Activate your environment first +source ~/.venvs/it_latest/bin/activate + # Basic test run (requires full dependencies) cd /home/runner/work/interpretune/interpretune && python -m pytest src/interpretune tests -v @@ -130,7 +209,6 @@ src/it_examples/ # Example experiments ### Key Entry Points - Console script: `interpretune` → `interpretune.base.components.cli:bootstrap_cli` -- Circuit-tracer installer: `interpretune-install-circuit-tracer` ## CI and Validation Pipeline @@ -142,10 +220,32 @@ src/it_examples/ # Example experiments **Timeout:** 90 minutes **CI Process:** -1. Install dependencies with constraints -2. Run pytest with coverage -3. Resource monitoring (Linux only) -4. Upload artifacts on failure +1. Install interpretune in editable mode with git dependencies +2. Install locked CI requirements (all PyPI packages) +3. Run pytest with coverage +4. Resource monitoring (Linux only) +5. Upload artifacts on failure + +**CI Installation Flow:** +```bash +# Step 1: Install interpretune editable + git dependencies +uv pip install -e . --group git-deps + +# Step 2: Install all locked PyPI dependencies +uv pip install -r requirements/ci/requirements.txt +``` + +**Development Installation Flow (build_it_env.sh):** +```bash +# Step 1: Install interpretune editable + git dependencies +uv pip install -e . --group git-deps + +# Step 2: Install locked CI requirements +uv pip install -r requirements/ci/requirements.txt + +# Step 3: Install from-source packages (if specified) +# These override any PyPI/git versions for development +``` **Environment Variables for CI:** - `IT_CI_LOG_LEVEL` - Defaults to "INFO", set to "DEBUG" for verbose logging @@ -162,6 +262,10 @@ Note: the GPU pipeline runs only when a PR is ready for review and an admin appr ### Manual Validation Steps +```bash +# Activate your environment first +source ~/.venvs/it_latest/bin/activate + # Run ruff linting (configured in pyproject.toml) # we don't have ruff installed as a separate package but use it via pre-commit (with the --fix flag) # there are two phases, the check and format, run each separately @@ -170,19 +274,34 @@ pre-commit run ruff-format --all-files # Run pre-commit hooks (includes ruff, docformatter, yaml checks) pre-commit run --all-files +``` -### Regenerating stable CI dependency pins - -When updating top-level requirements or periodically refreshing CI pins, use the repository helper to regenerate and compile the CI requirement files. This workflow updates `requirements/*` and writes compiled CI pins to `requirements/ci`. +### Updating dependencies -Run these commands from your repo home after activating ensuring you've activated any relevant venv (e.g. `source ~/.venvs/${target_env_name}/bin/activate`): +When updating dependencies, edit `pyproject.toml` and regenerate locked requirements: ```bash -python requirements/utils/regen_reqfiles.py --mode pip-compile --ci-output-dir=requirements/ci +# Edit pyproject.toml to update version constraints + +# Regenerate locked CI requirements +./requirements/utils/lock_ci_requirements.sh + +# Rebuild your development environment +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest + +# Or update manually in an activated environment +source ~/.venvs/it_latest/bin/activate +uv pip install --upgrade + +# After updating, test thoroughly +python -m pytest tests/ -v ``` Notes: -- Regenerating pins may change CI dependency resolution — run the full CI (or at least the CPU GitHub Actions CI) after updating pins to validate. Don't update pins aggressively, this is done periodically anyway, focus mostly on the issue at hand without changing the CI pins unless you think it is related to the issue. +- Dependencies are specified in `pyproject.toml` with optional extras and dependency groups +- CI uses locked requirements (requirements/ci/requirements.txt) for reproducibility +- Development can use either locked requirements (via build script) or direct installation +- Always run the full CI after dependency changes to validate compatibility across platforms ### Type-checking caveat @@ -191,17 +310,11 @@ Full repository type-checking is a work in progress. Current local checks may on ## Special Dependencies and Known Issues ### Circuit-Tracer Dependency -**Issue:** circuit-tracer is not on PyPI, requires git-based install - -**Solutions:** -1. Use built-in installer: `interpretune-install-circuit-tracer` -2. Manual install: `pip install git+https://github.com/speediedan/circuit-tracer.git@` -3. Environment variable control: `IT_USE_CT_COMMIT_PIN=1` +**Note:** circuit-tracer is installed directly from git as specified in `pyproject.toml`. The commit is pinned in the examples optional dependencies: `circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b28...`. When you run `uv pip install -e ".[examples]"`, uv resolves and installs this git dependency automatically. ### Dependency Constraints - **torch** requires 2.7.1+ for newer features -- **setuptools** requires 77.0.0+ for PEP 639 support -- **pip** requires < 25.3 to avoid issues with pip-tools https://github.com/jazzband/pip-tools/issues/2252 +- **setuptools** requires 77.0.0+ for PEP 639 support (used in build system) ### Import Dependencies - `transformer_lens` and `sae_lens` have complex initialization requirements diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 751a823d..820b5c18 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -132,7 +132,6 @@ jobs: python-version: ["3.12"] timeout-minutes: 90 env: - IT_USE_CT_COMMIT_PIN: "1" WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} HF_GATED_PUBLIC_REPO_AUTH_KEY: ${{ secrets.HF_GATED_PUBLIC_REPO_AUTH_KEY }} ### we need to set the below secrets/vars for Grafana Alloy monitoring in our composite action (which doesn't have direct access to those workflow inputs) @@ -147,33 +146,12 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Reset caching - id: set_time_period - run: | - python -c "import time; days = time.time() / 60 / 60 / 24; print(f'TIME_PERIOD=d{int(days / 7) * 7}')" >> $GITHUB_OUTPUT - - name: Install libpq (macOS only) if: runner.os == 'macOS' run: | brew reinstall libpq brew link --force libpq - # Note: This uses an internal pip API and may not always work - # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow - - name: Get pip cache dir - id: pip-cache - shell: bash - run: | - echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV - - - name: pip cache - uses: actions/cache@v4 - with: - path: ${{ env.PIP_CACHE_DIR }}/wheels - key: ${{ runner.os }}-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py${{ matrix.python-version }}-${{ hashFiles('requirements/ci/requirements.txt') }}-${{ hashFiles('requirements/ci/post_upgrades.txt') }}-${{ hashFiles('requirements/ci/platform_dependent.txt') }} - restore-keys: | - ${{ runner.os }}-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py${{ matrix.python-version }}- - - name: Install CI dependencies uses: ./.github/actions/install-ci-dependencies with: diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index b7f39427..7c8b75ee 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -32,36 +32,16 @@ jobs: with: python-version: '3.12' - - name: Reset caching - id: set_time_period - run: | - python -c "import time; days = time.time() / 60 / 60 / 24; print(f'TIME_PERIOD=d{int(days / 7) * 7}')" >> $GITHUB_OUTPUT - - - name: Get pip cache dir - id: pip-cache - shell: bash - run: | - echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV - - - name: pip cache - uses: actions/cache@v4 - with: - path: ${{ env.PIP_CACHE_DIR }}/wheels - key: ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12-${{ hashFiles('requirements/ci/requirements.txt') }}-${{ hashFiles('requirements/ci/post_upgrades.txt') }}-${{ hashFiles('requirements/ci/platform_dependent.txt') }} - restore-keys: | - ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12- - - name: Install CI dependencies uses: ./.github/actions/install-ci-dependencies with: show_pip_list: "true" - apply_post_upgrades: "true" - - name: Setup pyright, precommit and git lfs + - name: Setup development tools shell: bash run: | - # Install pyright, pre-commit, and git-lfs - pip install --upgrade pyright pre-commit git-lfs + # Install development tools via dev dependency group + uv pip install --group dev # pyright -p pyproject.toml pre-commit install git lfs install diff --git a/.github/workflows/regen-ci-req-check.yml b/.github/workflows/regen-ci-req-check.yml index af6a87cc..e8f0d125 100644 --- a/.github/workflows/regen-ci-req-check.yml +++ b/.github/workflows/regen-ci-req-check.yml @@ -21,7 +21,7 @@ jobs: with: python_version: '3.12' ci_output_dir: requirements/ci - compare_paths: "requirements/ci/requirements.txt requirements/ci/post_upgrades.txt requirements/ci/platform_dependent.txt" + compare_paths: "requirements/ci/requirements.txt" patch_path: /tmp/regen_diff.patch - name: Create PR with updated pins diff --git a/.github/workflows/regen-ci-req-report.yml b/.github/workflows/regen-ci-req-report.yml index 46e27722..49e5ae88 100644 --- a/.github/workflows/regen-ci-req-report.yml +++ b/.github/workflows/regen-ci-req-report.yml @@ -62,7 +62,7 @@ jobs: with: python_version: '3.12' ci_output_dir: requirements/ci - compare_paths: "requirements/ci/requirements.txt requirements/ci/post_upgrades.txt requirements/ci/platform_dependent.txt" + compare_paths: "requirements/ci/requirements.txt" patch_path: /tmp/regen_diff.patch - name: Upload regen diff (if present) diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml index cdd2ca9b..6d9f65da 100644 --- a/.github/workflows/type-check.yml +++ b/.github/workflows/type-check.yml @@ -75,7 +75,7 @@ jobs: uses: actions/cache@v4 with: path: ${{ env.PIP_CACHE_DIR }}/wheels - key: ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12-${{ hashFiles('requirements/ci/requirements.txt') }}-${{ hashFiles('requirements/ci/post_upgrades.txt') }}-${{ hashFiles('requirements/ci/platform_dependent.txt') }} + key: ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12-${{ hashFiles('requirements/ci/requirements.txt') }} restore-keys: | ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12- diff --git a/.gitignore b/.gitignore index 3b9ad781..dd437610 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,9 @@ env/ venv/ ENV/ +# uv +.python-version + # Spyder project settings .spyderproject .spyproject diff --git a/Makefile b/Makefile index 64c9848b..78b9786f 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,6 @@ .PHONY: help lint test build sdist wheel clean export SPHINX_MOCK_REQUIREMENTS=1 -# TODO: remove this once CT is available via PyPI or our own package index -export IT_USE_CT_COMMIT_PIN="1" help: @echo "Makefile commands:" @@ -26,19 +24,17 @@ clean: rm -rf ./docs/source/api lint: - # run ruff via pre-commit where available + # run ruff via pre-commit where available (assumes activated venv) pre-commit run ruff-check --all-files pre-commit run ruff-format --all-files test: clean - pip install -r requirements/devel.txt - # TODO: remove this once no longer necessary - @echo "Using IT_USE_CT_COMMIT_PIN for circuit-tracer installation" - interpretune-install-circuit-tracer # run tests with coverage (cpu-only, running gpu standalone tests required for full coverage) + # assumes activated venv with interpretune installed python -m coverage run --append --source src/interpretune -m pytest src/interpretune tests -v python -m coverage report docs: clean - pip install --quiet -r requirements/docs.txt + # assumes activated venv + uv pip install --quiet -e . --group docs python -m sphinx -b html -W --keep-going docs/source docs/build diff --git a/README.md b/README.md index b90b1d09..3f99b0c9 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,62 @@ Interpretune aims to provide tools and infrastructure for: - Collaborative research with shareable/composable experimentation - Flexible tuning and interpretability workflows +## Installation + +Interpretune uses [uv](https://github.com/astral-sh/uv) for fast, reliable dependency management. + +### Quick Start + +```bash +# Install uv (one-time setup) +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Clone the repository +git clone https://github.com/speediedan/interpretune.git +cd interpretune + +# Create and activate a virtual environment +uv venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install interpretune in editable mode with all dependencies +# Note: git-deps group is optional once circuit-tracer is published on PyPI +uv pip install -e ".[test,examples,lightning,profiling]" --group git-deps dev + +# Run tests +pytest tests/ -v +``` + +### Development Setup + +For advanced development workflows, use the provided build script which supports locked CI requirements and from-source packages: + +```bash +# Standard development build (uses locked CI requirements) +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest + +# Activate the created environment +source ~/.venvs/it_latest/bin/activate + +# Build with packages from source (useful for development) +./scripts/build_it_env.sh --repo_home=${PWD} --target_env_name=it_latest \ + --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all" \ + --from-source="circuit_tracer:${HOME}/repos/circuit-tracer" +``` + +### Locked Requirements for CI + +CI workflows use locked requirements for reproducibility: + +```bash +# Install using locked CI requirements (CI approach) +uv pip install -e . --group git-deps +uv pip install -r requirements/ci/requirements.txt + +# Regenerate locked requirements (after updating pyproject.toml) +./requirements/utils/lock_ci_requirements.sh +``` + ## Project Status This project is in the **pre-MVP** stage. Features and APIs are subject to change. Contributions and feedback are welcome as the framework evolves. diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index ff167309..42d78918 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -17,7 +17,7 @@ ARG OS_VER=ubuntu22.04 FROM nvidia/cuda:${CUDA_VERSION}-devel-${OS_VER} ARG PYTHON_VERSION=3.12 -ARG PYTORCH_VERSION=2.9.0 +ARG PYTORCH_VERSION=2.9.1 ARG CUST_BUILD=0 ARG MKL_THREADING_LAYER=GNU @@ -31,8 +31,9 @@ ENV \ CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \ TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;9.0;10.0;12.0" \ MKL_THREADING_LAYER=${MKL_THREADING_LAYER} \ - MAKEFLAGS="-j2" -RUN apt-get update -qq --fix-missing && \ + MAKEFLAGS="-j2" \ + UV_NO_CACHE=1 + RUN apt-get update -qq --fix-missing && \ apt-get install -y --no-install-recommends \ build-essential \ pkg-config \ @@ -52,80 +53,72 @@ RUN apt-get update -qq --fix-missing && \ apt-get install -y \ python${PYTHON_VERSION} \ python${PYTHON_VERSION}-dev \ - python3-pip \ && \ update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \ update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 && \ + # Install uv for fast, reliable package management (system-wide) + curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/usr/local/bin sh && \ # Cleaning apt-get autoremove -y && \ apt-get clean && \ rm -rf /root/.cache && \ rm -rf /var/lib/apt/lists/* -COPY ./requirements.txt requirements.txt COPY ./requirements/ ./requirements/ ENV PYTHONPATH=/usr/lib/python${PYTHON_VERSION}/site-packages -# ENV USE_CI_COMMIT_PIN="1" RUN \ - wget https://bootstrap.pypa.io/get-pip.py --progress=bar:force:noscroll --no-check-certificate | python${PYTHON_VERSION} && \ - python${PYTHON_VERSION} get-pip.py && \ - rm get-pip.py && \ - python${PYTHON_VERSION} -m pip install --upgrade pip setuptools && \ - # Disable cache - pip config set global.cache-dir false && \ - pip install virtualenv && \ - mkdir /tmp/venvs && \ - python -m virtualenv -p python${PYTHON_VERSION} /tmp/venvs/it_dev && \ + # Set umask to make all new files/dirs world-writable by default + umask 000 && \ + # Create virtual environment with uv + uv venv /tmp/venvs/it_dev --python python${PYTHON_VERSION} && \ . /tmp/venvs/it_dev/bin/activate && \ # set particular PyTorch version by default if [[ "${CUST_BUILD}" -eq 0 ]]; then \ CUDA_VERSION_MM=${CUDA_VERSION%.*}; \ - pip install torch --no-cache-dir \ + uv pip install torch --no-cache \ --find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \ --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \ --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton"; \ else \ # or target a specific cuda build, by specifying a particular index url w/... # ... default channel - pip install torch --index-url https://download.pytorch.org/whl/cu128; \ + uv pip install torch --index-url https://download.pytorch.org/whl/cu128; \ # ... pytorch nightly dev version - # pip install --pre torch==2.9.0.dev20250811 --index-url https://download.pytorch.org/whl/nightly/cu128; \ + # uv pip install --pre torch==2.9.0.dev20250811 --index-url https://download.pytorch.org/whl/nightly/cu128; \ # ... test channel - # pip install --pre torch==2.9.0 --index-url https://download.pytorch.org/whl/test/cu128; \ + # uv pip install --pre torch==2.9.0 --index-url https://download.pytorch.org/whl/test/cu128; \ fi && \ - # We avoid installing Lightning and other dependencies here as they are usually upgraded anyway later in - # CI but we may re-enable in the future. - # LIGHTNING_COMMIT=$(cat ./requirements/lightning_pin.txt) && \ - # pip install "lightning @ git+https://github.com/Lightning-AI/lightning.git@${LIGHTNING_COMMIT}#egg=lightning" --no-cache-dir && \ - # Install all requirements - # pip install -r requirements/devel.txt --no-cache-dir && \ - # Update six - # pip install -U six --no-cache-dir && \ - chmod -R 777 /tmp/venvs/it_dev && \ - rm -rf requirements.* requirements/ + rm -rf requirements/ + +# Set VIRTUAL_ENV to activate the venv for subsequent RUN commands +ENV VIRTUAL_ENV=/tmp/venvs/it_dev +ENV PATH="${VIRTUAL_ENV}/bin:$PATH" +# Separate RUN to verify permissions persist across Docker layers RUN \ - set -x && \ - . /tmp/venvs/it_dev/bin/activate && \ - echo "Checking CUDA version:" && \ - CUDA_VERSION_MAJOR=$(python -c "import torch; print(torch.version.cuda.split('.')[0])") && \ - echo "CUDA Version Major: ${CUDA_VERSION_MAJOR}" && \ - echo "Checking Python version compatibility:" && \ - py_ver=$(python -c "print(int('$PYTHON_VERSION'.split('.') >= '3.12'.split('.')))") && \ - echo "Python version check result: ${py_ver}" + echo "=== Post-ENV permission check ===" && \ + ls -ld /tmp/venvs/it_dev && \ + echo "Checking torch __pycache__ permissions:" && \ + find /tmp/venvs/it_dev -type d -path "*/torch/__pycache__" | head -3 | xargs ls -ld || echo "No torch __pycache__ found" && \ + echo "=== Post-ENV check complete ===" RUN \ set -x && \ echo "============= Environment Information =============" && \ . /tmp/venvs/it_dev/bin/activate && \ - echo "Pip version:" && \ - pip --version && \ + echo "UV version:" && \ + uv --version && \ echo "Installed packages:" && \ - pip list && \ + uv pip list && \ echo "Verifying Python version:" && \ python -c "import sys; ver = sys.version_info; print(f'Python {ver.major}.{ver.minor} detected'); assert f'{ver.major}.{ver.minor}' == '$PYTHON_VERSION', ver" && \ echo "Verifying PyTorch version:" && \ python -c "import torch; print(f'PyTorch {torch.__version__} detected'); assert torch.__version__.startswith('$PYTORCH_VERSION'), torch.__version__" && \ - echo "============= Environment Verification Complete =============" + echo "============= Environment Verification Complete =============" && \ + # Apply recursive chmod to allow userns-remapped rootless Docker users to update the venv + chmod -R 777 /tmp/venvs/it_dev && \ + echo "Verify permission state:" && \ + find /tmp/venvs/it_dev -type d -path "*/torch/__pycache__" | head -3 | xargs ls -ld || echo "No torch __pycache__ found" && \ + echo "=== Post-env-check chmod complete ===" diff --git a/dockers/docker_images_main.sh b/dockers/docker_images_main.sh index 82b97b3d..f0d584c1 100755 --- a/dockers/docker_images_main.sh +++ b/dockers/docker_images_main.sh @@ -43,7 +43,7 @@ maybe_build(){ build_eval(){ # latest PyTorch image supported by release # see CUDA_ARCHES_FULL_VERSION for the full version of the pytorch-provided toolkit - declare -A iv=(["cuda"]="12.8.1" ["python"]="3.12" ["pytorch"]="2.9.0" ["cust_build"]="1") + declare -A iv=(["cuda"]="12.8.1" ["python"]="3.12" ["pytorch"]="2.9.1" ["cust_build"]="1") export latest_pt="base-cu${iv["cuda"]}-py${iv["python"]}-pt${iv["pytorch"]}" export latest_azpl="py${iv["python"]}-pt${iv["pytorch"]}-azpl-init" maybe_build iv "${latest_pt}" "${latest_azpl}" diff --git a/dockers/docker_images_release.sh b/dockers/docker_images_release.sh index 50c5afe3..363e4fa7 100755 --- a/dockers/docker_images_release.sh +++ b/dockers/docker_images_release.sh @@ -41,7 +41,7 @@ maybe_build(){ build_eval(){ # latest PyTorch image supported by release - declare -A iv=(["cuda"]="12.8.1" ["python"]="3.12" ["pytorch"]="2.9.0" ["cust_build"]="0") + declare -A iv=(["cuda"]="12.8.1" ["python"]="3.12" ["pytorch"]="2.9.1" ["cust_build"]="0") export latest_pt="base-cu${iv["cuda"]}-py${iv["python"]}-pt${iv["pytorch"]}" export latest_azpl="py${iv["python"]}-pt${iv["pytorch"]}-azpl-init" maybe_build iv "${latest_pt}" "${latest_azpl}" diff --git a/dockers/it-az-base/.bashrc b/dockers/it-az-base/.bashrc index 701edcdd..fd19db9e 100644 --- a/dockers/it-az-base/.bashrc +++ b/dockers/it-az-base/.bashrc @@ -29,7 +29,7 @@ shopt -s checkwinsize #shopt -s globstar # make less more friendly for non-text input files, see lesspipe(1) -[ -x /usr/bin/lesspipe ] && eval "$(SHELL=/bin/sh lesspipe)" +[ -x /usr/bin/lesspipe ] && eval "$(SHELL=/bin/bash lesspipe)" # set variable identifying the chroot you work in (used in the prompt below) if [ -z "${debian_chroot:-}" ] && [ -r /etc/debian_chroot ]; then diff --git a/dockers/it-az-base/Dockerfile b/dockers/it-az-base/Dockerfile index aab2e820..b7c77af7 100644 --- a/dockers/it-az-base/Dockerfile +++ b/dockers/it-az-base/Dockerfile @@ -11,7 +11,7 @@ # limitations under the License. ARG PYTHON_VERSION=3.12 -ARG PYTORCH_VERSION=2.9.0 +ARG PYTORCH_VERSION=2.9.1 ARG CUST_BASE FROM speediedan/interpretune:base-${CUST_BASE}py${PYTHON_VERSION}-pt${PYTORCH_VERSION} diff --git a/pyproject.toml b/pyproject.toml index 72412866..774a39ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,69 +53,83 @@ interpretune = "interpretune.base.components.cli:bootstrap_cli" [project.optional-dependencies] lightning = [ - # TODO: revert this FTS commit pin and/or old FTS version once lightning publishes a new release - # "finetuning-scheduler @ git+https://github.com/speediedan/finetuning-scheduler.git@44a8e62fdc0fa9b08ef770160e2fa25a89f389f4", - #"finetuning-scheduler[examples, cli, extra] >= 2.5.0", "finetuning-scheduler >= 2.5.0", - "bitsandbytes", + "bitsandbytes; platform_system != 'Darwin'", "peft", ] examples = [ -"wandb", -"torch-tb-profiler", -"notebook", -"jupyterlab", -"ipywidgets", -"jupytext >= 1.10", # converting notebook source .py to .ipynb -"nbval >= 0.9.6", # testing the notebook -"python-dotenv", -"plotly", -"matplotlib", -"gdown", -"evaluate", -"scikit-learn", -"neuronpedia", -# TODO: add our packaged circuit-tracer dep (either pypi or pypi fork) once it is available and remove the -# `install_circuit_tracer` tool -# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87", + "wandb", + "torch-tb-profiler", + "notebook", + "jupyterlab", + "ipywidgets", + "jupytext >= 1.10", # converting notebook source .py to .ipynb + "nbval >= 0.9.6", # testing the notebook + "python-dotenv", + "plotly", + "matplotlib", + "gdown", + "evaluate", + "scikit-learn", + "neuronpedia", + # "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87", +] + +[dependency-groups] +# PEP 735 dependency groups for development and CI + +# Git URL dependencies - packages installed from git repositories +# These are separated because they cannot be included in universal lock files +# Update commit hashes when cutting release branches to pin to tested versions +git-deps = [ + "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87", +] + +dev = [ + # Core development tools + "uv >= 0.5.0", + "pre-commit >= 1.0", + "pyright >= 1.1.365", + "toml", ] -docs = [ -"sphinx >= 4.0", -"myst-parser >= 0.18.1", -"nbsphinx >= 0.8.5", -"pandoc >= 1.0", -"docutils >= 0.16", -"sphinxcontrib-fulltoc >= 1.0", -"sphinxcontrib-mockautodoc", -"sphinx-autodoc-typehints >= 1.16", -"sphinx-paramlinks >= 0.5.1", -"sphinx-togglebutton >= 0.2", -"sphinx-copybutton >= 0.3", -"typing-extensions", -"jinja2 >= 3.0.0,<3.1.0", -"pt_lightning_sphinx_theme @ git+https://github.com/speediedan/lightning_sphinx_theme.git@057f4c3e669948bc618eec1688b016f07140cc0d", +test = [ + { include-group = "dev" }, + # Testing framework and tools + "coverage >= 6.4", + "pytest >= 6.0", + "pytest-rerunfailures >= 10.2", + # Additional test dependencies + "twine >= 3.2", + "psycopg", + "huggingface_hub[hf_xet]", + # Notebook testing + "nbmake >= 1.5.0", + "papermill >= 2.4.0", ] -test = [ -"coverage >= 6.4", -"pytest >= 6.0", -"pytest-rerunfailures >= 10.2", -"twine >= 3.2", -"pyright >= 1.1.365", -"pre-commit >= 1.0", -"psycopg", -"toml", -"pip-tools >= 7.5.1", -"pip < 25.3", # temporarily need to avoid pip 25.3 due to pip-tools https://github.com/jazzband/pip-tools/issues/2252 -"huggingface_hub[hf_xet]", -"nbmake >= 1.5.0", -"papermill >= 2.4.0", +docs = [ + # Documentation generation + "sphinx >= 4.0", + "myst-parser >= 0.18.1", + "nbsphinx >= 0.8.5", + "pandoc >= 1.0", + "docutils >= 0.16", + "sphinxcontrib-fulltoc >= 1.0", + "sphinxcontrib-mockautodoc", + "sphinx-autodoc-typehints >= 1.16", + "sphinx-paramlinks >= 0.5.1", + "sphinx-togglebutton >= 0.2", + "sphinx-copybutton >= 0.3", + "typing-extensions", + "jinja2 >= 3.0.0,<3.1.0", + "pt_lightning_sphinx_theme @ git+https://github.com/speediedan/lightning_sphinx_theme.git@057f4c3e669948bc618eec1688b016f07140cc0d", ] profiling = [ -"py-spy", + # Performance profiling tools + "py-spy", ] [tool.setuptools] @@ -173,7 +187,6 @@ order-by-type = false max-complexity = 10 [tool.pyright] -autoSearchPaths=false typeCheckingMode="standard" include = [ "src/", @@ -232,19 +245,3 @@ notebook_metadata_filter = "-all" [tool.jupytext.formats] "notebooks/" = "ipynb" "scripts/" = "py" - -[tool.ci_pinning] -# default pinning behavior for packages not listed explicitly: "major", "minor", "exact", or "none" -default = "major" - -# packages to never relax (keep exact/strict behavior) -strict = ["torch"] - -# packages with platform-dependent versions that should be excluded from pinning -# these will be installed separately with flexible constraints to allow platform-specific resolution -# only applies to direct dependencies we specify in pyproject.toml -platform_dependent = ["bitsandbytes"] - -# packages that should be applied as post-upgrades (package -> desired_version) -[tool.ci_pinning.post_upgrades] -# no packages currently require post installation upgrades diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 38a362e8..00000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ --r ./requirements/base.txt diff --git a/requirements/base.txt b/requirements/base.txt deleted file mode 100644 index c7665c23..00000000 --- a/requirements/base.txt +++ /dev/null @@ -1,6 +0,0 @@ -transformer_lens >= 2.15.4 -sae_lens >= 6.3.1 -torch >=2.7.1 -tabulate >= 0.9.0 -datasets >= 4.0.0 -jsonargparse[signatures] >= 4.35.0,<4.42.0 diff --git a/requirements/ci/circuit_tracer_pin.txt b/requirements/ci/circuit_tracer_pin.txt deleted file mode 100644 index 2f48da40..00000000 --- a/requirements/ci/circuit_tracer_pin.txt +++ /dev/null @@ -1 +0,0 @@ -004f1b2822eca3f0c1ddd2389e9105b3abffde87 diff --git a/requirements/ci/platform_dependent.txt b/requirements/ci/platform_dependent.txt deleted file mode 100644 index 38cb1102..00000000 --- a/requirements/ci/platform_dependent.txt +++ /dev/null @@ -1 +0,0 @@ -bitsandbytes diff --git a/requirements/ci/post_upgrades.txt b/requirements/ci/post_upgrades.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/requirements/ci/requirements.in b/requirements/ci/requirements.in deleted file mode 100644 index faaaf7b8..00000000 --- a/requirements/ci/requirements.in +++ /dev/null @@ -1,36 +0,0 @@ -transformer_lens >= 2.15.4 -sae_lens >= 6.3.1 -torch >=2.7.1 -tabulate >= 0.9.0 -datasets >= 4.0.0 -jsonargparse[signatures] >= 4.35.0,<4.42.0 -finetuning-scheduler >= 2.5.0 -peft -wandb -torch-tb-profiler -notebook -jupyterlab -ipywidgets -jupytext >= 1.10 -nbval >= 0.9.6 -python-dotenv -plotly -matplotlib -gdown -evaluate -scikit-learn -neuronpedia -coverage >= 6.4 -pytest >= 6.0 -pytest-rerunfailures >= 10.2 -twine >= 3.2 -pyright >= 1.1.365 -pre-commit >= 1.0 -psycopg -toml -pip-tools >= 7.5.1 -pip < 25.3 -huggingface_hub[hf_xet] -nbmake >= 1.5.0 -papermill >= 2.4.0 -git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87#egg=circuit-tracer diff --git a/requirements/ci/requirements.txt b/requirements/ci/requirements.txt index 335c4eda..a0edef70 100644 --- a/requirements/ci/requirements.txt +++ b/requirements/ci/requirements.txt @@ -1,102 +1,768 @@ -# -# This file is autogenerated by pip-compile with Python 3.12 -# by the following command: -# -# pip-compile --no-strip-extras --output-file=requirements/ci/requirements.txt requirements/ci/requirements.in -# -circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87 - # via -r requirements/ci/requirements.in -coverage==7.11.0 - # via - # -r requirements/ci/requirements.in +# This file was autogenerated by uv via the following command: +# uv pip compile /home/speediedan/repos/interpretune/pyproject.toml --extra examples --extra lightning --group dev --group test --group profiling --output-file /home/speediedan/repos/interpretune/requirements/ci/requirements.txt --no-strip-extras --universal +absl-py==2.3.1 + # via tensorboard +accelerate==1.11.0 + # via + # peft + # transformer-lens +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.2 + # via + # fsspec + # papermill +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +ansicolors==1.1.8 + # via papermill +anyio==4.11.0 + # via + # httpx + # jupyter-server +appnope==0.1.4 ; sys_platform == 'darwin' + # via ipykernel +argon2-cffi==25.1.0 + # via jupyter-server +argon2-cffi-bindings==25.1.0 + # via argon2-cffi +arrow==1.4.0 + # via isoduration +asttokens==3.0.0 + # via stack-data +async-lru==2.0.5 + # via jupyterlab +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +babe==0.0.7 + # via sae-lens +babel==2.17.0 + # via jupyterlab-server +beartype==0.14.1 + # via transformer-lens +beautifulsoup4==4.14.2 + # via + # gdown + # nbconvert +better-abc==0.0.3 + # via transformer-lens +bitsandbytes==0.48.2 ; platform_machine != 'arm64' or sys_platform != 'darwin' + # via interpretune (pyproject.toml) +bleach[css]==6.3.0 + # via nbconvert +certifi==2025.11.12 + # via + # httpcore + # httpx + # requests + # sentry-sdk +cffi==2.0.0 + # via + # argon2-cffi-bindings + # cryptography + # pyzmq +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.4.4 + # via requests +click==8.3.0 + # via + # nltk + # papermill + # wandb +colorama==0.4.6 ; sys_platform == 'win32' + # via + # click + # ipython + # pytest + # tqdm +comm==0.2.3 + # via + # ipykernel + # ipywidgets +config2py==0.1.44 + # via py2store +contourpy==1.3.3 + # via matplotlib +coverage==7.11.3 + # via + # interpretune (pyproject.toml:test) # nbval -datasets==4.3.0 +cryptography==46.0.3 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' + # via secretstorage +cycler==0.12.1 + # via matplotlib +datasets==4.4.1 # via - # -r requirements/ci/requirements.in + # interpretune (pyproject.toml) # evaluate # sae-lens # transformer-lens +debugpy==1.8.17 + # via ipykernel +decorator==5.2.1 + # via ipython +defusedxml==0.7.1 + # via nbconvert +dill==0.4.0 + # via + # datasets + # evaluate + # multiprocess +distlib==0.4.0 + # via virtualenv +docstring-parser==0.17.0 + # via + # jsonargparse + # simple-parsing +docutils==0.22.3 + # via readme-renderer +dol==0.3.31 + # via + # config2py + # graze + # py2store +einops==0.8.1 + # via transformer-lens +entrypoints==0.4 + # via papermill evaluate==0.4.6 - # via -r requirements/ci/requirements.in -finetuning-scheduler==2.9.0 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +executing==2.2.1 + # via stack-data +fancy-einsum==0.0.3 + # via transformer-lens +fastjsonschema==2.21.2 + # via nbformat +filelock==3.20.0 + # via + # datasets + # gdown + # huggingface-hub + # torch + # transformers + # virtualenv +finetuning-scheduler==2.9.1 + # via interpretune (pyproject.toml) +fonttools==4.60.1 + # via matplotlib +fqdn==1.5.1 + # via jsonschema +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.10.0 + # via + # datasets + # evaluate + # huggingface-hub + # lightning + # pytorch-lightning + # torch gdown==5.2.0 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via wandb +graze==0.1.39 + # via babe +grpcio==1.76.0 + # via tensorboard +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # datasets + # jupyterlab huggingface-hub[hf-xet]==0.36.0 # via - # -r requirements/ci/requirements.in + # interpretune (pyproject.toml:test) # accelerate - # circuit-tracer # datasets # evaluate # peft # tokenizers # transformers -ipywidgets==8.1.7 +i2==0.1.60 + # via config2py +id==1.5.0 + # via twine +identify==2.6.15 + # via pre-commit +idna==3.11 + # via + # anyio + # httpx + # jsonschema + # requests + # yarl +importlib-resources==6.5.2 + # via + # py2store + # typeshed-client +iniconfig==2.3.0 + # via pytest +ipykernel==7.1.0 + # via + # jupyterlab + # nbmake + # nbval +ipython==9.7.0 + # via + # ipykernel + # ipywidgets +ipython-pygments-lexers==1.1.1 + # via ipython +ipywidgets==8.1.8 + # via interpretune (pyproject.toml) +isoduration==20.11.0 + # via jsonschema +jaraco-classes==3.4.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' + # via keyring +jaraco-context==6.0.1 ; platform_machine != 'ppc64le' and platform_machine != 's390x' + # via keyring +jaraco-functools==4.3.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' + # via keyring +jaxtyping==0.3.3 + # via transformer-lens +jedi==0.19.2 + # via ipython +jeepney==0.9.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' + # via + # keyring + # secretstorage +jinja2==3.1.6 + # via + # jupyter-server + # jupyterlab + # jupyterlab-server + # nbconvert + # torch +joblib==1.5.2 + # via + # nltk + # scikit-learn +json5==0.12.1 + # via jupyterlab-server +jsonargparse[signatures]==4.41.0 + # via interpretune (pyproject.toml) +jsonpointer==3.0.0 + # via jsonschema +jsonschema[format-nongpl]==4.25.1 + # via + # jupyter-events + # jupyterlab-server + # nbformat +jsonschema-specifications==2025.9.1 + # via jsonschema +jupyter-client==8.6.3 # via - # -r requirements/ci/requirements.in - # circuit-tracer -jsonargparse[signatures,typing-extensions]==4.41.0 - # via -r requirements/ci/requirements.in + # ipykernel + # jupyter-server + # nbclient + # nbval +jupyter-core==5.9.1 + # via + # ipykernel + # jupyter-client + # jupyter-server + # jupyterlab + # nbclient + # nbconvert + # nbformat +jupyter-events==0.12.0 + # via jupyter-server +jupyter-lsp==2.3.0 + # via jupyterlab +jupyter-server==2.17.0 + # via + # jupyter-lsp + # jupyterlab + # jupyterlab-server + # notebook + # notebook-shim +jupyter-server-terminals==0.5.3 + # via jupyter-server jupyterlab==4.4.10 # via - # -r requirements/ci/requirements.in + # interpretune (pyproject.toml) # notebook +jupyterlab-pygments==0.3.0 + # via nbconvert +jupyterlab-server==2.28.0 + # via + # jupyterlab + # notebook +jupyterlab-widgets==3.0.16 + # via ipywidgets jupytext==1.18.1 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +keyring==25.6.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' + # via twine +kiwisolver==1.4.9 + # via matplotlib +lark==1.3.1 + # via rfc3987-syntax +lightning==2.5.6 + # via finetuning-scheduler +lightning-utilities==0.15.2 + # via + # lightning + # pytorch-lightning + # torchmetrics +markdown==3.10 + # via tensorboard +markdown-it-py==4.0.0 + # via + # jupytext + # mdit-py-plugins + # rich +markupsafe==3.0.3 + # via + # jinja2 + # nbconvert + # werkzeug matplotlib==3.10.7 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +matplotlib-inline==0.2.1 + # via + # ipykernel + # ipython +mdit-py-plugins==0.5.0 + # via jupytext +mdurl==0.1.2 + # via markdown-it-py +mistune==3.1.4 + # via nbconvert +more-itertools==10.8.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' + # via + # jaraco-classes + # jaraco-functools +mpmath==1.3.0 + # via sympy +multidict==6.7.0 + # via + # aiohttp + # yarl +multiprocess==0.70.18 + # via + # datasets + # evaluate +narwhals==2.11.0 + # via plotly +nbclient==0.10.2 + # via + # nbconvert + # nbmake + # papermill +nbconvert==7.16.6 + # via jupyter-server +nbformat==5.10.4 + # via + # jupyter-server + # jupytext + # nbclient + # nbconvert + # nbmake + # nbval + # papermill nbmake==1.5.5 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml:test) nbval==0.11.0 - # via -r requirements/ci/requirements.in -neuronpedia==1.0.22 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +nest-asyncio==1.6.0 + # via ipykernel +networkx==3.5 + # via torch +neuronpedia==1.0.23 + # via interpretune (pyproject.toml) +nh3==0.3.2 + # via readme-renderer +nltk==3.9.2 + # via sae-lens +nodeenv==1.9.1 + # via + # pre-commit + # pyright notebook==7.4.7 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +notebook-shim==0.2.4 + # via + # jupyterlab + # notebook +numpy==1.26.4 + # via + # accelerate + # bitsandbytes + # contourpy + # datasets + # evaluate + # matplotlib + # pandas + # patsy + # peft + # plotly-express + # scikit-learn + # scipy + # statsmodels + # tensorboard + # torchmetrics + # transformer-lens + # transformers +nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +packaging==25.0 + # via + # accelerate + # bitsandbytes + # datasets + # evaluate + # huggingface-hub + # ipykernel + # jupyter-events + # jupyter-server + # jupyterlab + # jupyterlab-server + # jupytext + # lightning + # lightning-utilities + # matplotlib + # nbconvert + # peft + # plotly + # pytest + # pytest-rerunfailures + # pytorch-lightning + # statsmodels + # tensorboard + # torchmetrics + # transformers + # twine + # wandb +pandas==2.3.3 + # via + # babe + # datasets + # evaluate + # plotly-express + # statsmodels + # torch-tb-profiler + # transformer-lens +pandocfilters==1.5.1 + # via nbconvert papermill==2.6.0 - # via -r requirements/ci/requirements.in -peft==0.17.1 - # via -r requirements/ci/requirements.in -pip-tools==7.5.1 - # via -r requirements/ci/requirements.in -plotly==6.3.1 - # via - # -r requirements/ci/requirements.in + # via interpretune (pyproject.toml:test) +parso==0.8.5 + # via jedi +patsy==1.0.2 + # via + # plotly-express + # statsmodels +peft==0.18.0 + # via interpretune (pyproject.toml) +pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' + # via ipython +pillow==12.0.0 + # via + # matplotlib + # tensorboard +platformdirs==4.5.0 + # via + # jupyter-core + # virtualenv + # wandb +plotly==6.4.0 + # via + # interpretune (pyproject.toml) # plotly-express # sae-lens -pre-commit==4.3.0 - # via -r requirements/ci/requirements.in +plotly-express==0.4.1 + # via sae-lens +pluggy==1.6.0 + # via pytest +pre-commit==4.4.0 + # via + # interpretune (pyproject.toml:dev) + # interpretune (pyproject.toml:test) +prometheus-client==0.23.1 + # via jupyter-server +prompt-toolkit==3.0.52 + # via ipython +propcache==0.4.1 + # via + # aiohttp + # yarl +protobuf==6.33.1 + # via + # tensorboard + # wandb +psutil==7.1.3 + # via + # accelerate + # ipykernel + # peft psycopg==3.2.12 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml:test) +ptyprocess==0.7.0 ; os_name != 'nt' or (sys_platform != 'emscripten' and sys_platform != 'win32') + # via + # pexpect + # terminado +pure-eval==0.2.3 + # via stack-data +py-spy==0.4.1 + # via interpretune (pyproject.toml:profiling) +py2store==0.1.22 + # via babe +pyarrow==22.0.0 + # via datasets +pycparser==2.23 ; implementation_name != 'PyPy' + # via cffi +pydantic==2.12.4 + # via wandb +pydantic-core==2.41.5 + # via pydantic +pygments==2.19.2 + # via + # ipython + # ipython-pygments-lexers + # nbconvert + # nbmake + # pytest + # readme-renderer + # rich +pyparsing==3.2.5 + # via matplotlib pyright==1.1.407 - # via -r requirements/ci/requirements.in -pytest==8.4.2 # via - # -r requirements/ci/requirements.in + # interpretune (pyproject.toml:dev) + # interpretune (pyproject.toml:test) +pysocks==1.7.1 + # via requests +pytest==9.0.1 + # via + # interpretune (pyproject.toml:test) # nbmake # nbval # pytest-rerunfailures pytest-rerunfailures==16.1 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml:test) +python-dateutil==2.9.0.post0 + # via + # arrow + # jupyter-client + # matplotlib + # pandas python-dotenv==1.2.1 # via - # -r requirements/ci/requirements.in + # interpretune (pyproject.toml) # neuronpedia # sae-lens -sae-lens==6.20.1 - # via -r requirements/ci/requirements.in +python-json-logger==4.0.0 + # via jupyter-events +pytorch-lightning==2.5.6 + # via lightning +pytz==2025.2 + # via pandas +pywin32-ctypes==0.2.3 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'win32' + # via keyring +pywinpty==3.0.2 ; os_name == 'nt' and sys_platform != 'linux' + # via + # jupyter-server + # jupyter-server-terminals + # terminado +pyyaml==6.0.3 + # via + # accelerate + # datasets + # huggingface-hub + # jsonargparse + # jupyter-events + # jupytext + # lightning + # papermill + # peft + # pre-commit + # pytorch-lightning + # sae-lens + # transformers + # wandb +pyzmq==27.1.0 + # via + # ipykernel + # jupyter-client + # jupyter-server +readme-renderer==44.0 + # via twine +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications + # jupyter-events +regex==2025.11.3 + # via + # nltk + # transformers +requests[socks]==2.32.5 + # via + # datasets + # evaluate + # gdown + # graze + # huggingface-hub + # id + # jupyterlab-server + # neuronpedia + # papermill + # requests-toolbelt + # transformers + # twine + # wandb +requests-toolbelt==1.0.0 + # via twine +rfc3339-validator==0.1.4 + # via + # jsonschema + # jupyter-events +rfc3986==2.0.0 + # via twine +rfc3986-validator==0.1.1 + # via + # jsonschema + # jupyter-events +rfc3987-syntax==1.1.0 + # via jsonschema +rich==14.2.0 + # via + # transformer-lens + # twine +rpds-py==0.28.0 + # via + # jsonschema + # referencing +sae-lens==6.22.0 + # via interpretune (pyproject.toml) +safetensors==0.6.2 + # via + # accelerate + # peft + # sae-lens + # transformers scikit-learn==1.7.2 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +scipy==1.16.3 + # via + # plotly-express + # scikit-learn + # statsmodels +secretstorage==3.4.1 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' + # via keyring +send2trash==1.8.3 + # via jupyter-server +sentencepiece==0.2.1 + # via transformer-lens +sentry-sdk==2.44.0 + # via wandb +setuptools==80.9.0 + # via + # jupyterlab + # lightning-utilities + # tensorboard + # torch +simple-parsing==0.1.7 + # via sae-lens +six==1.17.0 + # via + # python-dateutil + # rfc3339-validator +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +soupsieve==2.8 + # via beautifulsoup4 +stack-data==0.6.3 + # via ipython +statsmodels==0.14.5 + # via plotly-express +sympy==1.14.0 + # via torch tabulate==0.9.0 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +tenacity==9.1.2 + # via + # papermill + # sae-lens +tensorboard==2.20.0 + # via torch-tb-profiler +tensorboard-data-server==0.7.2 + # via tensorboard +terminado==0.18.1 + # via + # jupyter-server + # jupyter-server-terminals +threadpoolctl==3.6.0 + # via scikit-learn +tinycss2==1.4.0 + # via bleach +tokenizers==0.22.1 + # via transformers toml==0.10.2 - # via -r requirements/ci/requirements.in -torch==2.9.0 # via - # -r requirements/ci/requirements.in + # interpretune (pyproject.toml:dev) + # interpretune (pyproject.toml:test) +torch==2.9.1 + # via + # interpretune (pyproject.toml) # accelerate - # circuit-tracer + # bitsandbytes # finetuning-scheduler # lightning # peft @@ -104,19 +770,132 @@ torch==2.9.0 # torchmetrics # transformer-lens torch-tb-profiler==0.4.3 - # via -r requirements/ci/requirements.in + # via interpretune (pyproject.toml) +torchmetrics==1.8.2 + # via + # lightning + # pytorch-lightning +tornado==6.5.2 + # via + # ipykernel + # jupyter-client + # jupyter-server + # jupyterlab + # notebook + # terminado +tqdm==4.67.1 + # via + # datasets + # evaluate + # gdown + # huggingface-hub + # lightning + # nltk + # papermill + # peft + # pytorch-lightning + # transformer-lens + # transformers +traitlets==5.14.3 + # via + # ipykernel + # ipython + # ipywidgets + # jupyter-client + # jupyter-core + # jupyter-events + # jupyter-server + # jupyterlab + # matplotlib-inline + # nbclient + # nbconvert + # nbformat transformer-lens==2.16.1 # via - # -r requirements/ci/requirements.in - # circuit-tracer + # interpretune (pyproject.toml) # sae-lens +transformers==4.57.1 + # via + # peft + # sae-lens + # transformer-lens + # transformers-stream-generator +transformers-stream-generator==0.0.5 + # via transformer-lens +triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch twine==6.2.0 - # via -r requirements/ci/requirements.in -wandb==0.22.2 + # via interpretune (pyproject.toml:test) +typeguard==4.4.4 + # via transformer-lens +typeshed-client==2.8.2 + # via jsonargparse +typing-extensions==4.15.0 # via - # -r requirements/ci/requirements.in + # aiosignal + # anyio + # beautifulsoup4 + # grpcio + # huggingface-hub + # lightning + # lightning-utilities + # psycopg + # pydantic + # pydantic-core + # pyright + # pytorch-lightning + # referencing + # sae-lens + # simple-parsing + # torch + # transformer-lens + # typeguard + # typeshed-client + # typing-inspection + # wandb +typing-inspection==0.4.2 + # via pydantic +tzdata==2025.2 + # via + # arrow + # pandas + # psycopg +uri-template==1.3.0 + # via jsonschema +urllib3==2.5.0 + # via + # requests + # sentry-sdk + # twine +uv==0.9.9 + # via + # interpretune (pyproject.toml:dev) + # interpretune (pyproject.toml:test) +virtualenv==20.35.4 + # via pre-commit +wadler-lindig==0.1.7 + # via jaxtyping +wandb==0.23.0 + # via + # interpretune (pyproject.toml) # transformer-lens - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools +wcwidth==0.2.14 + # via prompt-toolkit +webcolors==25.10.0 + # via jsonschema +webencodings==0.5.1 + # via + # bleach + # tinycss2 +websocket-client==1.9.0 + # via jupyter-server +werkzeug==3.1.3 + # via tensorboard +widgetsnbextension==4.0.15 + # via ipywidgets +xxhash==3.6.0 + # via + # datasets + # evaluate +yarl==1.22.0 + # via aiohttp diff --git a/requirements/devel.txt b/requirements/devel.txt deleted file mode 100644 index 9f31f6e5..00000000 --- a/requirements/devel.txt +++ /dev/null @@ -1,14 +0,0 @@ -# install all mandatory dependencies --r ./base.txt - -# extended list of dependencies for development and run lint and tests --r ./test.txt - -# install all extra dependencies for running examples --r ./examples.txt - -# install lightning adapter additional deps --r ./lightning.txt - -# install profiling additional deps --r ./profiling.txt diff --git a/requirements/docs.txt b/requirements/docs.txt deleted file mode 100644 index 1ddd9b82..00000000 --- a/requirements/docs.txt +++ /dev/null @@ -1,14 +0,0 @@ -sphinx >= 4.0 -myst-parser >= 0.18.1 -nbsphinx >= 0.8.5 -pandoc >= 1.0 -docutils >= 0.16 -sphinxcontrib-fulltoc >= 1.0 -sphinxcontrib-mockautodoc -sphinx-autodoc-typehints >= 1.16 -sphinx-paramlinks >= 0.5.1 -sphinx-togglebutton >= 0.2 -sphinx-copybutton >= 0.3 -typing-extensions -jinja2 >= 3.0.0,<3.1.0 -pt_lightning_sphinx_theme @ git+https://github.com/speediedan/lightning_sphinx_theme.git@057f4c3e669948bc618eec1688b016f07140cc0d diff --git a/requirements/examples.txt b/requirements/examples.txt deleted file mode 100644 index 687e7540..00000000 --- a/requirements/examples.txt +++ /dev/null @@ -1,14 +0,0 @@ -wandb -torch-tb-profiler -notebook -jupyterlab -ipywidgets -jupytext >= 1.10 -nbval >= 0.9.6 -python-dotenv -plotly -matplotlib -gdown -evaluate -scikit-learn -neuronpedia diff --git a/requirements/lightning.txt b/requirements/lightning.txt deleted file mode 100644 index b57307b3..00000000 --- a/requirements/lightning.txt +++ /dev/null @@ -1,3 +0,0 @@ -finetuning-scheduler >= 2.5.0 -bitsandbytes -peft diff --git a/requirements/profiling.txt b/requirements/profiling.txt deleted file mode 100644 index 63c284d3..00000000 --- a/requirements/profiling.txt +++ /dev/null @@ -1 +0,0 @@ -py-spy diff --git a/requirements/test.txt b/requirements/test.txt deleted file mode 100644 index 25e839be..00000000 --- a/requirements/test.txt +++ /dev/null @@ -1,13 +0,0 @@ -coverage >= 6.4 -pytest >= 6.0 -pytest-rerunfailures >= 10.2 -twine >= 3.2 -pyright >= 1.1.365 -pre-commit >= 1.0 -psycopg -toml -pip-tools >= 7.5.1 -pip < 25.3 -huggingface_hub[hf_xet] -nbmake >= 1.5.0 -papermill >= 2.4.0 diff --git a/requirements/utils/lock_ci_requirements.sh b/requirements/utils/lock_ci_requirements.sh new file mode 100755 index 00000000..69ad886b --- /dev/null +++ b/requirements/utils/lock_ci_requirements.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Simple wrapper around uv pip compile for CI requirements locking +# This replaces the complex regen_reqfiles.py with a straightforward uv-based approach +set -eo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +CI_DIR="${REPO_ROOT}/requirements/ci" + +# Ensure output directory exists +mkdir -p "${CI_DIR}" + +echo "Generating locked CI requirements from pyproject.toml..." + +# uv pip compile can read directly from pyproject.toml +# We include: +# - Base dependencies (always included from [project.dependencies]) +# - Optional dependencies: examples, lightning +# - Dependency groups: test, profiling, dev +# +# Note: git-deps group is excluded from locking because git URLs cannot be +# included in universal lock files. It will be installed separately. + +uv pip compile \ + "${REPO_ROOT}/pyproject.toml" \ + --extra examples \ + --extra lightning \ + --group dev \ + --group test \ + --group profiling \ + --output-file "${CI_DIR}/requirements.txt" \ + --upgrade \ + --no-strip-extras \ + --universal + +echo "✓ Generated ${CI_DIR}/requirements.txt" +echo "" +echo "Note: git-deps group (git URL dependencies) is installed separately in CI" diff --git a/requirements/utils/regen_reqfiles.py b/requirements/utils/regen_reqfiles.py deleted file mode 100644 index ce968090..00000000 --- a/requirements/utils/regen_reqfiles.py +++ /dev/null @@ -1,400 +0,0 @@ -import argparse -import fnmatch -import os -import shlex -import subprocess -import shutil -import toml -import re -from dataclasses import dataclass -from typing import Dict, List - -# Commit-based pin selection logic -# -------------------------------- -# This module supports a small workflow for installing important helper -# packages either from a version bound declared in `pyproject.toml` (preferred) -# or from a specific commit SHA recorded in a pin file under -# `requirements/ci/`. The selection rules are: -# 1. If the package-specific environment variable (e.g. `IT_USE_CT_COMMIT_PIN`) -# is set to a truthy value ("1", "true", "yes"), the pin-file-based -# commit installation is selected. -# 2. Otherwise, if the package appears in the pyproject location declared by -# its mapping (for example the `examples` optional extra), the pyproject -# requirement (including any version bounds) is used. -# 3. If neither of the above applies, the code falls back to the pin file (if -# present) as a best-effort final option. -# -# The mapping `DEP_COMMIT_PINS` at the top of this file controls which -# packages participate in this flow; add new entries to enable the same -# behavior for other packages. - -# Paths -REQ_DIR = os.path.dirname(os.path.abspath(__file__)) -REPO_ROOT = os.path.dirname(os.path.dirname(REQ_DIR)) -PYPROJECT_PATH = os.path.join(REPO_ROOT, "pyproject.toml") -CI_REQ_DIR = os.path.join(REPO_ROOT, "requirements", "ci") -POST_UPGRADES_PATH = os.path.join(CI_REQ_DIR, "post_upgrades.txt") - - -@dataclass -class DepCommitPin: - package_name: str - env_var: str - dep_def_loc: str # e.g. 'examples' for examples extra or 'dependencies' for a base dependency - pin_filename: str - repo_base_url: str - - -# Mapping for packages that support commit-pin installation. Add new entries -# here to enable the same behavior for other packages. -DEP_COMMIT_PINS: Dict[str, DepCommitPin] = { - "circuit-tracer": DepCommitPin( - package_name="circuit-tracer", - env_var="IT_USE_CT_COMMIT_PIN", - dep_def_loc="examples", - pin_filename="circuit_tracer_pin.txt", - repo_base_url="https://github.com/speediedan/circuit-tracer.git", - ), -} - -os.makedirs(REQ_DIR, exist_ok=True) -os.makedirs(CI_REQ_DIR, exist_ok=True) - - -def write_file(path, lines): - with open(path, "w") as f: - for line in lines: - f.write(line.rstrip() + "\n") - - -def load_pyproject(): - with open(PYPROJECT_PATH, "r") as f: - return toml.load(f) - - -def convert_pin_file(pin_file_path: str, repo_base_url: str, pkg_name: str) -> List[str]: - """Read a pin file and convert lines into pip-installable requirement strings. - - Each non-empty, non-comment line may be: - - a bare commit hash (40 or 64 hex chars) -> convert to git+{repo}@{hash}#egg={pkg} - - a git+ URL or a name@rev entry -> return as-is - - any other string -> returned as-is - """ - if not os.path.exists(pin_file_path): - return [] - out: List[str] = [] - with open(pin_file_path, "r") as f: - for line in f: - s = line.strip() - if not s or s.startswith("#"): - continue - if all(c in "0123456789abcdef" for c in s.lower()) and len(s) in (40, 64): - out.append(f"git+{repo_base_url}@{s}#egg={pkg_name}") - elif s.startswith("git+") or "@" in s: - out.append(s) - else: - out.append(s) - return out - - -def generate_top_level_files(pyproject, output_dir=REQ_DIR): - project = pyproject.get("project", {}) - core_reqs = project.get("dependencies", []) - # Write top-level requirement files into the repository-level `requirements/` directory - repo_requirements_dir = os.path.join(REPO_ROOT, "requirements") - os.makedirs(repo_requirements_dir, exist_ok=True) - - write_file(os.path.join(repo_requirements_dir, "base.txt"), core_reqs) - opt_deps = project.get("optional-dependencies", {}) - for group, reqs in opt_deps.items(): - write_file(os.path.join(repo_requirements_dir, f"{group}.txt"), reqs) - - -def generate_pip_compile_inputs(pyproject, ci_output_dir=CI_REQ_DIR): - project = pyproject.get("project", {}) - tool_cfg = pyproject.get("tool", {}).get("ci_pinning", {}) - post_upgrades = tool_cfg.get("post_upgrades", {}) or {} - platform_dependent = tool_cfg.get("platform_dependent", []) or [] - - req_in_lines = [] - platform_dependent_lines = [] - direct_packages = [] - - def normalize_package_name(name): - return name.lower().replace("_", "-") - - def add_lines_from(list_or_none): - if not list_or_none: - return - for r in list_or_none: - parts = re.split(r"[\s\[\]=<>!;@]", r) - pkg_name = parts[0].lower() if parts and parts[0] else "" - if normalize_package_name(pkg_name) in {normalize_package_name(k) for k in post_upgrades}: - continue - is_platform_pkg = any( - fnmatch.fnmatch(normalize_package_name(pkg_name), normalize_package_name(pattern)) - for pattern in platform_dependent - ) - if is_platform_pkg: - platform_dependent_lines.append(r) - direct_packages.append(pkg_name) - continue - - req_in_lines.append(r) - direct_packages.append(pkg_name) - - add_lines_from(project.get("dependencies", [])) - - opt_deps = project.get("optional-dependencies", {}) - groups_to_include_completely = ["test", "examples", "lightning"] - - for group, reqs in opt_deps.items(): - if not reqs: - continue - - if group in groups_to_include_completely: - add_lines_from(reqs) - else: - for req in reqs: - parts = re.split(r"[\s\[\]=<>!;@]", req) - pkg_name = parts[0].lower() if parts and parts[0] else "" - post_upgrade_names = {normalize_package_name(k) for k in post_upgrades} - if normalize_package_name(pkg_name) in post_upgrade_names: - continue - is_platform_pkg = any( - fnmatch.fnmatch(normalize_package_name(pkg_name), normalize_package_name(pattern)) - for pattern in platform_dependent - ) - if is_platform_pkg: - platform_dependent_lines.append(req) - direct_packages.append(pkg_name) - continue - continue - - def determine_dep_commit_lines(dep_key: str, pyproject: dict) -> List[str]: - """Generalized selection logic for a dependency that supports commit-pin installs. - - Returns a list of requirement strings to add to requirements.in (may be empty). - """ - if dep_key not in DEP_COMMIT_PINS: - return [] - - cfg = DEP_COMMIT_PINS[dep_key] - env_flag = os.getenv(cfg.env_var, "").lower() - if env_flag in ("1", "true", "yes"): - print(f"{cfg.env_var} is set -> using pin file {cfg.pin_filename} for {cfg.package_name}") - pin_path = os.path.join(CI_REQ_DIR, cfg.pin_filename) - return convert_pin_file(pin_path, cfg.repo_base_url, cfg.package_name) - - # Look for the dependency in the declared pyproject extra/base location - project = pyproject.get("project", {}) - if cfg.dep_def_loc == "dependencies": - candidates = project.get("dependencies", []) or [] - else: - opt_deps = project.get("optional-dependencies", {}) or {} - candidates = opt_deps.get(cfg.dep_def_loc, []) or [] - - for req in candidates: - if cfg.package_name in req.lower() or cfg.package_name.replace("-", "_") in req.lower(): - print( - f"Found {cfg.package_name} in pyproject.{cfg.dep_def_loc} ->", - "using pyproject-specified requirement:", - req, - ) - return [req] - - # fallback to pin file - print( - f"{cfg.package_name} not found in pyproject.{cfg.dep_def_loc};", - f"falling back to pin file {cfg.pin_filename} if present", - ) - pin_path = os.path.join(CI_REQ_DIR, cfg.pin_filename) - return convert_pin_file(pin_path, cfg.repo_base_url, cfg.package_name) - - # Ascertain commit-pin dependencies via the DEP_COMMIT_PINS mapping. - for dep_key, cfg in DEP_COMMIT_PINS.items(): - dep_lines = determine_dep_commit_lines(dep_key, pyproject) - if not dep_lines: - continue - req_in_lines.extend(dep_lines) - # Track the package as a direct package when appropriate - for line in dep_lines: - if cfg.package_name in line or cfg.package_name.replace("-", "_") in line: - direct_packages.append(cfg.package_name) - - in_path = os.path.join(ci_output_dir, "requirements.in") - write_file(in_path, req_in_lines) - - post_lines = [] - for pkg, spec in post_upgrades.items(): - spec_str = spec.strip() - if re.match(r"^[<>=!].+", spec_str): - post_lines.append(f"{pkg}{spec_str}") - else: - post_lines.append(f"{pkg}=={spec_str}") - write_file(POST_UPGRADES_PATH, post_lines) - - platform_path = os.path.join(CI_REQ_DIR, "platform_dependent.txt") - write_file(platform_path, platform_dependent_lines) - - return in_path, POST_UPGRADES_PATH, platform_path, direct_packages - - -def run_pip_compile(req_in_path, output_path): - pip_compile = shutil.which("pip-compile") - if not pip_compile: - print("pip-compile not found in PATH; install pip-tools to generate full pinned requirements.txt") - return False - cmd = [pip_compile, "--output-file", output_path, req_in_path, "--upgrade", "--no-strip-extras"] - print("Running:", " ".join(shlex.quote(c) for c in cmd)) - subprocess.check_call(cmd) - return True - - -def normalize_pip_compile_comments(requirements_path: str, repo_root: str) -> None: - """Normalize absolute repository-root-prefixed paths found in pip-compile comment lines. - - This replaces occurrences of the absolute `repo_root` path with a repository-relative - path (for example, "/home/me/repos/interpretune/requirements/ci/requirements.in" -> - "requirements/ci/requirements.in") inside comment lines (lines that start with "#"). - - This keeps the rest of the file intact and only rewrites comment text so generated - pinned files are consistent across environments. - """ - if not os.path.exists(requirements_path): - return - - with open(requirements_path, "r") as f: - lines = f.readlines() - - prefix = repo_root if repo_root.endswith(os.sep) else repo_root + os.sep - changed = False - out_lines: List[str] = [] - - for line in lines: - stripped = line.lstrip() - if stripped.startswith("#"): - # Replace any occurrence of the absolute repo root path with a relative path - new_line = line.replace(prefix, "") - # Normalize backslashes just in case (Windows paths in cross-envs) - new_line = new_line.replace("\\", "/") - if new_line != line: - changed = True - out_lines.append(new_line) - else: - out_lines.append(line) - - if changed: - with open(requirements_path, "w") as f: - f.writelines(out_lines) - - -def post_process_pinned_requirements(requirements_path, platform_dependent_path, platform_patterns, direct_packages): - if not os.path.exists(requirements_path): - return - - with open(requirements_path, "r") as f: - lines = f.readlines() - - requirements_lines = [] - platform_dependent_lines = [] - - existing_platform_deps = [] - if os.path.exists(platform_dependent_path): - with open(platform_dependent_path, "r") as f: - existing_platform_deps = [line.strip() for line in f if line.strip() and not line.startswith("#")] - - def normalize_package_name(name): - return name.lower().replace("_", "-") - - direct_packages_normalized = {normalize_package_name(pkg) for pkg in direct_packages} - - i = 0 - while i < len(lines): - line = lines[i].strip() - if not line or line.startswith("#"): - requirements_lines.append(lines[i]) - i += 1 - continue - - if " @ " in line: - pkg_name = line.split(" @ ")[0].strip().lower() - else: - parts = re.split(r"[\[\]=<>!;]", line) - pkg_name = parts[0].strip().lower() if parts else "" - - pkg_name_normalized = normalize_package_name(pkg_name) - - is_platform_dependent = any( - fnmatch.fnmatch(pkg_name_normalized, pattern.replace("_", "-")) for pattern in platform_patterns - ) - - is_direct_dependency = pkg_name_normalized in direct_packages_normalized - - if is_platform_dependent: - flexible_req = pkg_name_normalized - platform_dependent_lines.append(flexible_req) - i += 1 - while i < len(lines) and lines[i].strip().startswith("#"): - i += 1 - elif is_direct_dependency: - requirements_lines.append(lines[i]) - i += 1 - while i < len(lines) and lines[i].strip().startswith("#"): - requirements_lines.append(lines[i]) - i += 1 - else: - i += 1 - while i < len(lines) and lines[i].strip().startswith("#"): - i += 1 - - with open(requirements_path, "w") as f: - for line in requirements_lines: - f.write(line.rstrip() + "\n") - - all_platform_deps = list(set(existing_platform_deps + platform_dependent_lines)) - all_platform_deps.sort() - - with open(platform_dependent_path, "w") as f: - for pkg in all_platform_deps: - f.write(pkg.rstrip() + "\n") - - -def main(): - parser = argparse.ArgumentParser(description="Regenerate requirements files from pyproject.toml") - parser.add_argument("--mode", choices=["top-level", "pip-compile"], default="top-level") - parser.add_argument("--ci-output-dir", default=CI_REQ_DIR) - args = parser.parse_args() - - pyproject = load_pyproject() - - generate_top_level_files(pyproject) - - if args.mode == "pip-compile": - in_path, post_path, platform_path, direct_packages = generate_pip_compile_inputs(pyproject, args.ci_output_dir) - out_path = os.path.join(args.ci_output_dir, "requirements.txt") - try: - success = run_pip_compile(in_path, out_path) - if success: - tool_cfg = pyproject.get("tool", {}).get("ci_pinning", {}) - platform_dependent = tool_cfg.get("platform_dependent", []) or [] - post_process_pinned_requirements(out_path, platform_path, platform_dependent, direct_packages) - print(f"Generated pinned requirements at {out_path}") - print(f"Generated post-upgrades at {post_path}") - print(f"Generated platform-dependent packages at {platform_path}") - else: - print(f"Generated {in_path}, {post_path}, and {platform_path}.") - print("To create a pinned requirements.txt, install pip-tools and run:") - print(f" pip-compile {in_path} --output-file {out_path}") - except subprocess.CalledProcessError as e: - print("pip-compile failed:", e) - print( - f"Generated inputs at {in_path}, post-upgrades at {post_path}, " - f"and platform-dependent at {platform_path}" - ) - else: - print("Wrote top-level base and optional group requirement files in requirements/ (no pip-compile run).") - - -if __name__ == "__main__": - main() diff --git a/scripts/build_it_env.sh b/scripts/build_it_env.sh index 9e64f096..d2d085f2 100755 --- a/scripts/build_it_env.sh +++ b/scripts/build_it_env.sh @@ -1,64 +1,95 @@ #!/bin/bash # -# Utility script to build IT environments +# Interpretune environment builder using uv +# Uses uv pip with traditional venv activation for maximum control +# # Usage examples: -# build latest: # ./build_it_env.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest -# build latest with specific pytorch nightly: # ./build_it_env.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --torch_dev_ver=dev20240201 -# build latest with torch test channel: -# ./build_it_env.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --torch_test_channel +# ./build_it_env.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler" set -eo pipefail +# Source shared infrastructure utilities +source "$(dirname "${BASH_SOURCE[0]}")/infra_utils.sh" + unset repo_home unset target_env_name unset torch_dev_ver unset torch_test_channel -unset fts_from_source -unset ct_from_source -unset pip_install_flags -unset regen_with_pip_compile -unset apply_post_upgrades -unset no_ci_reqs +unset uv_install_flags +unset from_source_spec +unset venv_dir +declare -a from_source_specs +declare -A from_source_packages usage(){ >&2 cat << EOF Usage: $0 [ --repo-home input] [ --target-env-name input ] + [ --venv-dir input ] [ --torch-dev-ver input ] [ --torch-test-channel ] - [ --fts-from-source "path" ] - [ --ct-from-source "path" ] - [ --pip-install-flags "flags" ] - [ --no-ci-reqs ] - [ --regen-with-pip-compile ] - [ --apply-post-upgrades ] + [ --from-source "package:path[:extras][:env_var=value...]" ] (can be specified multiple times) + [ --uv-install-flags "flags" ] [ --help ] + + The --from-source flag can be specified multiple times for clarity, or use semicolons to separate specs. + Format: "package:path[:extras][:env_var=value...]" + - extras: optional, e.g., "all" or "dev,test" + - env_var=value: optional, multiple env vars separated by colons + Package names should use underscores (e.g., finetuning_scheduler, circuit_tracer, transformer_lens). + Paths will be expanded if they start with ~. + Environment variables are set only during that package's installation and unset afterward. + + Venv Directory: + - Use --venv-dir to explicitly set the venv BASE directory (recommended when using with manage_standalone_processes.sh) + - The venv will be created at: / + - If --venv-dir not set, uses IT_VENV_BASE environment variable as base (default: ~/.venvs) + - Place venvs on same filesystem as UV cache to avoid hardlink warnings and improve performance + + Environment Variables: + - IT_VENV_BASE: Base directory for venvs when --venv-dir not specified (default: ~/.venvs) + Example: export IT_VENV_BASE=/mnt/cache/username/.venvs + Examples: # build latest: # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest + # build latest with specific pytorch nightly: # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --torch-dev-ver=dev20240201 + # build latest with torch test channel: # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --torch-test-channel - # build latest with FTS from source: - # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --fts-from-source=${HOME}/repos/finetuning-scheduler - # build latest with circuit-tracer from source: - # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --ct-from-source=${HOME}/repos/circuit-tracer - # build latest with no cache directory: - # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --pip-install-flags="--no-cache-dir" - # build latest without using CT commit pinning: - # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest - # build latest and regenerate CI pinned requirements: - # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --regen-with-pip-compile - # build latest and apply post-upgrades: - # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --apply-post-upgrades + + # build latest with single package from source (no extras): + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="circuit_tracer:${HOME}/repos/circuit-tracer" + + # build latest with package from source with extras: + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all" + + # build latest with package from source with extras and env var: + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" + + # build latest with package from source with env var but no extras: + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler::USE_CI_COMMIT_PIN=1" + + # build latest with multiple packages from source (using multiple --from-source flags): + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" --from-source="circuit_tracer:${HOME}/repos/circuit-tracer" + + # build latest with multiple packages from source (using semicolon separator): + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1;circuit_tracer:${HOME}/repos/circuit-tracer" + + # build latest with transformer_lens from source: + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="transformer_lens:${HOME}/repos/TransformerLens" + + # build latest with no cache: + # ./build_it_env.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --uv-install-flags="--no-cache" EOF exit 1 } -args=$(getopt -o '' --long repo-home:,target-env-name:,torch-dev-ver:,torch-test-channel,fts-from-source:,ct-from-source:,pip-install-flags:,no-ci-reqs,regen-with-pip-compile,apply-post-upgrades,help -- "$@") +args=$(getopt -o '' --long repo-home:,target-env-name:,venv-dir:,torch-dev-ver:,torch-test-channel,from-source:,uv-install-flags:,help -- "$@") if [[ $? -gt 0 ]]; then usage fi @@ -69,14 +100,11 @@ do case $1 in --repo-home) repo_home=$2 ; shift 2 ;; --target-env-name) target_env_name=$2 ; shift 2 ;; + --venv-dir) venv_dir=$2 ; shift 2 ;; --torch-dev-ver) torch_dev_ver=$2 ; shift 2 ;; --torch-test-channel) torch_test_channel=1 ; shift ;; - --fts-from-source) fts_from_source=$2 ; shift 2 ;; - --ct-from-source) ct_from_source=$2 ; shift 2 ;; - --pip-install-flags) pip_install_flags=$2 ; shift 2 ;; - --no-ci-reqs) no_ci_reqs=1 ; shift ;; - --regen-with-pip-compile) regen_with_pip_compile=1 ; shift ;; - --apply-post-upgrades) apply_post_upgrades=1 ; shift ;; + --from-source) from_source_specs+=("$2") ; shift 2 ;; + --uv-install-flags) uv_install_flags=$2 ; shift 2 ;; --help) usage ; shift ;; # -- means the end of the arguments; drop this, and break out of the while loop --) shift; break ;; @@ -85,32 +113,43 @@ do esac done -# Use pip_install_flags in pip commands -pip_install_flags=${pip_install_flags:-""} +# Combine multiple --from-source flags into single spec with semicolon separator +if [[ ${#from_source_specs[@]} -gt 0 ]]; then + from_source_spec=$(IFS=';'; echo "${from_source_specs[*]}") +fi -# Expand leading ~ in common path arguments so users can pass --repo-home=~/repos/... -expand_tilde(){ - local p="$1" - if [[ -n "$p" ]] && [[ "$p" == ~* ]]; then - # Use eval to expand ~ reliably - eval echo "$p" - else - echo "$p" - fi -} +# Parse from-source specifications using shared infra_utils function +if [[ -n ${from_source_spec} ]]; then + parse_from_source_specs "${from_source_spec}" from_source_packages || exit 1 +fi + +# Use uv_install_flags in uv pip commands +uv_install_flags=${uv_install_flags:-""} +# Expand leading ~ in common path arguments so users can pass --repo-home=~/repos/... repo_home=$(expand_tilde "${repo_home}") -fts_from_source=$(expand_tilde "${fts_from_source}") -ct_from_source=$(expand_tilde "${ct_from_source}") +venv_dir=$(expand_tilde "${venv_dir}") + +# Expand tilde in from_source_packages paths (which are in format "path|extras|env_vars") +for pkg in "${!from_source_packages[@]}"; do + pkg_spec="${from_source_packages[$pkg]}" + IFS='|' read -r pkg_path pkg_extras pkg_env_vars <<< "${pkg_spec}" + pkg_path=$(expand_tilde "${pkg_path}") + from_source_packages[$pkg]="${pkg_path}|${pkg_extras}|${pkg_env_vars}" +done -# Source common utility functions -source ${repo_home}/scripts/infra_utils.sh +# Determine venv path using centralized function from infra_utils.sh +# Priority: 1) --venv-dir flag, 2) IT_VENV_BASE env var, 3) default ~/.venvs +# Placing venvs on same filesystem as UV cache avoids hardlink warnings and improves performance +venv_path=$(determine_venv_path "${venv_dir}" "${target_env_name}") clear_activate_env(){ - $1 -m venv --clear ~/.venvs/${target_env_name} - source ~/.venvs/${target_env_name}/bin/activate + local python_cmd=$1 + echo "Creating/clearing venv at ${venv_path} with ${python_cmd}" + uv venv --clear "${venv_path}" --python ${python_cmd} + source "${venv_path}/bin/activate" echo "Current venv prompt is now ${VIRTUAL_ENV_PROMPT}" - pip install ${pip_install_flags} --upgrade pip + uv pip install ${uv_install_flags} --upgrade pip setuptools wheel } base_env_build(){ @@ -119,16 +158,16 @@ base_env_build(){ clear_activate_env python3.12 if [[ -n ${torch_dev_ver} ]]; then # temporarily remove torchvision until it supports cu128 in nightly binary - pip install ${pip_install_flags} --pre torch==2.9.0.${torch_dev_ver} --index-url https://download.pytorch.org/whl/nightly/cu128 - elif [[ $torch_test_channel -eq 1 ]]; then - pip install ${pip_install_flags} --pre torch==2.9.0 --index-url https://download.pytorch.org/whl/test/cu128 + uv pip install ${uv_install_flags} --pre torch==2.10.0.${torch_dev_ver} --index-url https://download.pytorch.org/whl/nightly/cu128 + elif [[ ${torch_test_channel} -eq 1 ]]; then + uv pip install ${uv_install_flags} --pre torch==2.10.0 --index-url https://download.pytorch.org/whl/test/cu128 else - pip install ${pip_install_flags} torch --index-url https://download.pytorch.org/whl/cu128 + uv pip install ${uv_install_flags} torch --index-url https://download.pytorch.org/whl/cu128 fi ;; it_release) clear_activate_env python3.12 - pip install ${pip_install_flags} torch --index-url https://download.pytorch.org/whl/cu128 + uv pip install ${uv_install_flags} torch --index-url https://download.pytorch.org/whl/cu128 ;; *) echo "no matching environment found, exiting..." @@ -138,74 +177,43 @@ base_env_build(){ } it_install(){ - source ~/.venvs/${target_env_name}/bin/activate - unset PACKAGE_NAME - if [[ -n ${fts_from_source} ]]; then - export USE_CI_COMMIT_PIN="1" - echo "Installing FTS from source at ${fts_from_source}" - cd ${fts_from_source} - python -m pip install ${pip_install_flags} -e ".[all]" -r requirements/docs.txt - unset USE_CI_COMMIT_PIN - fi + source "${venv_path}/bin/activate" cd ${repo_home} - # Optionally regenerate CI pinned requirements (pip-compile mode) if requested - if [[ -n ${regen_with_pip_compile} ]]; then - python -m pip install ${pip_install_flags} toml pip-tools - # "pip < 25.3" temporarily needed due to pip-tools https://github.com/jazzband/pip-tools/issues/2252 - python -m pip install ${pip_install_flags} "pip<25.3" - echo "Regenerating CI pinned requirements (pip-compile mode)" - python ${repo_home}/requirements/utils/regen_reqfiles.py --mode pip-compile --ci-output-dir ${repo_home}/requirements/ci - fi - # If CI pinned requirements don't exist and user did not disable ci-reqs, regenerate them - if [[ -z ${no_ci_reqs} ]] && [[ ! -f ${repo_home}/requirements/ci/requirements.txt ]]; then - python -m pip install ${pip_install_flags} toml pip-tools - # "pip < 25.3" temporarily needed due to pip-tools https://github.com/jazzband/pip-tools/issues/2252 - python -m pip install ${pip_install_flags} "pip<25.3" - echo "CI pinned requirements not found; regenerating requirements.in and post_upgrades." - python ${repo_home}/requirements/utils/regen_reqfiles.py --mode pip-compile --ci-output-dir ${repo_home}/requirements/ci - fi + # installation strategy: locked CI reqs → git-deps → from-source + ci_reqs_file="${repo_home}/requirements/ci/requirements.txt" - # Install project and extras; prefer CI pinned requirements if available - if [[ -f ${repo_home}/requirements/ci/requirements.txt ]] && [[ -z ${no_ci_reqs} ]]; then - # Install pinned requirements, then install editable package so CLI modules (interpretune.*) are importable - python -m pip install ${pip_install_flags} -r ${repo_home}/requirements/ci/requirements.txt -r requirements/docs.txt || true - # Ensure interpretune package is installed (editable install recommended during dev) - python -m pip install ${pip_install_flags} -e ".[test,examples,lightning,profiling]" - else - python -m pip install ${pip_install_flags} -e ".[test,examples,lightning,profiling]" -r requirements/docs.txt + if [[ ! -f "${ci_reqs_file}" ]]; then + echo "⚠ ERROR: Locked CI requirements not found at ${ci_reqs_file}" + echo "Please regenerate with: bash requirements/utils/lock_ci_requirements.sh" + exit 1 fi - cd ${repo_home} - # Optionally apply post-upgrades if requested and file exists - if [[ -n ${apply_post_upgrades} ]] && [[ -s ${repo_home}/requirements/ci/post_upgrades.txt ]]; then - echo "Applying post-upgrades from requirements/ci/post_upgrades.txt" - pip install --upgrade -r ${repo_home}/requirements/ci/post_upgrades.txt || true - else - echo "Skipping post-upgrades (flag not set or file empty)." - fi + echo "Using locked CI requirements from ${ci_reqs_file}..." - if [[ -n ${ct_from_source} ]]; then - echo "Installing circuit-tracer from source at ${ct_from_source}" - # Try uninstalling both import-style package name and hyphenated distribution name to avoid conflicts - python -m pip uninstall -y circuit_tracer || true - python -m pip uninstall -y circuit-tracer || true - cd ${ct_from_source} - python -m pip install ${pip_install_flags} -e . - - # Verify only the editable source installation is installed - echo "Verifying circuit-tracer installation..." - if pip show circuit_tracer 2>/dev/null | grep -q "Editable project location:"; then - echo "✓ circuit_tracer is installed in editable mode" - else - echo "✗ circuit_tracer is not installed in editable mode" - exit 1 - fi + # 1. Install interpretune in editable mode + git-deps group (uv doesn't currently support url deps in locked reqs) + echo "Installing interpretune in editable mode..." + uv pip install ${uv_install_flags} -e . --group git-deps + + # 2. Install locked CI requirements (all PyPI packages) + echo "Installing locked dependencies..." + uv pip install ${uv_install_flags} -r "${ci_reqs_file}" + + # 3. Install from-source packages (override any PyPI/git versions) + if [[ ${#from_source_packages[@]} -gt 0 ]]; then + echo "Installing from-source packages (these will override any PyPI/git versions)..." + install_from_source_packages from_source_packages "${venv_path}" "${uv_install_flags}" fi - pyright -p pyproject.toml + # 4. Setup git hooks and type checking + cd ${repo_home} + echo "Setting up git hooks and running type checks..." + pyright -p pyproject.toml || echo "⚠ pyright check had issues, continuing..." pre-commit install git lfs install + + # 5. Display environment info + echo "Collecting environment details..." python ${repo_home}/requirements/utils/collect_env_details.py --packages-only } diff --git a/scripts/gen_it_coverage.sh b/scripts/gen_it_coverage.sh index cb64320e..776355b2 100755 --- a/scripts/gen_it_coverage.sh +++ b/scripts/gen_it_coverage.sh @@ -3,18 +3,22 @@ # Utility script to generate local IT coverage for a given environment set -eo pipefail +# Source shared infrastructure utilities +source "$(dirname "${BASH_SOURCE[0]}")/infra_utils.sh" + unset repo_home unset target_env_name unset torch_dev_ver unset torch_test_channel unset no_rebuild_base -unset fts_from_source -unset ct_from_source +unset from_source_spec unset run_all_and_examples unset no_export_cov_xml unset pip_install_flags unset self_test_only unset it_build_flags +unset venv_dir +declare -a from_source_specs usage(){ >&2 cat << EOF @@ -24,34 +28,40 @@ Usage: $0 [ --torch-dev-ver input ] [ --torch-test-channel ] [ --no-rebuild-base ] - [ --fts-from-source "path" ] - [ --ct-from-source "path" ] + [ --from-source "package:path[:extras][:env_var=value...]" ] (can be specified multiple times) + [ --venv-dir "/path/to/venv/base" ] [ --run-all-and-examples ] [ --no-export-cov-xml ] [ --pip-install-flags "flags" ] [ --self-test-only ] [ --it-build-flags "flags" ] [ --help ] + + The --from-source flag can be specified multiple times for clarity, or use semicolons to separate specs. + Format: "package:path[:extras][:env_var=value...]" + - extras: optional, e.g., "all" or "dev,test" + - env_var=value: optional, multiple env vars separated by colons + Examples: # generate it_latest coverage without rebuilding the it_latest base environment: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --no_rebuild_base + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --no-rebuild-base # generate it_latest coverage with a given torch_dev_version, rebuilding base it_latest and with FTS from source: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --torch_dev_ver=dev20240201 --fts_from_source=${HOME}/repos/finetuning-scheduler - # generate it_latest coverage, rebuilding base it_latest with PyTorch test channel and FTS from source: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --torch_test_channel --fts_from_source=${HOME}/repos/finetuning-scheduler - # generate it_latest coverage with circuit-tracer from source: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --ct_from_source=${HOME}/repos/circuit-tracer + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --torch-dev-ver=dev20240201 --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" + # generate it_latest coverage with multiple packages from source (multiple --from-source flags): + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1" --from-source="circuit_tracer:${HOME}/repos/circuit-tracer" --from-source="transformer_lens:${HOME}/repos/TransformerLens" + # generate it_latest coverage with multiple packages from source (semicolon separator): + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1;circuit_tracer:${HOME}/repos/circuit-tracer;transformer_lens:${HOME}/repos/TransformerLens" + # Build with custom venv directory: + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --venv-dir=/mnt/cache/${USER}/.venvs --from-source="finetuning_scheduler:${HOME}/repos/finetuning-scheduler:all:USE_CI_COMMIT_PIN=1;circuit_tracer:${HOME}/repos/circuit-tracer;transformer_lens:${HOME}/repos/TransformerLens" # generate it_latest coverage with no pip cache: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --pip_install_flags="--no-cache-dir" + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --pip-install-flags="--no-cache-dir" # generate it_latest coverage with self_test_only: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --self_test_only - # generate it_latest coverage using CT commit pinning: - # ./gen_it_coverage.sh --repo_home=${HOME}/repos/interpretune --target_env_name=it_latest --ct_commit_pin + # ./gen_it_coverage.sh --repo-home=${HOME}/repos/interpretune --target-env-name=it_latest --self-test-only EOF exit 1 } -args=$(getopt -o '' --long repo-home:,target-env-name:,torch-dev-ver:,torchvision-dev-ver:,torch-test-channel,no-rebuild-base,fts-from-source:,ct-from-source:,run-all-and-examples,no-export-cov-xml,pip-install-flags:,self-test-only,it-build-flags:,help -- "$@") +args=$(getopt -o '' --long repo-home:,target-env-name:,torch-dev-ver:,torchvision-dev-ver:,torch-test-channel,no-rebuild-base,from-source:,venv-dir:,run-all-and-examples,no-export-cov-xml,pip-install-flags:,self-test-only,it-build-flags:,help -- "$@") if [[ $? -gt 0 ]]; then usage fi @@ -65,8 +75,8 @@ do --torch-dev-ver) torch_dev_ver=$2 ; shift 2 ;; --torch-test-channel) torch_test_channel=1 ; shift ;; --no-rebuild-base) no_rebuild_base=1 ; shift ;; - --fts-from-source) fts_from_source=$2 ; shift 2 ;; - --ct-from-source) ct_from_source=$2 ; shift 2 ;; + --from-source) from_source_specs+=("$2") ; shift 2 ;; + --venv-dir) venv_dir=$2 ; shift 2 ;; --run-all-and-examples) run_all_and_examples=1 ; shift ;; --no-export-cov-xml) no_export_cov_xml=1 ; shift ;; --pip-install-flags) pip_install_flags=$2 ; shift 2 ;; @@ -79,26 +89,27 @@ do esac done +# Combine multiple --from-source flags into single spec with semicolon separator +if [[ ${#from_source_specs[@]} -gt 0 ]]; then + from_source_spec=$(IFS=';'; echo "${from_source_specs[*]}") +fi + d=`date +%Y%m%d%H%M%S` tmp_coverage_dir="/tmp" coverage_session_log="${tmp_coverage_dir}/gen_it_coverage_${target_env_name}_${d}.log" echo "Use 'tail -f ${coverage_session_log}' to monitor progress" -# Robustness: strip leading/trailing quotes from path variables if present -strip_quotes() { - # Remove both single and double quotes from start/end - local val="$1" - val="${val%\' }"; val="${val#\' }" - val="${val%\"}"; val="${val#\"}" - echo "$val" -} +# Expand leading ~ in common path arguments +repo_home=$(expand_tilde "${repo_home}") +venv_dir=$(expand_tilde "${venv_dir}") -# Only strip if set -if [[ -n "${ct_from_source}" ]]; then - ct_from_source=$(strip_quotes "$ct_from_source") -fi -if [[ -n "${fts_from_source}" ]]; then - fts_from_source=$(strip_quotes "$fts_from_source") +# Determine venv path +# Priority: --venv-dir > IT_VENV_BASE > default ~/.venvs +venv_path=$(determine_venv_path "${venv_dir}" "${target_env_name}") + +# Strip leading/trailing quotes from string variables if present +if [[ -n "${from_source_spec}" ]]; then + from_source_spec=$(strip_quotes "$from_source_spec") fi if [[ -n "${it_build_flags}" ]]; then it_build_flags=$(strip_quotes "$it_build_flags") @@ -114,70 +125,53 @@ check_self_test_only(){ } env_rebuild(){ - # Prepare pip_install_flags parameter if set - pip_flags_param="" - if [[ -n "${pip_install_flags}" ]]; then - pip_flags_param="--pip-install-flags=\"${pip_install_flags}\"" - fi - - fts_from_source_param="" - if [[ -n "${fts_from_source}" ]]; then - fts_from_source_param="--fts-from-source=${fts_from_source}" - fi - - ct_from_source_param="" - if [[ -n "${ct_from_source}" ]]; then - ct_from_source_param="--ct-from-source=${ct_from_source}" - fi - - it_build_flags_params="" - if [[ -n "${it_build_flags}" ]]; then - it_build_flags_params="${it_build_flags}" - fi + cd ${repo_home} case $1 in - it_latest ) - if [[ -n ${torch_dev_ver} ]]; then - ${repo_home}/scripts/build_it_env.sh --repo-home=${repo_home} --target-env-name=$1 --torch-dev-ver=${torch_dev_ver} ${fts_from_source_param} ${ct_from_source_param} ${pip_flags_param} ${ct_commit_pin_param} ${apply_post_upgrades_param} - elif [[ $torch_test_channel -eq 1 ]]; then - ${repo_home}/scripts/build_it_env.sh --repo-home=${repo_home} --target-env-name=$1 --torch-test-channel ${fts_from_source_param} ${ct_from_source_param} ${pip_flags_param} ${ct_commit_pin_param} ${apply_post_upgrades_param} - else - ${repo_home}/scripts/build_it_env.sh --repo-home=${repo_home} --target-env-name=$1 ${fts_from_source_param} ${ct_from_source_param} ${pip_flags_param} ${ct_commit_pin_param} ${it_build_flags_params} + it_latest | it_release ) + echo "Rebuilding environment with build_it_env.sh..." >> $coverage_session_log + # Build command with conditional flags + build_cmd="${repo_home}/scripts/build_it_env.sh --repo-home=${repo_home} --target-env-name=$1" + [[ -n ${venv_dir} ]] && build_cmd="${build_cmd} --venv-dir=${venv_dir}" + [[ -n ${torch_dev_ver} ]] && build_cmd="${build_cmd} --torch-dev-ver=${torch_dev_ver}" + [[ ${torch_test_channel} -eq 1 ]] && build_cmd="${build_cmd} --torch-test-channel" + + # Handle multiple --from-source flags + if [[ ${#from_source_specs[@]} -gt 0 ]]; then + for spec in "${from_source_specs[@]}"; do + build_cmd="${build_cmd} --from-source='${spec}'" + done fi - ;; - it_release ) - ${repo_home}/scripts/build_it_env.sh --repo-home=${repo_home} --target-env-name=$1 ${fts_from_source_param} ${ct_from_source_param} ${pip_flags_param} ${ct_commit_pin_param} ${it_build_flags_params} + + [[ -n ${pip_install_flags} ]] && build_cmd="${build_cmd} --uv-install-flags='${pip_install_flags}'" + + echo "Running: ${build_cmd}" >> $coverage_session_log + eval ${build_cmd} >> $coverage_session_log 2>&1 ;; *) echo "no matching environment found, exiting..." >> $coverage_session_log exit 1 ;; esac - } collect_env_coverage(){ temp_special_log="${tmp_coverage_dir}/special_test_output_$1_${d}.log" cd ${repo_home} - source ./scripts/infra_utils.sh maybe_deactivate - source ~/.venvs/$1/bin/activate + source ${venv_path}/bin/activate + case $1 in it_latest | it_latest_pt_2_4 ) check_self_test_only "Skipping all tests and examples." && return python -m coverage erase if [[ $run_all_and_examples -eq 1 ]]; then - #check_self_test_only "Skipping all tests and examples." && return python -m coverage run --source src/interpretune -m pytest src/interpretune src/it_examples tests -v 2>&1 >> $coverage_session_log (./tests/special_tests.sh --mark_type=standalone --log_file=${coverage_session_log} 2>&1 >> ${temp_special_log}) > /dev/null (./tests/special_tests.sh --mark_type=profile_ci --log_file=${coverage_session_log} 2>&1 >> ${temp_special_log}) > /dev/null (./tests/special_tests.sh --mark_type=profile --log_file=${coverage_session_log} 2>&1 >> ${temp_special_log}) > /dev/null (./tests/special_tests.sh --mark_type=optional --log_file=${coverage_session_log} 2>&1 >> ${temp_special_log}) > /dev/null else - #check_self_test_only "Skipping all tests and examples." && return - # if check_self_test_only "Skipping CI tests."; then - # return - # fi python -m coverage run --append --source src/interpretune -m pytest tests -v 2>&1 >> $coverage_session_log (./tests/special_tests.sh --mark_type=standalone --log_file=${coverage_session_log} 2>&1 >> ${temp_special_log}) > /dev/null (./tests/special_tests.sh --mark_type=profile_ci --log_file=${coverage_session_log} 2>&1 >> ${temp_special_log}) > /dev/null @@ -196,9 +190,6 @@ env_rebuild_collect(){ else echo "Beginning IT env rebuild for $1" >> $coverage_session_log check_self_test_only "Skipping all tests and examples." && return - # if check_self_test_only "Skipping rebuild."; then - # return - # fi env_rebuild "$1" fi echo "Collecting coverage for the IT env $1" >> $coverage_session_log @@ -228,5 +219,7 @@ case ${target_env_name} in ;; esac echo "Writing collected coverage stats for IT env ${target_env_name}" >> $coverage_session_log +# Reactivate the environment for coverage report (in case it was deactivated) +source ${venv_path}/bin/activate python -m coverage report -m >> $coverage_session_log show_elapsed_time $coverage_session_log "IT coverage collection" diff --git a/scripts/infra_utils.sh b/scripts/infra_utils.sh index 50465f27..19e865fd 100755 --- a/scripts/infra_utils.sh +++ b/scripts/infra_utils.sh @@ -21,3 +21,193 @@ show_elapsed_time(){ maybe_deactivate(){ deactivate 2>/dev/null || true } + +# Function to strip leading and trailing quotes from a string +# Usage: cleaned=$(strip_quotes "$variable") +strip_quotes(){ + local val="$1" + # Remove both single and double quotes from start/end + val="${val%\' }"; val="${val#\' }" + val="${val%\"}"; val="${val#\"}" + echo "$val" +} + +# Function to expand tilde in paths +# Usage: expanded_path=$(expand_tilde "~/some/path") +expand_tilde(){ + local p="$1" + if [[ -n "$p" ]] && [[ "$p" == ~* ]]; then + # Use eval to expand ~ reliably + eval echo "$p" + else + echo "$p" + fi +} + +# Determine venv path based on priority: --venv-dir > IT_VENV_BASE > default ~/.venvs +# Usage: venv_path=$(determine_venv_path "$venv_dir" "$target_env_name") +# Arguments: +# $1: venv_dir (optional) - explicit base directory from --venv-dir flag +# $2: target_env_name (required) - environment name to append to base +# Returns: Full venv path (e.g., /mnt/cache/user/.venvs/it_latest) +determine_venv_path(){ + local venv_dir="$1" + local target_env_name="$2" + local venv_base + + if [[ -z "${target_env_name}" ]]; then + echo "Error: target_env_name is required" >&2 + return 1 + fi + + if [[ -n "${venv_dir}" ]]; then + # Explicit --venv-dir provided (base directory) + venv_base="${venv_dir}" + elif [[ -n "${IT_VENV_BASE}" ]]; then + # Use IT_VENV_BASE environment variable + venv_base="${IT_VENV_BASE}" + else + # Use default + venv_base="~/.venvs" + fi + + # Expand tilde if present + venv_base=$(expand_tilde "${venv_base}") + + # Return full path + echo "${venv_base}/${target_env_name}" +} + +# Parse from-source specifications into an associative array +# Format: package:path[:extras][:env_var=value...] +# Usage: parse_from_source_specs "spec1;spec2;..." from_source_packages_array_name +# Example: +# declare -A from_source_packages +# parse_from_source_specs "$from_source_spec" from_source_packages +parse_from_source_specs(){ + local from_source_spec="$1" + local -n pkg_array=$2 # nameref to associative array + + if [[ -z ${from_source_spec} ]]; then + return 0 + fi + + IFS=';' read -ra PAIRS <<< "${from_source_spec}" + for pair in "${PAIRS[@]}"; do + # Split on colons to get all fields + IFS=':' read -ra FIELDS <<< "${pair}" + + if [[ ${#FIELDS[@]} -lt 2 ]]; then + echo "Error: Invalid from-source format: '$pair'" >&2 + echo "Expected format: package:path[:extras][:env_var=value...]" >&2 + return 1 + fi + + local pkg_name="${FIELDS[0]}" + local pkg_path="${FIELDS[1]}" + local pkg_extras="" + local pkg_env_vars="" + + # Process remaining fields - first non-env field is extras, rest are env vars + for ((i=2; i<${#FIELDS[@]}; i++)); do + local field="${FIELDS[i]}" + if [[ $field =~ ^[A-Z_][A-Z0-9_]*=.*$ ]]; then + # This is an env var (contains =) + if [[ -n ${pkg_env_vars} ]]; then + pkg_env_vars="${pkg_env_vars}|${field}" + else + pkg_env_vars="${field}" + fi + elif [[ -z ${pkg_extras} && -n ${field} ]]; then + # First non-env, non-empty field is extras + pkg_extras="${field}" + fi + done + + # Normalize package name (convert underscores to hyphens for consistency) + pkg_name="${pkg_name//_/-}" + # Store as "path|extras|env_vars" so we can split later + pkg_array[$pkg_name]="${pkg_path}|${pkg_extras}|${pkg_env_vars}" + done +} + +# Install packages from source with optional extras and environment variables +# Installs packages with all their dependencies - UV's git dependency caching ensures +# that commit-pinned dependencies (e.g., Lightning pinned by finetuning-scheduler) will +# be respected by subsequent installations. +# Usage: install_from_source_packages from_source_packages_array_name venv_path [uv_install_flags] +# Example: +# declare -A from_source_packages +# parse_from_source_specs "$specs" from_source_packages +# install_from_source_packages from_source_packages "/path/to/venv" "--no-cache" +# Note: The venv_path parameter should be the full path to the venv (not just the name) +install_from_source_packages(){ + local -n pkg_array=$1 # nameref to associative array + local venv_path="$2" + local uv_install_flags="${3:-}" + + if [[ ${#pkg_array[@]} -eq 0 ]]; then + return 0 + fi + + source "${venv_path}/bin/activate" + + # Install packages from source if requested (do this before main package to avoid conflicts) + for pkg in "${!pkg_array[@]}"; do + local pkg_spec="${pkg_array[$pkg]}" + # Split on pipe delimiter to get path, extras, and env vars + IFS='|' read -r pkg_path pkg_extras pkg_env_vars <<< "${pkg_spec}" + + # Uninstall any existing installations (try both package name formats) + local pkg_underscore="${pkg//-/_}" + uv pip uninstall -y "${pkg}" "${pkg_underscore}" 2>/dev/null || true + + cd "${pkg_path}" + + # Build install target with optional extras + if [[ -n ${pkg_extras} ]]; then + local install_target=".[${pkg_extras}]" + echo "Installing ${pkg} from source at ${pkg_path} with extras: [${pkg_extras}]" + else + local install_target="." + echo "Installing ${pkg} from source at ${pkg_path} (no extras)" + fi + + # Set environment variables if specified + local env_vars_set=() + if [[ -n ${pkg_env_vars} ]]; then + echo "Setting environment variables for ${pkg} installation:" + IFS='|' read -ra ENV_VARS <<< "${pkg_env_vars}" + for env_var in "${ENV_VARS[@]}"; do + if [[ $env_var =~ ^([^=]+)=(.*)$ ]]; then + local var_name="${BASH_REMATCH[1]}" + local var_value="${BASH_REMATCH[2]}" + echo " export ${var_name}=${var_value}" + export "${var_name}=${var_value}" + env_vars_set+=("${var_name}") + fi + done + fi + + uv pip install ${uv_install_flags} -e "${install_target}" + + # Unset environment variables after installation + if [[ ${#env_vars_set[@]} -gt 0 ]]; then + echo "Unsetting temporary environment variables for ${pkg}:" + for var_name in "${env_vars_set[@]}"; do + echo " unset ${var_name}" + unset "${var_name}" + done + fi + + # Verify editable installation + echo "Verifying ${pkg} installation..." + if uv pip show "${pkg_underscore}" 2>/dev/null | grep -q "Editable project location:"; then + echo "✓ ${pkg} is installed in editable mode" + elif uv pip show "${pkg}" 2>/dev/null | grep -q "Editable project location:"; then + echo "✓ ${pkg} is installed in editable mode" + else + echo "⚠ Warning: ${pkg} may not be installed in editable mode" + fi + done +} diff --git a/src/interpretune/analysis/core.py b/src/interpretune/analysis/core.py index 79f84cdc..73848694 100644 --- a/src/interpretune/analysis/core.py +++ b/src/interpretune/analysis/core.py @@ -351,7 +351,7 @@ class AnalysisStore: def __init__( self, # dataset: can be a path or a loaded Hugging Face dataset - dataset: HfDataset | StrOrPath | None = None, + dataset: HfDataset | StrOrPath | os.PathLike | None = None, op_output_dataset_path: str | None = None, cache_dir: str | None = None, streaming: bool = False, diff --git a/src/interpretune/config/shared.py b/src/interpretune/config/shared.py index c2a6ba73..512793ac 100644 --- a/src/interpretune/config/shared.py +++ b/src/interpretune/config/shared.py @@ -4,6 +4,7 @@ import logging import os import sys +from pathlib import PosixPath, WindowsPath import yaml from transformers import PreTrainedTokenizerBase @@ -52,7 +53,25 @@ def __repr__(self) -> str: # TODO: add custom constructors and representers for core IT object types @dataclass(kw_only=True) -class ITSerializableCfg(yaml.YAMLObject): ... +class ITSerializableCfg(yaml.YAMLObject): + """Base class for serializable Interpretune configs. + + Automatically registers subclasses and Path types as safe globals for PyTorch checkpoint loading. + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Auto-register all ITSerializableCfg subclasses as safe for pickle deserialization + # This is required when loading checkpoints with weights_only=True + # Also register Path types to allow Path objects in serialized configs + try: + import torch.serialization + + # Register the config class and both platform-specific Path types + torch.serialization.add_safe_globals([cls, PosixPath, WindowsPath]) + except (ImportError, AttributeError): + # torch.serialization.add_safe_globals not available in older PyTorch versions + pass @dataclass(kw_only=True) diff --git a/src/interpretune/protocol.py b/src/interpretune/protocol.py index 062babc7..19207623 100644 --- a/src/interpretune/protocol.py +++ b/src/interpretune/protocol.py @@ -17,11 +17,11 @@ TypedDict, TypeVar, ) -from os import PathLike from pathlib import Path from types import UnionType from enum import auto, Enum, EnumMeta from dataclasses import dataclass +from os import PathLike import inspect import torch @@ -47,7 +47,7 @@ # Interpretune helper types ################################################################################ -StrOrPath: TypeAlias = Union[str, PathLike, Path] +StrOrPath: TypeAlias = Union[str, Path] ################################################################################ # Interpretune Enhanced Enums @@ -493,7 +493,7 @@ def apply_op_by_sae(self, operation: Callable | str, *args, **kwargs) -> "SAEDic class AnalysisStoreProtocol(Protocol): """Protocol verifying core analysis store functionality.""" - dataset: Union[HfDataset, StrOrPath, None] + dataset: Union[HfDataset, StrOrPath, PathLike, None] streaming: bool cache_dir: str | None diff --git a/src/interpretune/utils/import_utils.py b/src/interpretune/utils/import_utils.py index f5d2ebad..13187727 100644 --- a/src/interpretune/utils/import_utils.py +++ b/src/interpretune/utils/import_utils.py @@ -2,9 +2,9 @@ import importlib from functools import lru_cache from importlib.util import find_spec +from importlib.metadata import version as get_version, PackageNotFoundError import operator import torch -import pkg_resources from packaging.version import Version from interpretune.utils import MisconfigurationException @@ -137,7 +137,7 @@ def module_available(module_path: str) -> bool: return True -def compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: +def compare_version(package: str, op: Callable, version_str: str, use_base_version: bool = False) -> bool: """Compare package version with some requirements. >>> compare_version("torch", operator.ge, "0.1") @@ -147,20 +147,20 @@ def compare_version(package: str, op: Callable, version: str, use_base_version: """ try: pkg = importlib.import_module(package) - except (ImportError, pkg_resources.DistributionNotFound): + except (ImportError, PackageNotFoundError): return False try: if hasattr(pkg, "__version__"): pkg_version = Version(pkg.__version__) else: - # try pkg_resources to infer version - pkg_version = Version(pkg_resources.get_distribution(package).version) - except TypeError: + # try importlib.metadata to infer version + pkg_version = Version(get_version(package)) + except (TypeError, PackageNotFoundError): # this is mocked by Sphinx, so it should return True to generate all summaries return True if use_base_version: pkg_version = Version(pkg_version.base_version) - return op(pkg_version, Version(version)) + return op(pkg_version, Version(version_str)) ################################################################################ diff --git a/src/interpretune/utils/logging.py b/src/interpretune/utils/logging.py index 5a927000..698b98fb 100644 --- a/src/interpretune/utils/logging.py +++ b/src/interpretune/utils/logging.py @@ -53,8 +53,14 @@ def collect_env_info() -> Dict: "torch.utils.collect_env", "get_cuda_module_loading_config", CUDA_MAY_BE_INIT_MSG ) sys_dict = sys_info._asdict() - pip_dict = {name: ver for name, ver in [p.split("==") for p in sys_info._asdict()["pip_packages"].split("\n")]} - sys_dict["pip_packages"] = pip_dict + # TODO: since we now use uv via the pip interface, we should consider adding uv pip package versions here if torch + # does not start doing so soon + pip_packages = sys_dict.get("pip_packages") + if pip_packages: + pip_dict = {name: ver for name, ver in [p.split("==", 1) for p in pip_packages.split("\n") if "==" in p]} + sys_dict["pip_packages"] = pip_dict + else: + sys_dict["pip_packages"] = {} return sys_dict diff --git a/src/it_examples/utils/raw_graph_analysis.py b/src/it_examples/utils/raw_graph_analysis.py index 2b81bed0..07a6ebe3 100644 --- a/src/it_examples/utils/raw_graph_analysis.py +++ b/src/it_examples/utils/raw_graph_analysis.py @@ -322,7 +322,7 @@ def get_logit_indices_for_tokens( the adjacency matrix for the logit nodes of those tokens. Args: - graph: The graph object containing logit_tokens and adjacency_matrix. + graph: The graph object containing logit_token_ids and adjacency_matrix. token_ids (torch.Tensor, optional): Tensor of token ids to inspect. token_strings (list, optional): List of token strings to inspect. tokenizer (transformers.PreTrainedTokenizer, optional): Tokenizer to convert strings to ids. @@ -331,19 +331,19 @@ def get_logit_indices_for_tokens( torch.Tensor: Indices in the adjacency matrix for the logit nodes of the specified tokens. """ if token_strings is not None and tokenizer is not None: - inspect_ids = torch.tensor(tokenizer.convert_tokens_to_ids(token_strings), device=graph.logit_tokens.device) + inspect_ids = torch.tensor(tokenizer.convert_tokens_to_ids(token_strings), device=graph.logit_token_ids.device) elif token_ids is not None: - inspect_ids = token_ids.to(graph.logit_tokens.device) + inspect_ids = token_ids.to(graph.logit_token_ids.device) else: raise ValueError("Either token_ids or (token_strings and tokenizer) must be provided.") - lmask = (graph.logit_tokens.unsqueeze(1) == inspect_ids).any(dim=1) + lmask = (graph.logit_token_ids.unsqueeze(1) == inspect_ids).any(dim=1) indices = torch.nonzero(lmask, as_tuple=False).cpu().squeeze() if indices.numel() == 0: return torch.tensor([], dtype=torch.long) if indices.dim() == 0: indices = indices.unsqueeze(0) - adj_offset = graph.adjacency_matrix.shape[0] - len(graph.logit_tokens) + adj_offset = graph.adjacency_matrix.shape[0] - len(graph.logit_token_ids) final_logit_idxs = adj_offset + indices return final_logit_idxs, indices @@ -370,8 +370,8 @@ def generate_topk_node_mapping(graph, node_mask, topk_feats_to_translate=None, c # If cumulative_scores is provided, use its length for logit_end_idx, else infer from graph if cumulative_scores is not None: logit_end_idx = len(cumulative_scores) - elif hasattr(graph, "logit_tokens"): - logit_end_idx = logit_start_idx + len(graph.logit_tokens) + elif hasattr(graph, "logit_token_ids"): + logit_end_idx = logit_start_idx + len(graph.logit_token_ids) else: logit_end_idx = logit_start_idx @@ -413,7 +413,7 @@ def generate_topk_node_mapping(graph, node_mask, topk_feats_to_translate=None, c node_ids[node_idx] = node_id elif node_idx in range(logit_start_idx, logit_end_idx): pos = node_idx - logit_start_idx - vocab_idx = graph.logit_tokens[pos] + vocab_idx = graph.logit_token_ids[pos] layer = str(layers + 1) node_id = f"{layer}_{vocab_idx}_{pos}" node_ids[node_idx] = node_id diff --git a/src/it_examples/utils/raw_graph_analysis_example_incomplete.py b/src/it_examples/utils/raw_graph_analysis_example_incomplete.py index 9dff7270..dc403295 100644 --- a/src/it_examples/utils/raw_graph_analysis_example_incomplete.py +++ b/src/it_examples/utils/raw_graph_analysis_example_incomplete.py @@ -1,102 +1,104 @@ -""" -Raw Graph Analysis Example - INCOMPLETE - -TODO: This example demonstrates manual raw graph analysis using local artifacts that are not yet -publicly available. Complete this example with publicly accessible demo data or parametrize it -for user-provided graph data. - -The code below shows how to use the raw_graph_analysis utility functions to inspect -Circuit Tracer graphs manually, but requires local data files that don't exist in CI/testing environments. -""" - -import os -from pathlib import Path - -import torch - -from circuit_tracer.graph import Graph, prune_graph -from it_examples.utils.raw_graph_analysis import ( - generate_topk_node_mapping, - gen_raw_graph_overview, - get_logit_indices_for_tokens, - get_node_ids_for_adj_matrix_indices, - get_topk_edges_for_node_range, - load_graph_json, - unpack_objs_from_pt_dict, -) - -# TODO: Replace these hardcoded paths with configurable parameters or publicly available demo data -node_threshold = 0.8 -edge_threshold = 0.98 -OS_HOME = os.environ.get("HOME") -if OS_HOME is None: - raise RuntimeError("HOME environment variable is not set") - -local_it_demo_graph_data = Path(OS_HOME) / "repos" / "local_it_demo_graph_data" -target_example_dir = "ct_attribution_analysis_example" -target_example_raw_data_file = "it_circuit_tracer_compute_specific_logits_demo_1_20250929_102245.pt" -target_example_graph_file = "it_circuit_tracer_compute_specific_logits_demo_1_20250929_102245.json" -raw_graph_inspect = local_it_demo_graph_data / target_example_dir / target_example_raw_data_file -raw_graph_data = torch.load(raw_graph_inspect, weights_only=False, map_location="cpu") -graph_json_path = local_it_demo_graph_data / target_example_dir / target_example_graph_file -graph_dict = load_graph_json(graph_json_path) -locals().update(unpack_objs_from_pt_dict(raw_graph_data)) - -graph = Graph.from_pt(str(raw_graph_inspect)) -device = "cuda" if torch.cuda.is_available() else "cpu" -graph.to(device) -node_mask, edge_mask, cumulative_scores = (el.cpu() for el in prune_graph(graph, node_threshold, edge_threshold)) -graph.to("cpu") - - -# Examining logit node edges in the adjacency matrix directly -# Set our target_token_ids either by str (with tokenizer) or manually -target_token_ids = torch.tensor([26865, 22605], device=graph.logit_tokens.device) - -# Generate our raw graph overview -raw_graph_overview = gen_raw_graph_overview(k=5, target_token_ids=target_token_ids, graph=graph, node_mask=node_mask) - -# Explore as desired -# raw_graph_overview.first_order_node_ids -# ( -# ['20_15589_7', 'E_26865_6', '0_24_7', '21_5943_7', '23_12237_7'], -# ['E_26865_6', '20_15589_7', '21_5943_7', '14_2268_6', '16_25_6'] +# TODO: disabling this incomplete example for now while refactor is underway +# """ +# Raw Graph Analysis Example - INCOMPLETE + +# TODO: This example demonstrates manual raw graph analysis using local artifacts that are not yet +# publicly available. Complete this example with publicly accessible demo data or parametrize it +# for user-provided graph data. + +# The code below shows how to use the raw_graph_analysis utility functions to inspect +# Circuit Tracer graphs manually, but requires local data files that don't exist in CI/testing environments. +# """ + +# import os +# from pathlib import Path + +# import torch + +# from circuit_tracer.graph import Graph, prune_graph +# from it_examples.utils.raw_graph_analysis import ( +# generate_topk_node_mapping, +# gen_raw_graph_overview, +# get_logit_indices_for_tokens, +# get_node_ids_for_adj_matrix_indices, +# get_topk_edges_for_node_range, +# load_graph_json, +# unpack_objs_from_pt_dict, # ) -# raw_graph_overview.first_order_values -# tensor([[6.0000, 5.9062, 3.6719, 3.5000, 2.8594], -# [9.6875, 5.5000, 3.8906, 2.8281, 2.7812]]) -# raw_graph_overview.idxs_to_node_ids(6588) -# ['E_26865_6'] - -# Or indvidually analyze the adjacency matrix -# generate our node mapping and ranges -node_mapping, node_ranges = generate_topk_node_mapping(graph, node_mask) - -# Get our topk edges for a given node range -topk_logit_vals, topk_logit_indices = get_topk_edges_for_node_range(node_ranges["logit_nodes"], graph.adjacency_matrix) - -# Get our target logit indices into both the adjacency matrix and our logit_probabilities/logit_tokens vector -adj_matrix_target_logit_idxs, target_logit_vec_idxs = get_logit_indices_for_tokens(graph, target_token_ids) - -# Gather our target logit topk edge values using the full adj_matrix logit indices -target_topk_logit_vals = torch.gather( - graph.adjacency_matrix[adj_matrix_target_logit_idxs], 1, topk_logit_indices[target_logit_vec_idxs] -) - -# Get node_ids for the target logit indices in the adjacency matrix -node_ids_for_target_logit_nodes = get_node_ids_for_adj_matrix_indices(adj_matrix_target_logit_idxs, node_mapping) - -# Example output: -# node_ids_for_target_logit_nodes -# ['27_22605_0', '27_26865_5'] - -# Get the node_ids for the topk edges for our target logit nodes -node_ids_for_topk_edges_of_target_logit_nodes = get_node_ids_for_adj_matrix_indices( - topk_logit_indices[target_logit_vec_idxs], node_mapping -) - -# Example output: -# node_ids_for_topk_edges_of_target_logit_nodes[0] -# ['20_15589_7', 'E_26865_6', '0_24_7', '21_5943_7', '23_12237_7'] -# node_ids_for_topk_edges_of_target_logit_nodes[1] -# ['E_26865_6', '20_15589_7', '21_5943_7', '14_2268_6', '16_25_6'] + +# # TODO: Replace these hardcoded paths with configurable parameters or publicly available demo data +# node_threshold = 0.8 +# edge_threshold = 0.98 +# OS_HOME = os.environ.get("HOME") +# if OS_HOME is None: +# raise RuntimeError("HOME environment variable is not set") + +# local_it_demo_graph_data = Path(OS_HOME) / "repos" / "local_it_demo_graph_data" +# target_example_dir = "ct_attribution_analysis_example" +# target_example_raw_data_file = "it_circuit_tracer_compute_specific_logits_demo_1_20250929_102245.pt" +# target_example_graph_file = "it_circuit_tracer_compute_specific_logits_demo_1_20250929_102245.json" +# raw_graph_inspect = local_it_demo_graph_data / target_example_dir / target_example_raw_data_file +# raw_graph_data = torch.load(raw_graph_inspect, weights_only=False, map_location="cpu") +# graph_json_path = local_it_demo_graph_data / target_example_dir / target_example_graph_file +# graph_dict = load_graph_json(graph_json_path) +# locals().update(unpack_objs_from_pt_dict(raw_graph_data)) + +# graph = Graph.from_pt(str(raw_graph_inspect)) +# device = "cuda" if torch.cuda.is_available() else "cpu" +# graph.to(device) +# node_mask, edge_mask, cumulative_scores = (el.cpu() for el in prune_graph(graph, node_threshold, edge_threshold)) +# graph.to("cpu") + + +# # Examining logit node edges in the adjacency matrix directly +# # Set our target_token_ids either by str (with tokenizer) or manually +# target_token_ids = torch.tensor([26865, 22605], device=graph.logit_token_ids.device) + +# # Generate our raw graph overview +# raw_graph_overview = gen_raw_graph_overview(k=5, target_token_ids=target_token_ids, graph=graph, node_mask=node_mask) + +# # Explore as desired +# # raw_graph_overview.first_order_node_ids +# # ( +# # ['20_15589_7', 'E_26865_6', '0_24_7', '21_5943_7', '23_12237_7'], +# # ['E_26865_6', '20_15589_7', '21_5943_7', '14_2268_6', '16_25_6'] +# # ) +# # raw_graph_overview.first_order_values +# # tensor([[6.0000, 5.9062, 3.6719, 3.5000, 2.8594], +# # [9.6875, 5.5000, 3.8906, 2.8281, 2.7812]]) +# # raw_graph_overview.idxs_to_node_ids(6588) +# # ['E_26865_6'] + +# # Or indvidually analyze the adjacency matrix +# # generate our node mapping and ranges +# node_mapping, node_ranges = generate_topk_node_mapping(graph, node_mask) + +# # Get our topk edges for a given node range +# topk_logit_vals, topk_logit_indices = get_topk_edges_for_node_range(node_ranges["logit_nodes"], +# graph.adjacency_matrix) + +# # Get our target logit indices into both the adjacency matrix and our logit_probabilities/logit_tokens vector +# adj_matrix_target_logit_idxs, target_logit_vec_idxs = get_logit_indices_for_tokens(graph, target_token_ids) + +# # Gather our target logit topk edge values using the full adj_matrix logit indices +# target_topk_logit_vals = torch.gather( +# graph.adjacency_matrix[adj_matrix_target_logit_idxs], 1, topk_logit_indices[target_logit_vec_idxs] +# ) + +# # Get node_ids for the target logit indices in the adjacency matrix +# node_ids_for_target_logit_nodes = get_node_ids_for_adj_matrix_indices(adj_matrix_target_logit_idxs, node_mapping) + +# # Example output: +# # node_ids_for_target_logit_nodes +# # ['27_22605_0', '27_26865_5'] + +# # Get the node_ids for the topk edges for our target logit nodes +# node_ids_for_topk_edges_of_target_logit_nodes = get_node_ids_for_adj_matrix_indices( +# topk_logit_indices[target_logit_vec_idxs], node_mapping +# ) + +# # Example output: +# # node_ids_for_topk_edges_of_target_logit_nodes[0] +# # ['20_15589_7', 'E_26865_6', '0_24_7', '21_5943_7', '23_12237_7'] +# # node_ids_for_topk_edges_of_target_logit_nodes[1] +# # ['E_26865_6', '20_15589_7', '21_5943_7', '14_2268_6', '16_25_6'] diff --git a/tests/core/test_regen_reqfiles.py b/tests/core/test_regen_reqfiles.py deleted file mode 100644 index b7f30853..00000000 --- a/tests/core/test_regen_reqfiles.py +++ /dev/null @@ -1,219 +0,0 @@ -from pathlib import Path -import importlib.util - - -def load_regen_module(): - spec = importlib.util.spec_from_file_location( - "regen_reqfiles", - Path(__file__).resolve().parents[2] / "requirements" / "utils" / "regen_reqfiles.py", - ) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - -def test_generate_pip_compile_inputs_writes_files(tmp_path): - regen = load_regen_module() - - # build a minimal pyproject dict with dependencies, post_upgrades mapping, and platform_dependent packages - pyproject = { - "project": { - "dependencies": [ - "packageA >=1.0", - "datasets >= 2.0", - "fsspec >= 2023.1", - ], - "optional-dependencies": { - "examples": ["example_pkg >=0.1", "datasets >= 2.0", "some_transitive_dep"], - "test": ["test_pkg >=1.0", "coverage >= 6.0"], - "lightning": ["bitsandbytes", "peft", "finetuning-scheduler >= 2.5.0"], - }, - }, - "tool": { - "ci_pinning": { - "post_upgrades": {"datasets": "4.0.0", "fsspec": "2025.3.0"}, - "platform_dependent": ["bitsandbytes"], - } - }, - } - - ci_out = tmp_path / "ci" - ci_out.mkdir() - - # Redirect regen module file outputs to tmp paths so tests don't modify repo files - regen.REQ_DIR = str(tmp_path) - regen.CI_REQ_DIR = str(ci_out) - regen.POST_UPGRADES_PATH = str(ci_out / "post_upgrades.txt") - regen.CIRCUIT_TRACER_PIN = str(ci_out / "circuit_tracer_pin.txt") - - req_in_path, post_path, platform_path, direct_packages = regen.generate_pip_compile_inputs(pyproject, str(ci_out)) - - # Check that direct_packages contains the expected packages - assert "packagea" in direct_packages # core dependency (normalized to lowercase) - assert "peft" in direct_packages # key package from lightning group - assert "finetuning-scheduler" in direct_packages # key package from lightning group - assert "bitsandbytes" in direct_packages # platform-dependent but still tracked as direct - assert "example_pkg" in direct_packages # from examples group (included completely) - assert "some_transitive_dep" in direct_packages # from examples group (included completely) - assert "test_pkg" in direct_packages # from test group (included completely) - assert "coverage" in direct_packages # from test group (included completely) - # packages excluded due to post_upgrades should not be in direct_packages - assert "datasets" not in direct_packages - assert "fsspec" not in direct_packages - - # requirements.in should be created and should contain core deps and key optional deps - req_in = (ci_out / "requirements.in").read_text() - assert "packageA" in req_in # core dependency - assert "peft" in req_in # key package from lightning group - assert "finetuning-scheduler" in req_in # key package from lightning group - assert "datasets" not in req_in # excluded as post_upgrade - assert "fsspec" not in req_in # excluded as post_upgrade - assert "bitsandbytes" not in req_in # excluded as platform_dependent - # All packages from examples and test groups should be included (test and examples are included completely) - assert "example_pkg" in req_in # from examples group (included completely) - assert "some_transitive_dep" in req_in # from examples group (included completely) - assert "test_pkg" in req_in # from test group (included completely) - assert "coverage" in req_in # from test group (included completely) - - # post_upgrades.txt should exist and pin the specified versions - post_text = Path(post_path).read_text() - assert "datasets==4.0.0" in post_text - assert "fsspec==2025.3.0" in post_text - - # platform_dependent.txt should exist and contain bitsandbytes - platform_text = Path(platform_path).read_text() - assert "bitsandbytes" in platform_text - - -def test_post_process_pinned_requirements(tmp_path): - regen = load_regen_module() - - # Create a mock requirements.txt with direct dependencies and transitive dependencies - requirements_path = tmp_path / "requirements.txt" - requirements_content = """# This is a generated file -absl-py==2.3.1 - # via transformers -torch==2.8.0 - # via -r requirements.in -transformers==4.55.2 - # via -r requirements.in -peft==0.17.0 - # via -r requirements.in -aiosignal==1.4.0 - # via aiohttp -aiohttp==3.12.15 - # via boostedblob -numpy==1.26.4 - # via torch -""" - requirements_path.write_text(requirements_content) - - # Create existing platform_dependent.txt with bitsandbytes - platform_path = tmp_path / "platform_dependent.txt" - platform_path.write_text("bitsandbytes\n") - - # Platform patterns to match - platform_patterns = ["bitsandbytes"] - - # Direct packages list (only the packages we explicitly specify) - direct_packages = ["torch", "transformers", "peft"] - # Redirect regen module file outputs to tmp paths so tests don't modify repo files - ci_out = tmp_path / "ci" - ci_out.mkdir() - regen.REQ_DIR = str(tmp_path) - regen.POST_UPGRADES_PATH = str(ci_out / "post_upgrades.txt") - regen.CIRCUIT_TRACER_PIN = str(ci_out / "circuit_tracer_pin.txt") - - # Run post-processing using the direct_packages defined above - regen.post_process_pinned_requirements( - str(requirements_path), str(platform_path), platform_patterns, direct_packages - ) - - # Check that only direct dependencies remain in requirements.txt - requirements_final = requirements_path.read_text() - assert "torch==2.8.0" in requirements_final # direct dependency should remain - assert "transformers==4.55.2" in requirements_final # direct dependency should remain - assert "peft==0.17.0" in requirements_final # direct dependency should remain - - # Transitive dependencies should be removed - assert "absl-py==2.3.1" not in requirements_final # transitive dependency should be removed - assert "aiosignal==1.4.0" not in requirements_final # transitive dependency should be removed - assert "aiohttp==3.12.15" not in requirements_final # transitive dependency should be removed - assert "numpy==1.26.4" not in requirements_final # transitive dependency should be removed - - # Check that platform_dependent.txt still contains bitsandbytes - platform_final = platform_path.read_text() - assert "bitsandbytes" in platform_final # existing package should remain - - -def test_post_upgrades_comparators_and_malformed(tmp_path): - regen = load_regen_module() - - pyproject = { - "project": { - "dependencies": [ - "packageA >=1.0", - ], - "optional-dependencies": { - "lightning": ["bitsandbytes", "peft", "finetuning-scheduler >= 2.5.0"], - }, - }, - "tool": { - "ci_pinning": { - # comparator-style specs and one malformed spec - "post_upgrades": {"datasets": "==4.0.0", "fsspec": ">=2025.3.0", "weirdpkg": "=>1.2.3"}, - "platform_dependent": ["bitsandbytes"], - } - }, - } - - ci_out = tmp_path / "ci" - ci_out.mkdir() - - # Redirect regen module file outputs to tmp paths so tests don't modify repo files - regen.REQ_DIR = str(tmp_path) - regen.CI_REQ_DIR = str(ci_out) - regen.POST_UPGRADES_PATH = str(ci_out / "post_upgrades.txt") - regen.CIRCUIT_TRACER_PIN = str(ci_out / "circuit_tracer_pin.txt") - - req_in_path, post_path, platform_path, direct_packages = regen.generate_pip_compile_inputs(pyproject, str(ci_out)) - - post_lines = Path(post_path).read_text().splitlines() - # comparator-style specs should be written verbatim - assert "datasets==4.0.0" in post_lines - assert "fsspec>=2025.3.0" in post_lines - # malformed spec should be preserved as-is (regenerator does not validate comparator syntax) - assert "weirdpkg=>1.2.3" in post_lines - - -def test_normalize_rewrites_absolute_paths_in_comments(tmp_path): - regen = load_regen_module() - normalize = regen.normalize_pip_compile_comments - repo_root = regen.REPO_ROOT - # example absolute path inside comments - abs_line = f"# -r {repo_root}/requirements/ci/requirements.in\n" - non_comment = "package==1.2.3\n" - - tmp = tmp_path / "reqs.txt" - tmp.write_text(abs_line + non_comment) - - normalize(str(tmp), repo_root) - - out = tmp.read_text().splitlines() - assert out[0].startswith("#") - assert "requirements/ci/requirements.in" in out[0] - # absolute root should not be present - assert repo_root not in out[0] - # non-comment should remain unchanged - assert out[1] == non_comment.strip() - - -def test_idempotent_when_no_absolute_paths(tmp_path): - regen = load_regen_module() - normalize = regen.normalize_pip_compile_comments - content = "# some comment without path\npackage==0.1.0\n" - tmp = tmp_path / "reqs2.txt" - tmp.write_text(content) - - normalize(str(tmp), regen.REPO_ROOT) - assert tmp.read_text() == content diff --git a/tests/runif.py b/tests/runif.py index caec49bd..e293ad91 100644 --- a/tests/runif.py +++ b/tests/runif.py @@ -19,7 +19,7 @@ import torch from interpretune.utils import _LIGHTNING_AVAILABLE, _BNB_AVAILABLE, _FTS_AVAILABLE from packaging.version import Version -from pkg_resources import get_distribution +from importlib.metadata import version as get_version from it_examples.patching.dep_patch_shim import ExpPatch, _ACTIVE_PATCHES EXTENDED_VER_PAT = re.compile(r"([0-9]+\.){2}[0-9]+") @@ -151,13 +151,13 @@ def __new__( kwargs["min_cuda_gpus"] = True if min_torch: - torch_version = get_distribution("torch").version + torch_version = get_version("torch") extended_torch_ver = EXTENDED_VER_PAT.match(torch_version).group() or torch_version conditions.append(Version(extended_torch_ver) < Version(min_torch)) reasons.append(f"torch>={min_torch}, {extended_torch_ver} installed.") if max_torch: - torch_version = get_distribution("torch").version + torch_version = get_version("torch") extended_torch_ver = EXTENDED_VER_PAT.match(torch_version).group() or torch_version conditions.append(Version(extended_torch_ver) > Version(max_torch)) reasons.append(f"torch<={max_torch}, {extended_torch_ver} installed.") diff --git a/tests/warns.py b/tests/warns.py index ffdcb8b5..2d2cbb35 100644 --- a/tests/warns.py +++ b/tests/warns.py @@ -3,7 +3,7 @@ from typing import List, Optional from warnings import WarningMessage from packaging.version import Version -from pkg_resources import get_distribution +from importlib.metadata import version as get_version from interpretune.utils import dummy_method_warn_fingerprint from interpretune.protocol import Adapter @@ -79,11 +79,11 @@ (Adapter.core, Adapter.core, Adapter.transformer_lens): TL_CTX_WARNS, } -MIN_VERSION_WARNS = "2.2" -MAX_VERSION_WARNS = "2.5" +MIN_VERSION_WARNS = "2.7" +MAX_VERSION_WARNS = "2.9" # torch version-specific warns go here EXPECTED_VERSION_WARNS = {MIN_VERSION_WARNS: [], MAX_VERSION_WARNS: []} -torch_version = get_distribution("torch").version +torch_version = get_version("torch") extended_torch_ver = EXTENDED_VER_PAT.match(torch_version).group() or torch_version if Version(extended_torch_ver) < Version(MAX_VERSION_WARNS): EXPECTED_WARNS.extend(EXPECTED_VERSION_WARNS[MIN_VERSION_WARNS])