diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 6142c3a3..a941c232 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -18,9 +18,9 @@ trigger: pr: branches: include: - - '*' # testing all branches that are ready for review (may revert to below if resource constraints dictate) - # - main - # - release/* + # - '*' + - main + - release/* paths: include: - setup.* diff --git a/.github/actions/check-file-diffs/action.yml b/.github/actions/check-file-diffs/action.yml new file mode 100644 index 00000000..7f7917ee --- /dev/null +++ b/.github/actions/check-file-diffs/action.yml @@ -0,0 +1,49 @@ +name: Check file diffs +description: Compare files with git diff and produce a patch if they differ +inputs: + compare_paths: + description: Space-separated list of files/paths to compare with git diff + required: true + patch_path: + description: Path to write the patch file + required: false + default: file_diff.patch + error_message: + description: Custom error message to display if files differ + required: false + default: "Files have changed" +runs: + using: composite + steps: + - id: check + name: Check for diffs and write patch + shell: bash + run: | + set -e + rc=0 + git --no-pager diff --exit-code ${{ inputs.compare_paths }} > ${{ inputs.patch_path }} 2> file_diff.err || rc=$? + if [ "$rc" -eq 0 ]; then + echo "changed=false" >> $GITHUB_OUTPUT + rm -f file_diff.err || true + elif [ "$rc" -eq 1 ]; then + echo "changed=true" >> $GITHUB_OUTPUT + echo "Patch size: $(wc -c < ${{ inputs.patch_path }}) bytes" + echo "${{ inputs.error_message }}" + echo "--- Diff ---" + cat ${{ inputs.patch_path }} || true + rm -f file_diff.err || true + else + echo "git diff failed with exit code $rc" >&2 + echo "--- git diff stderr ---" >&2 + cat file_diff.err >&2 || true + exit $rc + fi + echo "patch_path=${{ inputs.patch_path }}" >> $GITHUB_OUTPUT + +outputs: + changed: + description: 'true if files differ' + value: ${{ steps.check.outputs.changed }} + patch_path: + description: 'path to generated patch file' + value: ${{ steps.check.outputs.patch_path }} diff --git a/.github/actions/install-ci-dependencies/action.yml b/.github/actions/install-ci-dependencies/action.yml new file mode 100644 index 00000000..b8afd1dd --- /dev/null +++ b/.github/actions/install-ci-dependencies/action.yml @@ -0,0 +1,56 @@ +name: "Install CI Dependencies" +description: "Install Python dependencies for CI workflows including platform-dependent packages and post-upgrades" + +inputs: + show_pip_list: + description: "Whether to show pip list output after installations" + required: false + default: "false" + apply_post_upgrades: + description: "Whether to apply post-upgrade packages" + required: false + default: "true" + +runs: + using: "composite" + steps: + - name: Set up venv and install ci dependencies + shell: bash + run: | + python -m pip install --upgrade pip setuptools wheel build --cache-dir "$PIP_CACHE_DIR" + # Prefer CI pinned requirements if present + if [ -f requirements/ci/requirements.txt ]; then + pip install -r requirements/ci/requirements.txt --cache-dir "$PIP_CACHE_DIR" + python -m pip install -e '.[test,examples,lightning]' --cache-dir "$PIP_CACHE_DIR" + else + python -m pip install -e '.[test,examples,lightning]' -c requirements/ci_constraints.txt --cache-dir "$PIP_CACHE_DIR" + fi + if [ "${{ inputs.show_pip_list }}" = "true" ]; then + pip list + fi + + - name: Install platform-dependent packages + shell: bash + run: | + # Install platform-dependent packages with flexible constraints to allow platform-specific resolution + if [ -f requirements/platform_dependent.txt ] && [ -s requirements/platform_dependent.txt ]; then + echo "Installing platform-dependent packages..." + python -m pip install -r requirements/platform_dependent.txt --cache-dir "$PIP_CACHE_DIR" || echo "Some platform-dependent packages may not be available on this platform, continuing..." + else + echo "No platform-dependent packages to install." + fi + + - name: Optional post-upgrades (datasets/fsspec etc) + shell: bash + env: + APPLY_POST_UPGRADES: ${{ inputs.apply_post_upgrades }} + run: | + if ([ "${APPLY_POST_UPGRADES}" = "1" ] || [ "${APPLY_POST_UPGRADES}" = "true" ]) && [ -s requirements/post_upgrades.txt ]; then + echo "Applying post-upgrades..." + python -m pip install --upgrade -r requirements/post_upgrades.txt --cache-dir "$PIP_CACHE_DIR" + if [ "${{ inputs.show_pip_list }}" = "true" ]; then + pip list + fi + else + echo "Skipping post-upgrades (either disabled or file empty)." + fi diff --git a/.github/actions/regen-ci-reqs/action.yml b/.github/actions/regen-ci-reqs/action.yml index d1441e89..212fdb60 100644 --- a/.github/actions/regen-ci-reqs/action.yml +++ b/.github/actions/regen-ci-reqs/action.yml @@ -38,25 +38,11 @@ runs: - id: check name: Check for diffs and write patch - shell: bash - run: | - set -e - rc=0 - git --no-pager diff --exit-code ${{ inputs.compare_paths }} > ${{ inputs.patch_path }} 2> regen_diff.err || rc=$? - if [ "$rc" -eq 0 ]; then - echo "changed=false" >> $GITHUB_OUTPUT - rm -f regen_diff.err || true - elif [ "$rc" -eq 1 ]; then - echo "changed=true" >> $GITHUB_OUTPUT - echo "Patch size: $(wc -c < ${{ inputs.patch_path }}) bytes" - rm -f regen_diff.err || true - else - echo "git diff failed with exit code $rc" >&2 - echo "--- git diff stderr ---" >&2 - cat regen_diff.err >&2 || true - exit $rc - fi - echo "patch_path=${{ inputs.patch_path }}" >> $GITHUB_OUTPUT + uses: ./.github/actions/check-file-diffs + with: + compare_paths: ${{ inputs.compare_paths }} + patch_path: ${{ inputs.patch_path }} + error_message: "CI requirements have changed and need to be regenerated" outputs: changed: diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 40012675..6e89cb35 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -175,42 +175,10 @@ jobs: restore-keys: | ${{ runner.os }}-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py${{ matrix.python-version }}- - - name: Set up venv and install ci dependencies (all OS) - shell: bash - run: | - python -m pip install --upgrade pip setuptools wheel build --cache-dir "$PIP_CACHE_DIR" - # Prefer CI pinned requirements if present - if [ -f requirements/ci/requirements.txt ]; then - pip install -r requirements/ci/requirements.txt --cache-dir "$PIP_CACHE_DIR" - python -m pip install -e '.[test,examples,lightning]' --cache-dir "$PIP_CACHE_DIR" - else - python -m pip install -e '.[test,examples,lightning]' -c requirements/ci_constraints.txt --cache-dir "$PIP_CACHE_DIR" - fi - pip list - - - name: Install platform-dependent packages - shell: bash - run: | - # Install platform-dependent packages with flexible constraints to allow platform-specific resolution - if [ -f requirements/platform_dependent.txt ] && [ -s requirements/platform_dependent.txt ]; then - echo "Installing platform-dependent packages..." - python -m pip install -r requirements/platform_dependent.txt --cache-dir "$PIP_CACHE_DIR" || echo "Some platform-dependent packages may not be available on this platform, continuing..." - else - echo "No platform-dependent packages to install." - fi - - - name: Optional post-upgrades (datasets/fsspec etc) - shell: bash - env: - APPLY_POST_UPGRADES: ${{ vars.APPLY_POST_UPGRADES || '1' }} - run: | - if [ "${APPLY_POST_UPGRADES}" = "1" ] && [ -s requirements/post_upgrades.txt ]; then - echo "Applying post-upgrades..." - python -m pip install --upgrade -r requirements/post_upgrades.txt --cache-dir "$PIP_CACHE_DIR" - pip list - else - echo "Skipping post-upgrades (either disabled or file empty)." - fi + - name: Install CI dependencies + uses: ./.github/actions/install-ci-dependencies + with: + show_pip_list: "true" - name: Run pytest coverage w/ configured instrumentation id: run-pytest-coverage diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index b9c7942f..161e9260 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -31,23 +31,32 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.12' - - name: Set up venv and install dependencies + + - name: Reset caching + id: set_time_period + run: | + python -c "import time; days = time.time() / 60 / 60 / 24; print(f'TIME_PERIOD=d{int(days / 7) * 7}')" >> $GITHUB_OUTPUT + + - name: Get pip cache dir + id: pip-cache shell: bash run: | - python -m pip install --upgrade pip setuptools wheel build - # Prefer pinned CI requirements if present - if [ -f requirements/ci/requirements.txt ]; then - python -m pip install -r requirements/ci/requirements.txt -r requirements/platform_dependent.txt - python -m pip install -e '.[test,examples,lightning]' - else - python -m pip install -e '.[test,examples,lightning]' -c requirements/ci_constraints.txt - fi - # Optional post-upgrades (disabled by default) - if [ "${APPLY_POST_UPGRADES:-1}" = "1" ] && [ -s requirements/post_upgrades.txt ]; then - pip install --upgrade -r requirements/post_upgrades.txt - fi - pip list - # No env injection here; the shell uses ${APPLY_POST_UPGRADES:-1} to default when unset. + echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV + + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ env.PIP_CACHE_DIR }}/wheels + key: ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12-${{ hashFiles('requirements/ci/requirements.txt') }}-${{ hashFiles('requirements/post_upgrades.txt') }}-${{ hashFiles('requirements/platform_dependent.txt') }} + restore-keys: | + ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12- + + - name: Install CI dependencies + uses: ./.github/actions/install-ci-dependencies + with: + show_pip_list: "true" + apply_post_upgrades: "true" + - name: Setup pyright, precommit and git lfs shell: bash run: | diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml index 3513385b..40a97600 100644 --- a/.github/workflows/type-check.yml +++ b/.github/workflows/type-check.yml @@ -1,4 +1,4 @@ -name: Type Check +name: Stale Stubs and Type Checks on: push: @@ -9,6 +9,7 @@ on: - "pyproject.toml" - "src/**" - "requirements/**" + - "scripts/generate_op_stubs.py" - ".github/workflows/type-check.yml" # Exclude documentation-only files from triggering - "!docs/**" @@ -28,6 +29,7 @@ on: - "pyproject.toml" - "src/**" - "requirements/**" + - "scripts/generate_op_stubs.py" - ".github/workflows/type-check.yml" # Exclude documentation-only files from triggering - "!docs/**" @@ -75,42 +77,48 @@ jobs: restore-keys: | ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12- - - name: Set up venv and install ci dependencies - shell: bash - run: | - python -m pip install --upgrade pip setuptools wheel build --cache-dir "$PIP_CACHE_DIR" - # Prefer CI pinned requirements if present - if [ -f requirements/ci/requirements.txt ]; then - pip install -r requirements/ci/requirements.txt --cache-dir "$PIP_CACHE_DIR" - python -m pip install -e '.[test,examples,lightning]' --cache-dir "$PIP_CACHE_DIR" - else - python -m pip install -e '.[test,examples,lightning]' -c requirements/ci_constraints.txt --cache-dir "$PIP_CACHE_DIR" - fi + - name: Install CI dependencies + uses: ./.github/actions/install-ci-dependencies + with: + apply_post_upgrades: ${{ vars.APPLY_POST_UPGRADES || 'true' }} - - name: Install platform-dependent packages + - name: Check type stubs are up to date + id: check_stubs + continue-on-error: true shell: bash run: | - # Install platform-dependent packages with flexible constraints to allow platform-specific resolution - if [ -f requirements/platform_dependent.txt ] && [ -s requirements/platform_dependent.txt ]; then - echo "Installing platform-dependent packages..." - python -m pip install -r requirements/platform_dependent.txt --cache-dir "$PIP_CACHE_DIR" || echo "Some platform-dependent packages may not be available on this platform, continuing..." - else - echo "No platform-dependent packages to install." - fi + # Make a backup of current stubs + cp src/interpretune/__init__.pyi src/interpretune/__init__.pyi.backup - - name: Optional post-upgrades (datasets/fsspec etc) - shell: bash - env: - APPLY_POST_UPGRADES: ${{ vars.APPLY_POST_UPGRADES || '1' }} - run: | - if [ "${APPLY_POST_UPGRADES}" = "1" ] && [ -s requirements/post_upgrades.txt ]; then - echo "Applying post-upgrades..." - python -m pip install --upgrade -r requirements/post_upgrades.txt --cache-dir "$PIP_CACHE_DIR" - else - echo "Skipping post-upgrades (either disabled or file empty)." - fi + # Regenerate stubs + python scripts/generate_op_stubs.py + + # Check if stubs changed using our shared action + echo "stub_check_step=true" >> $GITHUB_OUTPUT + + - name: Check for stub differences + id: stub_diff + uses: ./.github/actions/check-file-diffs + with: + compare_paths: "src/interpretune/__init__.pyi" + patch_path: "stubs_diff.patch" + error_message: "Type stubs are stale and need to be regenerated with: python scripts/generate_op_stubs.py" + + - name: Upload stub diff (if present) + if: steps.stub_diff.outputs.changed == 'true' + uses: actions/upload-artifact@v4 + with: + name: type-stubs-diff + path: stubs_diff.patch - name: Run pyright type check shell: bash run: | pyright -p pyproject.toml + + - name: Fail workflow if stubs are stale + if: steps.stub_diff.outputs.changed == 'true' + shell: bash + run: | + echo "::error::Type stubs are stale. Please run 'python scripts/generate_op_stubs.py' and commit the updated __init__.pyi file." + exit 1 diff --git a/scripts/generate_op_stubs.py b/scripts/generate_op_stubs.py index fa1f1cef..826afe96 100644 --- a/scripts/generate_op_stubs.py +++ b/scripts/generate_op_stubs.py @@ -255,6 +255,11 @@ def generate_stubs(yaml_path: Path, output_path: Path) -> None: "from interpretune.base.datamodules import ITDataModule as ITDataModule", "from interpretune.base.components.mixins import MemProfilerHooks as MemProfilerHooks", "from interpretune.analysis.ops import AnalysisBatch as AnalysisBatch", + "from interpretune.analysis import (", + " AnalysisStore as AnalysisStore,", + " DISPATCHER as DISPATCHER,", + " SAEAnalysisTargets as SAEAnalysisTargets,", + ")", "from interpretune.config import (", " ITLensConfig as ITLensConfig,", " SAELensConfig as SAELensConfig,", @@ -264,7 +269,11 @@ def generate_stubs(yaml_path: Path, output_path: Path) -> None: " GenerativeClassificationConfig as GenerativeClassificationConfig,", " BaseGenerationConfig as BaseGenerationConfig,", " HFGenerationConfig as HFGenerationConfig,", + " SAELensFromPretrainedConfig as SAELensFromPretrainedConfig,", + " AnalysisCfg as AnalysisCfg,", ")", + "from interpretune.session import ITSessionConfig as ITSessionConfig, ITSession as ITSession", + "from interpretune.runners import AnalysisRunner as AnalysisRunner", "from interpretune.utils import rank_zero_warn as rank_zero_warn, sanitize_input_name as sanitize_input_name", "from interpretune.protocol import STEP_OUTPUT as STEP_OUTPUT", "", @@ -296,6 +305,39 @@ def generate_stubs(yaml_path: Path, output_path: Path) -> None: print(f"Stubs generated at {output_path}") + # Apply formatting to match pre-commit hooks + try: + import subprocess + + # Ensure pre-commit hooks are installed + # We run install every time since it's idempotent and fast if already installed + install_result = subprocess.run(["pre-commit", "install"], capture_output=True, text=True, cwd=project_root) + + if install_result.returncode != 0: + print(f"Warning: Failed to install pre-commit hooks: {install_result.stderr}") + print("Skipping formatting step") + return + + # Run pre-commit ruff formatting on the generated file to match pre-commit formatting + result = subprocess.run( + ["pre-commit", "run", "ruff-format", "--files", str(output_path)], + capture_output=True, + text=True, + cwd=project_root, + ) + + if result.returncode == 0: + print(f"Applied ruff formatting to {output_path}") + else: + # Pre-commit returns non-zero when it makes changes, which is expected + if "reformatted" in result.stdout or "Passed" in result.stdout: + print(f"Applied ruff formatting to {output_path}") + else: + print(f"Warning: ruff formatting may have failed: {result.stdout}") + + except Exception as e: + print(f"Warning: Could not apply formatting: {e}") + if __name__ == "__main__": yaml_path = project_root / "src" / "interpretune" / "analysis" / "ops" / "native_analysis_functions.yaml" diff --git a/src/interpretune/__init__.pyi b/src/interpretune/__init__.pyi index e129d21a..f031b53a 100644 --- a/src/interpretune/__init__.pyi +++ b/src/interpretune/__init__.pyi @@ -12,6 +12,11 @@ from interpretune.protocol import BaseAnalysisBatchProtocol, DefaultAnalysisBatc from interpretune.base.datamodules import ITDataModule as ITDataModule from interpretune.base.components.mixins import MemProfilerHooks as MemProfilerHooks from interpretune.analysis.ops import AnalysisBatch as AnalysisBatch +from interpretune.analysis import ( + AnalysisStore as AnalysisStore, + DISPATCHER as DISPATCHER, + SAEAnalysisTargets as SAEAnalysisTargets, +) from interpretune.config import ( ITLensConfig as ITLensConfig, SAELensConfig as SAELensConfig, @@ -21,7 +26,11 @@ from interpretune.config import ( GenerativeClassificationConfig as GenerativeClassificationConfig, BaseGenerationConfig as BaseGenerationConfig, HFGenerationConfig as HFGenerationConfig, + SAELensFromPretrainedConfig as SAELensFromPretrainedConfig, + AnalysisCfg as AnalysisCfg, ) +from interpretune.session import ITSessionConfig as ITSessionConfig, ITSession as ITSession +from interpretune.runners import AnalysisRunner as AnalysisRunner from interpretune.utils import rank_zero_warn as rank_zero_warn, sanitize_input_name as sanitize_input_name from interpretune.protocol import STEP_OUTPUT as STEP_OUTPUT