diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 7b6290f9..5fb580f4 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -48,7 +48,7 @@ jobs: strategy: matrix: PyTorch_latest: - image: speediedan/interpretune:py3.12-pt2.9.1-azpl-init + image: speediedan/interpretune:py3.13-pt2.9.1-azpl-init scope: "" timeoutInMinutes: 100 cancelTimeoutInMinutes: 2 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 45db93e3..727121ff 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -44,7 +44,7 @@ You can also fill out the list below manually. - TransformerLens Version (e.g., 2.16.1): - SAELens Version (e.g., 6.3.1): - Lightning Version (e.g., 2.5.1): -- Python version (e.g., 3.12): +- Python version (e.g., 3.13): - OS (e.g., Linux): - CUDA/cuDNN version: - GPU models and configuration: diff --git a/.github/actions/install-ci-dependencies/action.yml b/.github/actions/install-ci-dependencies/action.yml index 1afe4a56..ddf4ad64 100644 --- a/.github/actions/install-ci-dependencies/action.yml +++ b/.github/actions/install-ci-dependencies/action.yml @@ -5,7 +5,7 @@ inputs: python_version: description: "Python version to use" required: false - default: "3.12" + default: "3.13" show_pip_list: description: "Whether to show package list output after installations" required: false diff --git a/.github/actions/regen-ci-reqs/action.yml b/.github/actions/regen-ci-reqs/action.yml index 2df5bbcb..363084f2 100644 --- a/.github/actions/regen-ci-reqs/action.yml +++ b/.github/actions/regen-ci-reqs/action.yml @@ -4,7 +4,7 @@ inputs: python_version: description: 'Python version to setup' required: false - default: '3.12' + default: '3.13' ci_output_dir: description: Directory where regen writes CI output (defaults to requirements/ci) required: false diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 45ec6e6a..80601280 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -5,7 +5,7 @@ **Interpretune** is a flexible, powerful framework for collaborative AI world model analysis and tuning. This project is in **pre-MVP** stage - features and APIs are subject to change. **Key Technologies:** -- Python 3.10+ (CI tests on 3.12) +- Python 3.10+ (CI tests on 3.13) - PyTorch 2.7.1+ with transformers ecosystem - Core deps: transformer_lens >= 3.0.0 (TransformerBridge support), sae_lens, datasets, jsonargparse - Optional: PyTorch Lightning, W&B, circuit-tracer, neuronpedia @@ -260,7 +260,7 @@ src/it_examples/ # Example experiments **File:** `.github/workflows/ci_test-full.yml` **Triggers:** Push/PR to main, changes to source/test files -**Platforms:** Ubuntu 22.04, Windows 2022, macOS 14 (Python 3.12) +**Platforms:** Ubuntu 22.04, Windows 2022, macOS 14 (Python 3.13) **Timeout:** 90 minutes **CI Process:** diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 820b5c18..9eadc41a 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -129,7 +129,7 @@ jobs: windows-2022, macos-14 ] - python-version: ["3.12"] + python-version: ["3.13"] timeout-minutes: 90 env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 7c8b75ee..de968122 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -30,7 +30,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.12' + python-version: '3.13' - name: Install CI dependencies uses: ./.github/actions/install-ci-dependencies diff --git a/.github/workflows/regen-ci-req-check.yml b/.github/workflows/regen-ci-req-check.yml index e8f0d125..14fd2812 100644 --- a/.github/workflows/regen-ci-req-check.yml +++ b/.github/workflows/regen-ci-req-check.yml @@ -19,7 +19,7 @@ jobs: uses: ./.github/actions/regen-ci-reqs id: regen_ci_reqs with: - python_version: '3.12' + python_version: '3.13' ci_output_dir: requirements/ci compare_paths: "requirements/ci/requirements.txt" patch_path: /tmp/regen_diff.patch diff --git a/.github/workflows/regen-ci-req-report.yml b/.github/workflows/regen-ci-req-report.yml index 49e5ae88..b32e65e3 100644 --- a/.github/workflows/regen-ci-req-report.yml +++ b/.github/workflows/regen-ci-req-report.yml @@ -60,7 +60,7 @@ jobs: uses: ./.github/actions/regen-ci-reqs id: regen_ci_reqs with: - python_version: '3.12' + python_version: '3.13' ci_output_dir: requirements/ci compare_paths: "requirements/ci/requirements.txt" patch_path: /tmp/regen_diff.patch diff --git a/.github/workflows/release_pypi.yml b/.github/workflows/release_pypi.yml index 35be85bc..497cc6dc 100644 --- a/.github/workflows/release_pypi.yml +++ b/.github/workflows/release_pypi.yml @@ -65,7 +65,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.12' + python-version: '3.13' - name: Install dependencies run: >- diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml index 6d9f65da..732634bc 100644 --- a/.github/workflows/type-check.yml +++ b/.github/workflows/type-check.yml @@ -55,10 +55,10 @@ jobs: IT_USE_CT_COMMIT_PIN: "1" steps: - uses: actions/checkout@v4 - - name: Set up Python 3.12 + - name: Set up Python 3.13 uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.13" - name: Reset caching id: set_time_period @@ -75,9 +75,9 @@ jobs: uses: actions/cache@v4 with: path: ${{ env.PIP_CACHE_DIR }}/wheels - key: ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12-${{ hashFiles('requirements/ci/requirements.txt') }} + key: ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.13-${{ hashFiles('requirements/ci/requirements.txt') }} restore-keys: | - ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.12- + ubuntu-22.04-pip-wheels-${{ steps.set_time_period.outputs.TIME_PERIOD }}-py3.13- - name: Install CI dependencies uses: ./.github/actions/install-ci-dependencies diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 69fc5b57..3fe428c7 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -16,7 +16,7 @@ ARG OS_VER=ubuntu22.04 FROM nvidia/cuda:${CUDA_VERSION}-devel-${OS_VER} -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.13 ARG PYTORCH_VERSION=2.9.1 ARG CUST_BUILD=0 ARG MKL_THREADING_LAYER=GNU diff --git a/dockers/docker_images_main.sh b/dockers/docker_images_main.sh index f0d584c1..0c5d4343 100755 --- a/dockers/docker_images_main.sh +++ b/dockers/docker_images_main.sh @@ -43,7 +43,7 @@ maybe_build(){ build_eval(){ # latest PyTorch image supported by release # see CUDA_ARCHES_FULL_VERSION for the full version of the pytorch-provided toolkit - declare -A iv=(["cuda"]="12.8.1" ["python"]="3.12" ["pytorch"]="2.9.1" ["cust_build"]="1") + declare -A iv=(["cuda"]="12.8.1" ["python"]="3.13" ["pytorch"]="2.9.1" ["cust_build"]="1") export latest_pt="base-cu${iv["cuda"]}-py${iv["python"]}-pt${iv["pytorch"]}" export latest_azpl="py${iv["python"]}-pt${iv["pytorch"]}-azpl-init" maybe_build iv "${latest_pt}" "${latest_azpl}" diff --git a/dockers/docker_images_release.sh b/dockers/docker_images_release.sh index 363e4fa7..19c25c6d 100755 --- a/dockers/docker_images_release.sh +++ b/dockers/docker_images_release.sh @@ -41,7 +41,7 @@ maybe_build(){ build_eval(){ # latest PyTorch image supported by release - declare -A iv=(["cuda"]="12.8.1" ["python"]="3.12" ["pytorch"]="2.9.1" ["cust_build"]="0") + declare -A iv=(["cuda"]="12.8.1" ["python"]="3.13" ["pytorch"]="2.9.1" ["cust_build"]="0") export latest_pt="base-cu${iv["cuda"]}-py${iv["python"]}-pt${iv["pytorch"]}" export latest_azpl="py${iv["python"]}-pt${iv["pytorch"]}-azpl-init" maybe_build iv "${latest_pt}" "${latest_azpl}" diff --git a/dockers/it-az-base/Dockerfile b/dockers/it-az-base/Dockerfile index b7c77af7..9b8ab712 100644 --- a/dockers/it-az-base/Dockerfile +++ b/dockers/it-az-base/Dockerfile @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.13 ARG PYTORCH_VERSION=2.9.1 ARG CUST_BASE diff --git a/pyproject.toml b/pyproject.toml index 95f2928b..8e48c6c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] description = "A package to support LLM reasoning and interpretability experiments at a level of abstraction that is both powerfully flexible and convenient" readme = "README.md" @@ -36,7 +37,7 @@ dependencies = [ "transformers>=4.57.1", "tabulate >= 0.9.0", "datasets >= 4.0.0", - "jsonargparse[signatures] >= 4.35.0,<4.42.0", # upstream regression with 4.42, replace this req once adding fts req + "jsonargparse[signatures] >= 4.35.0,<4.42.0", # TODO: investigate upstream regression w/ 4.42, remove if fts is req # "finetuning-scheduler[possible_future_it_plugin] >= 2.5.0", ] @@ -93,7 +94,7 @@ examples = [ git-deps = [ "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87", "transformer-lens @ git+https://github.com/speediedan/TransformerLens.git@d35d01feb9cc076a091e41255e1c9f92de2af236", - "finetuning-scheduler @ git+https://github.com/speediedan/finetuning-scheduler.git@4aa64032c07acd34493a5c05929845bd011426c9", + "finetuning-scheduler @ git+https://github.com/speediedan/finetuning-scheduler.git@a96d8158ee92dddb3eb5203095ce9db98c137fd0", "sae_lens >= 6.3.1", ] diff --git a/requirements/ci/requirements.txt b/requirements/ci/requirements.txt index 54ea994d..1793cdee 100644 --- a/requirements/ci/requirements.txt +++ b/requirements/ci/requirements.txt @@ -7,9 +7,7 @@ accelerate==1.12.0 aiohappyeyeballs==2.6.1 # via aiohttp aiohttp==3.13.3 - # via - # fsspec - # papermill + # via fsspec aiosignal==1.4.0 # via aiohttp annotated-types==0.7.0 @@ -86,7 +84,7 @@ cryptography==46.0.3 ; platform_machine != 'ppc64le' and platform_machine != 's3 # via secretstorage cycler==0.12.1 # via matplotlib -datasets==4.4.2 +datasets==4.5.0 # via # interpretune (pyproject.toml) # evaluate @@ -137,7 +135,7 @@ fsspec[http]==2025.10.0 # evaluate # huggingface-hub # torch -gdown==5.2.0 +gdown==5.2.1 # via interpretune (pyproject.toml) gitdb==4.0.12 # via gitpython @@ -166,7 +164,7 @@ huggingface-hub[hf-xet]==0.36.0 # transformers id==1.5.0 # via twine -identify==2.6.15 +identify==2.6.16 # via pre-commit idna==3.11 # via @@ -196,7 +194,7 @@ isoduration==20.11.0 # via jsonschema jaraco-classes==3.4.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via keyring -jaraco-context==6.0.2 ; platform_machine != 'ppc64le' and platform_machine != 's390x' +jaraco-context==6.1.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via keyring jaraco-functools==4.4.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' # via keyring @@ -254,9 +252,9 @@ jupyter-server==2.17.0 # jupyterlab-server # notebook # notebook-shim -jupyter-server-terminals==0.5.3 +jupyter-server-terminals==0.5.4 # via jupyter-server -jupyterlab==4.5.1 +jupyterlab==4.5.2 # via # interpretune (pyproject.toml) # notebook @@ -348,13 +346,13 @@ nodeenv==1.10.0 # via # pre-commit # pyright -notebook==7.5.1 +notebook==7.5.2 # via interpretune (pyproject.toml) notebook-shim==0.2.4 # via # jupyterlab # notebook -numpy==2.4.0 +numpy==2.4.1 # via # accelerate # bitsandbytes @@ -454,7 +452,7 @@ platformdirs==4.5.1 # jupyter-core # virtualenv # wandb -plotly==6.5.1 +plotly==6.5.2 # via interpretune (pyproject.toml) pluggy==1.6.0 # via pytest @@ -462,7 +460,7 @@ pre-commit==4.5.1 # via # interpretune (pyproject.toml:dev) # interpretune (pyproject.toml:test) -prometheus-client==0.23.1 +prometheus-client==0.24.1 # via jupyter-server prompt-toolkit==3.0.52 # via ipython @@ -470,7 +468,7 @@ propcache==0.4.1 # via # aiohttp # yarl -protobuf==6.33.2 +protobuf==6.33.4 # via # tensorboard # wandb @@ -568,7 +566,7 @@ referencing==0.37.0 # jsonschema # jsonschema-specifications # jupyter-events -regex==2025.11.3 +regex==2026.1.15 # via transformers requests[socks]==2.32.5 # via @@ -611,11 +609,11 @@ safetensors==0.7.0 # transformers scikit-learn==1.8.0 # via interpretune (pyproject.toml) -scipy==1.16.3 +scipy==1.17.0 # via scikit-learn secretstorage==3.5.0 ; platform_machine != 'ppc64le' and platform_machine != 's390x' and sys_platform == 'linux' # via keyring -send2trash==2.0.0 +send2trash==2.1.0 # via jupyter-server sentry-sdk==2.49.0 # via wandb @@ -697,7 +695,7 @@ traitlets==5.14.3 # nbclient # nbconvert # nbformat -transformers==4.57.3 +transformers==4.57.5 # via # interpretune (pyproject.toml) # peft @@ -709,16 +707,12 @@ typeshed-client==2.8.2 # via jsonargparse typing-extensions==4.15.0 # via - # aiosignal - # anyio # beautifulsoup4 # grpcio # huggingface-hub - # psycopg # pydantic # pydantic-core # pyright - # referencing # torch # typeshed-client # typing-inspection @@ -737,13 +731,13 @@ urllib3==2.6.3 # requests # sentry-sdk # twine -uv==0.9.22 +uv==0.9.26 # via # interpretune (pyproject.toml:dev) # interpretune (pyproject.toml:test) virtualenv==20.36.1 # via pre-commit -wandb==0.23.1 +wandb==0.24.0 # via interpretune (pyproject.toml) wcwidth==0.2.14 # via prompt-toolkit diff --git a/scripts/build_it_env.sh b/scripts/build_it_env.sh index d2d085f2..8c7e4767 100755 --- a/scripts/build_it_env.sh +++ b/scripts/build_it_env.sh @@ -155,7 +155,7 @@ clear_activate_env(){ base_env_build(){ case ${target_env_name} in it_latest) - clear_activate_env python3.12 + clear_activate_env python3.13 if [[ -n ${torch_dev_ver} ]]; then # temporarily remove torchvision until it supports cu128 in nightly binary uv pip install ${uv_install_flags} --pre torch==2.10.0.${torch_dev_ver} --index-url https://download.pytorch.org/whl/nightly/cu128 @@ -166,7 +166,7 @@ base_env_build(){ fi ;; it_release) - clear_activate_env python3.12 + clear_activate_env python3.13 uv pip install ${uv_install_flags} torch --index-url https://download.pytorch.org/whl/cu128 ;; *) diff --git a/src/interpretune/adapter_registry.py b/src/interpretune/adapter_registry.py index dd9eedd6..504e5d5e 100644 --- a/src/interpretune/adapter_registry.py +++ b/src/interpretune/adapter_registry.py @@ -9,7 +9,7 @@ import logging import threading -from typing import Any, Optional +from typing import Any from interpretune.adapters.registration import CompositionRegistry @@ -25,7 +25,7 @@ class LazyCompositionRegistry: """ def __init__(self) -> None: - self._registry: Optional[CompositionRegistry] = None + self._registry: CompositionRegistry | None = None self._lock = threading.RLock() def _ensure_initialized(self) -> None: diff --git a/src/interpretune/adapters/_light_register.py b/src/interpretune/adapters/_light_register.py index 9e444cc5..5603617c 100644 --- a/src/interpretune/adapters/_light_register.py +++ b/src/interpretune/adapters/_light_register.py @@ -14,10 +14,10 @@ from importlib import import_module from types import ModuleType -from typing import Iterable, Optional +from typing import Iterable -def _import_adapter_module(module_path: str) -> Optional[ModuleType]: +def _import_adapter_module(module_path: str) -> ModuleType | None: """Import an adapter module and return the module object. We rely on adapter modules to avoid importing heavy third-party dependencies at module import time (they use diff --git a/src/interpretune/adapters/circuit_tracer.py b/src/interpretune/adapters/circuit_tracer.py index 4f6bdd08..3c464162 100644 --- a/src/interpretune/adapters/circuit_tracer.py +++ b/src/interpretune/adapters/circuit_tracer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Optional, Dict, List, Union, Tuple +from typing import Any from dataclasses import dataclass, field from pathlib import Path from copy import deepcopy @@ -28,8 +28,8 @@ @dataclass(kw_only=True) class InstantiatedGraph: handle: Graph - graph_path: Optional[Path] = None - metadata: Dict[str, Any] = field(default_factory=dict) + graph_path: Path | None = None + metadata: dict[str, Any] = field(default_factory=dict) ################################################################################ @@ -58,8 +58,8 @@ def replacement_model(self) -> ReplacementModel | None: class BaseCircuitTracerModule(BaseITLensModule): def __init__(self, *args, **kwargs): # Initialize attributes that may be required in base init methods - self.attribution_graphs: List[InstantiatedGraph] = [] - self._replacement_model: Optional[ReplacementModel] = None + self.attribution_graphs: list[InstantiatedGraph] = [] + self._replacement_model: ReplacementModel | None = None super().__init__(*args, **kwargs) def _convert_hf_to_tl(self) -> None: @@ -80,7 +80,7 @@ def _convert_hf_to_tl(self) -> None: self._load_replacement_model(pretrained_kwargs=loaded_model_kwargs) self.model.config = hf_preconversion_config - def _load_replacement_model(self, pretrained_kwargs: Optional[dict] = None) -> None: + def _load_replacement_model(self, pretrained_kwargs: dict | None = None) -> None: """Load the ReplacementModel for circuit tracing.""" pretrained_kwargs = pretrained_kwargs or {} cfg = self.circuit_tracer_cfg @@ -115,7 +115,7 @@ def set_input_require_grads(self) -> None: # Circuit tracer handles gradient requirements internally rank_zero_info("Input gradient requirements handled by circuit tracer internally.") - def _get_attribution_targets(self) -> Optional[list | torch.Tensor]: + def _get_attribution_targets(self) -> list | torch.Tensor | None: """Determine the attribution_targets value based on CircuitTracerConfig. Returns: @@ -185,7 +185,7 @@ def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph: return graph - def save_graph(self, graph: Graph, output_path: Union[str, Path]) -> Path: + def save_graph(self, graph: Graph, output_path: str | Path) -> Path: """Save attribution graph to file.""" output_path = Path(output_path) graph.to_pt(str(output_path)) @@ -195,7 +195,7 @@ def create_graph_visualization_files( self, graph: Graph, slug: str, - output_dir: Union[str, Path], + output_dir: str | Path, node_threshold: float = 0.8, edge_threshold: float = 0.98, ) -> None: @@ -288,10 +288,10 @@ class CircuitTracerAnalysisMixin: def save_graph( self, graph: Graph, - output_path: Union[str, Path], - slug: Optional[str] = None, - custom_metadata: Optional[Dict[str, Any]] = None, - use_neuronpedia: Optional[bool] = None, + output_path: str | Path, + slug: str | None = None, + custom_metadata: dict[str, Any] | None = None, + use_neuronpedia: bool | None = None, ) -> Path: """Save and optionally transform graph for Neuronpedia upload.""" # Convert output_path to directory for processing @@ -340,13 +340,13 @@ def save_graph( def generate_graph( self, prompt: str, - slug: Optional[str] = None, - custom_metadata: Optional[Dict[str, Any]] = None, + slug: str | None = None, + custom_metadata: dict[str, Any] | None = None, upload_to_np: bool = False, - output_dir: Optional[Union[str, Path]] = None, - use_neuronpedia: Optional[bool] = None, + output_dir: str | Path | None = None, + use_neuronpedia: bool | None = None, **generation_kwargs, - ) -> Tuple[Graph, Path, Any]: + ) -> tuple[Graph, Path, Any]: """Generate attribution graph and optionally upload to Neuronpedia.""" if use_neuronpedia is None: use_neuronpedia = ( diff --git a/src/interpretune/adapters/model_view.py b/src/interpretune/adapters/model_view.py index 6efcd4ee..f497b3ca 100644 --- a/src/interpretune/adapters/model_view.py +++ b/src/interpretune/adapters/model_view.py @@ -21,7 +21,7 @@ import os from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING import torch @@ -64,7 +64,7 @@ def build_param_mapping(self) -> None: pass @abstractmethod - def transform_to_canonical(self, param_names: List[str], inspect_only: bool = False) -> List[str]: + def transform_to_canonical(self, param_names: list[str], inspect_only: bool = False) -> list[str]: """Transform view-specific parameter names to canonical names. Used by FTS to convert schedule params to optimizer params. @@ -79,7 +79,7 @@ def transform_to_canonical(self, param_names: List[str], inspect_only: bool = Fa pass @abstractmethod - def transform_from_canonical(self, param_names: List[str]) -> List[str]: + def transform_from_canonical(self, param_names: list[str]) -> list[str]: """Transform canonical parameter names to view-specific names. Used for logging/reporting optimizer params in view naming. @@ -93,7 +93,7 @@ def transform_from_canonical(self, param_names: List[str]) -> List[str]: pass @abstractmethod - def get_named_params(self) -> Dict[str, torch.Tensor]: + def get_named_params(self) -> dict[str, torch.Tensor]: """Get model parameters using view-specific naming. Returns: @@ -102,7 +102,7 @@ def get_named_params(self) -> Dict[str, torch.Tensor]: pass @abstractmethod - def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + def gen_schedule(self, dump_loc: str | os.PathLike) -> os.PathLike | None: """Generate implicit schedule using view-specific naming. Args: @@ -114,7 +114,7 @@ def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLik pass @abstractmethod - def validate_schedule(self) -> Tuple[int, int]: + def validate_schedule(self) -> tuple[int, int]: """Validate schedule with optional view-specific diagnostics. Delegates to base StrategyAdapter implementation. @@ -149,19 +149,19 @@ def build_param_mapping(self) -> None: """No mapping needed for canonical naming.""" pass - def transform_to_canonical(self, param_names: List[str], inspect_only: bool = False) -> List[str]: + def transform_to_canonical(self, param_names: list[str], inspect_only: bool = False) -> list[str]: """Identity transformation - params already canonical.""" return param_names - def transform_from_canonical(self, param_names: List[str]) -> List[str]: + def transform_from_canonical(self, param_names: list[str]) -> list[str]: """Identity transformation - params already canonical.""" return param_names - def get_named_params(self) -> Dict[str, torch.Tensor]: + def get_named_params(self) -> dict[str, torch.Tensor]: """Get canonical parameter names.""" return dict(self.pl_module.named_parameters()) - def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + def gen_schedule(self, dump_loc: str | os.PathLike) -> os.PathLike | None: """Generate schedule with canonical naming. Delegates to base StrategyAdapter implementation via the adapter reference. @@ -171,7 +171,7 @@ def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLik return StrategyAdapter.gen_ft_schedule(self.adapter, dump_loc) - def validate_schedule(self) -> Tuple[int, int]: + def validate_schedule(self) -> tuple[int, int]: """Validate schedule with canonical naming. Delegates to base StrategyAdapter implementation via the adapter reference. diff --git a/src/interpretune/adapters/registration.py b/src/interpretune/adapters/registration.py index 9f248a24..3d7a97d8 100644 --- a/src/interpretune/adapters/registration.py +++ b/src/interpretune/adapters/registration.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Tuple, Callable, Type, Protocol, Set, runtime_checkable, List, Sequence, cast +from typing import Any, Tuple, Callable, Type, Protocol, Set, runtime_checkable, Sequence, cast from inspect import getmembers, isclass from typing_extensions import override from types import ModuleType @@ -17,9 +17,9 @@ def register( self, lead_adapter: Adapter, component_key: str, - adapter_combination: Tuple[Adapter | str], - composition_classes: Tuple[Callable[..., Any], ...], - description: Optional[str] = None, + adapter_combination: tuple[Adapter | str], + composition_classes: tuple[Callable[..., Any], ...], + description: str | None = None, ) -> None: """Registers valid component + adapter compositions mapped to composition keys with required metadata. @@ -27,10 +27,10 @@ def register( lead_adapter: The adapter registering this set of valid compositions (e.g. LightningAdapter) component_key: The name of the component (e.g. "datamodule") adapter_combination: tuple identifying the valid adapter composition - composition_classes: Tuple[Callable, ...], + composition_classes: tuple[Callable, ...], description : composition description """ - supported_composition: Dict[str | Adapter | Tuple[Adapter | str], Any] = {} + supported_composition: dict[str | Adapter | tuple[Adapter | str], Any] = {} composition_key = (component_key,) + self.canonicalize_composition(adapter_combination) supported_composition[composition_key] = composition_classes supported_composition["lead_adapter"] = Adapter[lead_adapter] if isinstance(lead_adapter, str) else lead_adapter @@ -39,8 +39,8 @@ def register( @staticmethod def resolve_adapter_filter( - adapter_filter: Optional[Sequence[Adapter | str] | Adapter | str] = None, - ) -> List[Adapter]: + adapter_filter: Sequence[Adapter | str] | Adapter | str | None = None, + ) -> list[Adapter]: unresolved_filters = [] if adapter_filter is None: return [] @@ -79,7 +79,7 @@ def canonicalize_composition(self, adapter_ctx: Sequence[Adapter | str]) -> Tupl return adapter_ctx @override - def get(self, composition_key: Tuple[Adapter | str], default: Any = None) -> Any: + def get(self, composition_key: tuple[Adapter | str], default: Any = None) -> Any: if composition_key in self: supported_composition = self[composition_key] return supported_composition[composition_key] @@ -94,11 +94,11 @@ def get(self, composition_key: Tuple[Adapter | str], default: Any = None) -> Any ) raise KeyError(err_msg) - def remove(self, composition_key: Tuple[Adapter | str]) -> None: + def remove(self, composition_key: tuple[Adapter | str]) -> None: """Removes the registered adapter composition by name.""" del self[composition_key] - def available_compositions(self, adapter_filter: Optional[Sequence[Adapter | str] | Adapter | str] = None) -> Set: + def available_compositions(self, adapter_filter: Sequence[Adapter | str] | Adapter | str | None = None) -> Set: """Returns a list of registered adapters, optionally filtering by the lead adapter that registered the valid composition.""" if adapter_filter is not None: diff --git a/src/interpretune/adapters/transformer_lens.py b/src/interpretune/adapters/transformer_lens.py index c0dbfa82..4e4f16d3 100644 --- a/src/interpretune/adapters/transformer_lens.py +++ b/src/interpretune/adapters/transformer_lens.py @@ -1,6 +1,6 @@ import os import re -from typing import Optional, Type, cast, Union, Dict, Any, List, Mapping, Tuple +from typing import Type, cast, Any, Mapping import inspect from functools import reduce, partial from copy import deepcopy @@ -33,7 +33,7 @@ class TLensAttributeMixin: @property - def tl_cfg(self) -> Optional[Union[HookedTransformerConfig, TransformerBridgeConfig]]: + def tl_cfg(self) -> HookedTransformerConfig | TransformerBridgeConfig | None: try: cfg = reduce(getattr, "model.cfg".split("."), self) except AttributeError as ae: @@ -43,8 +43,8 @@ def tl_cfg(self) -> Optional[Union[HookedTransformerConfig, TransformerBridgeCon # TODO: we aren't using IT's Property Composition feature for TLens yet, but might be worth enabling it @property - def device(self) -> Optional[torch.device]: - device: Optional[torch.device] = None + def device(self) -> torch.device | None: + device: torch.device | None = None try: device = ( getattr(self._it_state, "_device", None) # type: ignore[attr-defined] # provided by mixing class @@ -57,12 +57,12 @@ def device(self) -> Optional[torch.device]: return device @device.setter - def device(self, value: Optional[str | torch.device]) -> None: + def device(self, value: str | torch.device | None) -> None: if value is not None and not isinstance(value, torch.device): value = torch.device(value) self._it_state._device = value # type: ignore[attr-defined] # provided by mixing class - def get_tl_device(self) -> Optional[torch.device]: + def get_tl_device(self) -> torch.device | None: """Get the best available device based on TransformerLens config.""" try: if self.tl_cfg is None: @@ -76,11 +76,11 @@ def get_tl_device(self) -> Optional[torch.device]: return device @property - def output_device(self) -> Optional[torch.device]: + def output_device(self) -> torch.device | None: return self.get_tl_device() # type: ignore[attr-defined] # provided by mixing class @property - def input_device(self) -> Optional[torch.device]: + def input_device(self) -> torch.device | None: return self.get_tl_device() @@ -127,7 +127,7 @@ def hf_pretrained_model_init(self) -> None: self._convert_hf_to_tl() def hf_configured_model_init( - self, cust_config: HFPretrainedConfig, access_token: Optional[str] = None + self, cust_config: HFPretrainedConfig, access_token: str | None = None ) -> torch.nn.Module: # usually makes sense to init the HookedTransfomer (empty) and pretrained HF model weights on cpu # versus moving them both to GPU (may make sense to explore meta device usage for model definition @@ -179,7 +179,7 @@ def tl_config_model_init(self) -> None: tl_kwargs = {k: v for k, v in self.it_cfg.tl_cfg.__dict__.items() if k not in ["use_bridge"]} self.model = HookedTransformer(tokenizer=self.it_cfg.tokenizer, **tl_kwargs) - def _prune_tl_cfg_dict(self, prune_list: Optional[list] = None) -> dict: + def _prune_tl_cfg_dict(self, prune_list: list | None = None) -> dict: """Prunes the tl_cfg dictionary by removing IT-specific and HF-specific keys that shouldn't be passed to HookedTransformer/TransformerBridge constructors. @@ -400,8 +400,8 @@ class TransformerBridgeStrategyAdapter(StrategyAdapter): def __init__( self, - model_view: Union[None, str, Type[ModelView], ModelView] = None, - model_view_cfg: Optional[Dict[str, Any]] = None, + model_view: None | str | Type[ModelView] | ModelView = None, + model_view_cfg: dict[str, Any] | None = None, use_tl_names: bool = False, *args, **kwargs, @@ -415,10 +415,10 @@ def __init__( # Store initialization parameters for deferred model_view creation # (model_view needs access to adapter, so we create it in on_before_init_fts) # If nothing specified, we'll use CanonicalModelView (identity transformation) - self._model_view_init: Union[None, str, Type[ModelView], ModelView] = model_view - self._model_view_cfg: Dict[str, Any] = model_view_cfg or {} + self._model_view_init: None | str | Type[ModelView] | ModelView = model_view + self._model_view_cfg: dict[str, Any] = model_view_cfg or {} self._use_tl_names: bool = use_tl_names - self.model_view: Optional[ModelView] = None + self.model_view: ModelView | None = None # Always use translation function (even for canonical mode via CanonicalModelView) self.exec_ft_phase = partial(StrategyAdapter.base_ft_phase, translation_func=self.logical_param_translation) @@ -470,7 +470,7 @@ def on_before_init_fts(self) -> None: # Ensure model_view is initialized (may already be initialized if gen_ft_sched_only=True) self._ensure_model_view_initialized() - def fts_optim_transform(self, orig_pl: List[str], inspect_only: bool = False) -> List[str]: + def fts_optim_transform(self, orig_pl: list[str], inspect_only: bool = False) -> list[str]: """Transform parameter names to canonical names for optimizer. Delegates transformation to the active model view. @@ -485,7 +485,7 @@ def fts_optim_transform(self, orig_pl: List[str], inspect_only: bool = False) -> assert self.model_view is not None return self.model_view.transform_to_canonical(orig_pl, inspect_only=inspect_only) - def logical_param_translation(self, param_names: List[str]) -> List[str]: + def logical_param_translation(self, param_names: list[str]) -> list[str]: """Translate canonical parameter names to model view names. Delegates transformation to the active model view. @@ -499,25 +499,25 @@ def logical_param_translation(self, param_names: List[str]) -> List[str]: assert self.model_view is not None return self.model_view.transform_from_canonical(param_names) - def get_named_params_for_schedule_validation(self) -> Dict[str, torch.nn.Parameter]: + def get_named_params_for_schedule_validation(self) -> dict[str, torch.nn.Parameter]: """Get named parameters for schedule validation. Delegates to the active model view for parameter naming. Returns: - Dict[str, torch.nn.Parameter]: A dictionary mapping parameter names to parameter tensors. + dict[str, torch.nn.Parameter]: A dictionary mapping parameter names to parameter tensors. """ self._ensure_model_view_initialized() assert self.model_view is not None return self.model_view.get_named_params() # type: ignore[return-value] - def validate_ft_sched(self) -> Tuple[int, int]: + def validate_ft_sched(self) -> tuple[int, int]: """Validate the fine-tuning schedule. Delegates to the active model view's validation method. Returns: - Tuple[int, int]: A tuple of ints specifying: + tuple[int, int]: A tuple of ints specifying: 1. The depth of the final scheduled phase 2. The maximum epoch watermark explicitly specified in the schedule """ @@ -530,7 +530,7 @@ def validate_ft_sched(self) -> Tuple[int, int]: assert self.model_view is not None return self.model_view.validate_schedule() - def gen_ft_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + def gen_ft_schedule(self, dump_loc: str | os.PathLike) -> os.PathLike | None: """Generate fine-tuning schedule using active model view naming. Delegates to the active model view's generation method. @@ -583,9 +583,9 @@ class TLNamesModelView(ModelView): def __init__(self, adapter: "StrategyAdapter", implicit_ln_thaw: bool = True): super().__init__(adapter) self.implicit_ln_thaw = implicit_ln_thaw - self._tl_to_canonical_mapping: Optional[Dict[str, List[str]]] = None - self._canonical_to_tl_mapping: Optional[Dict[str, str]] = None - self._unmapped_canonical_params: Optional[set] = None + self._tl_to_canonical_mapping: dict[str, list[str]] | None = None + self._canonical_to_tl_mapping: dict[str, str] | None = None + self._unmapped_canonical_params: set | None = None def build_param_mapping(self) -> None: """Build bidirectional parameter name mappings using component structure tracing. @@ -617,7 +617,7 @@ def build_param_mapping(self) -> None: canonical_params = dict(self.pl_module.named_parameters()) # Build index of canonical params by data_ptr for efficient lookup - canonical_by_ptr: Dict[int, List[str]] = {} + canonical_by_ptr: dict[int, list[str]] = {} for name, tensor in canonical_params.items(): ptr = tensor.data_ptr() if ptr not in canonical_by_ptr: @@ -676,7 +676,7 @@ def build_param_mapping(self) -> None: f"{sum(len(v) for v in self._tl_to_canonical_mapping.values())} canonical parameters" ) - def transform_to_canonical(self, param_names: List[str], inspect_only: bool = False) -> List[str]: + def transform_to_canonical(self, param_names: list[str], inspect_only: bool = False) -> list[str]: """Transform TL-style parameter names to canonical names for optimizer. If implicit_ln_thaw=True, this method also appends LayerNorm parameters @@ -730,7 +730,7 @@ def transform_to_canonical(self, param_names: List[str], inspect_only: bool = Fa return canonical_params - def transform_from_canonical(self, param_names: List[str]) -> List[str]: + def transform_from_canonical(self, param_names: list[str]) -> list[str]: """Translate canonical parameter names to TL-style names. Args: @@ -764,17 +764,17 @@ def transform_from_canonical(self, param_names: List[str]) -> List[str]: return unique_tl_params - def get_named_params(self) -> Dict[str, torch.Tensor]: + def get_named_params(self) -> dict[str, torch.Tensor]: """Get named parameters for schedule validation. Returns TL-style parameter names from the TransformerBridge. Returns: - Dict[str, torch.Tensor]: A dictionary mapping TL-style names to parameter tensors. + dict[str, torch.Tensor]: A dictionary mapping TL-style names to parameter tensors. """ return dict(self.pl_module.model.tl_named_parameters()) # type: ignore[attr-defined] - def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + def gen_schedule(self, dump_loc: str | os.PathLike) -> os.PathLike | None: """Generate fine-tuning schedule using TL-style parameter names. Generates schedule with clean TL-style names (e.g., blocks.9.attn.W_Q). @@ -791,8 +791,8 @@ def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLik rank_zero_debug("TLNamesModelView.gen_schedule() called") rank_zero_info(f"Generating TL-style fine-tuning schedule for {self.pl_module.__class__.__name__}") - param_lists: List = [] - cur_group: List = [] + param_lists: list = [] + cur_group: list = [] # Use TL-style parameter names model_params = list(self.pl_module.model.tl_named_parameters())[::-1] # type: ignore[attr-defined] @@ -818,7 +818,7 @@ def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLik assert dump_loc is not None return ScheduleImplMixin.save_schedule(schedule_name, layer_config, dump_loc) - def validate_schedule(self) -> Tuple[int, int]: + def validate_schedule(self) -> tuple[int, int]: """Validate the fine-tuning schedule with TL-style parameter mapping diagnostics. Logs diagnostic information about the parameter mappings before delegating @@ -826,7 +826,7 @@ def validate_schedule(self) -> Tuple[int, int]: which canonical LayerNorm params are unmapped. Returns: - Tuple[int, int]: A tuple of ints specifying: + tuple[int, int]: A tuple of ints specifying: 1. The depth of the final scheduled phase 2. The maximum epoch watermark explicitly specified in the schedule """ @@ -863,7 +863,7 @@ def validate_schedule(self) -> Tuple[int, int]: # Private helper methods for component structure tracing - def _get_underlying_component_tensor(self, tl_name: str, bridge: Any) -> Optional[torch.Tensor]: + def _get_underlying_component_tensor(self, tl_name: str, bridge: Any) -> torch.Tensor | None: """Get the underlying component tensor for a TL-style parameter name. Maps TL names like 'blocks.0.attn.W_Q' to the actual component tensor @@ -921,7 +921,7 @@ def _get_underlying_component_tensor(self, tl_name: str, bridge: Any) -> Optiona return None - def _get_attn_component_tensor(self, attn: Any, param_name: str) -> Optional[torch.Tensor]: + def _get_attn_component_tensor(self, attn: Any, param_name: str) -> torch.Tensor | None: """Get attention component tensor. Maps TL attention param names to underlying component tensors: @@ -952,7 +952,7 @@ def _get_attn_component_tensor(self, attn: Any, param_name: str) -> Optional[tor return None - def _get_mlp_component_tensor(self, mlp: Any, param_name: str) -> Optional[torch.Tensor]: + def _get_mlp_component_tensor(self, mlp: Any, param_name: str) -> torch.Tensor | None: """Get MLP component tensor. Maps TL MLP param names to underlying component tensors: @@ -983,7 +983,7 @@ def _get_mlp_component_tensor(self, mlp: Any, param_name: str) -> Optional[torch return None - def _get_ln_component_tensor(self, ln: Any, param_name: str) -> Optional[torch.Tensor]: + def _get_ln_component_tensor(self, ln: Any, param_name: str) -> torch.Tensor | None: """Get LayerNorm/RMSNorm component tensor. Maps TL LayerNorm param names to underlying component tensors: @@ -997,7 +997,7 @@ def _get_ln_component_tensor(self, ln: Any, param_name: str) -> Optional[torch.T return None - def _get_implicit_layernorm_params(self, tl_param_names: List[str]) -> List[str]: + def _get_implicit_layernorm_params(self, tl_param_names: list[str]) -> list[str]: """Get implicit LayerNorm canonical params for the layers referenced by TL params. TL-style nomenclature doesn't include LayerNorm parameters, but canonical training @@ -1020,7 +1020,7 @@ def _get_implicit_layernorm_params(self, tl_param_names: List[str]) -> List[str] if not self._unmapped_canonical_params: return [] - implicit_ln_params: List[str] = [] + implicit_ln_params: list[str] = [] layer_indices_seen: set = set() has_embed_params = False diff --git a/src/interpretune/analysis/cache.py b/src/interpretune/analysis/cache.py index c81be32e..d2124c35 100644 --- a/src/interpretune/analysis/cache.py +++ b/src/interpretune/analysis/cache.py @@ -22,7 +22,6 @@ import shutil import tempfile from pathlib import Path -from typing import Optional from datasets.fingerprint import ( generate_random_fingerprint, @@ -96,7 +95,7 @@ def _create_tempdir() -> Path: return tmpdir -def get_analysis_cache_dir(module, explicit_cache_dir: Optional[str | Path] = None) -> Path: +def get_analysis_cache_dir(module, explicit_cache_dir: str | Path | None = None) -> Path: """Return the cache directory to use for analysis for the given module. Behavior: diff --git a/src/interpretune/analysis/core.py b/src/interpretune/analysis/core.py index 73848694..6e9f350b 100644 --- a/src/interpretune/analysis/core.py +++ b/src/interpretune/analysis/core.py @@ -1,6 +1,6 @@ from __future__ import annotations # see PEP 749, no longer needed when 3.13 reaches EOL from dataclasses import dataclass, field -from typing import Literal, NamedTuple, Optional, Any, Callable, Sequence, Union, List, Dict, Type +from typing import Literal, NamedTuple, Any, Callable, Sequence, List, Dict, Type from types import MappingProxyType import os from pathlib import Path @@ -393,7 +393,7 @@ def __init__( # Set default format as our custom analysis formatter self.dataset.set_format(type="interpretune") # type: ignore[attr-defined] # datasets API supports set_format - def _format_columns(self, cols_to_fetch: list[str], indices: Optional[Union[int, slice, list[int]]] = None) -> dict: + def _format_columns(self, cols_to_fetch: list[str], indices: int | slice | list[int] | None = None) -> dict: """Internal helper to format specified columns into tensors with proper shape reconstruction. Args: @@ -477,7 +477,7 @@ def _is_tensor_seq(self, data) -> bool: or (isinstance(data, list) and all(isinstance(x, torch.Tensor) for x in data)) ) - def __getitem__(self, key: Union[str, List[str], int, slice]) -> Union[List, Dict]: + def __getitem__(self, key: str | list[str] | int | slice) -> List | Dict: """Enable direct column/row access similar to HF Dataset. Args: @@ -540,7 +540,7 @@ def __getitem__(self, key: Union[str, List[str], int, slice]) -> Union[List, Dic assert self.dataset is not None, "Dataset should be loaded before accessing rows" return self.dataset[key] # type: ignore[index] # datasets support various indexing types - def select_columns(self, column_names: List[str]) -> "AnalysisStore": + def select_columns(self, column_names: list[str]) -> "AnalysisStore": """Select a subset of columns. Args: diff --git a/src/interpretune/analysis/formatters.py b/src/interpretune/analysis/formatters.py index 2a9b4933..edbbaaa0 100644 --- a/src/interpretune/analysis/formatters.py +++ b/src/interpretune/analysis/formatters.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Union, Optional +from typing import Any import torch import numpy as np from datasets.formatting import TorchFormatter @@ -23,7 +23,7 @@ def __init__(self, col_cfg: dict | None = None, **kwargs): self._field_context = [] @contextmanager - def field_context(self, field_info: Union[tuple[Optional[str], dict], Optional[str]]): + def field_context(self, field_info: tuple[str | None, dict] | str | None): """Context manager to track the current field being processed.""" if isinstance(field_info, str) or field_info is None: field_name = field_info @@ -35,13 +35,13 @@ def field_context(self, field_info: Union[tuple[Optional[str], dict], Optional[s finally: self._field_context.pop() - def is_field_non_tensor(self, field_name: Optional[str]) -> bool: + def is_field_non_tensor(self, field_name: str | None) -> bool: """Check if current field or any parent field in context is marked as non-tensor.""" if field_name is not None and field_name in self.non_tensor_fields: return True return any(context[0] in self.non_tensor_fields for context in self._field_context if context[0] is not None) - def is_field_per_latent(self, field_name: Optional[str]) -> bool: + def is_field_per_latent(self, field_name: str | None) -> bool: """Check if current field or any parent field in context is marked as per_latent.""" if field_name is not None and field_name in self.per_latent_fields: return True @@ -58,7 +58,7 @@ def handle_per_latent_dict(self, value: dict, tensorize_fn) -> dict: return {int(k): tensorize_fn(v) for k, v in zip(latents, per_latent_values)} return value - def apply_dynamic_dimension(self, tensor: torch.Tensor, field_name: Optional[str]) -> torch.Tensor: + def apply_dynamic_dimension(self, tensor: torch.Tensor, field_name: str | None) -> torch.Tensor: """Apply dynamic dimension transformation if configured.""" dyn_dim = self.dyn_dims.get(field_name) curr_tensor_dim = tensor.dim() @@ -91,7 +91,7 @@ def __init__(self, features=None, **format_kwargs): col_cfg = format_kwargs.pop("col_cfg", {}) super().__init__(col_cfg=col_cfg, features=features, **format_kwargs) - def _tensorize(self, value: Any, field_name: Optional[str] = None) -> Any: + def _tensorize(self, value: Any, field_name: str | None = None) -> Any: """Enhanced tensorization with support for non-tensor fields, per-latent transformations and non-zero dynamic dimensions.""" if isinstance(value, (str, bytes, type(None))): @@ -145,7 +145,7 @@ def _recursive_tensorize(self, data_struct: Any) -> Any: current_field = self._field_context[-1][0] if self._field_context else None return self._tensorize(data_struct, current_field) - # TODO: validate that we don't want to allow Union[torch.Tensor, Sequence] return type + # TODO: validate that we don't want to allow torch.Tensor | Sequence return type def format_column(self, pa_table: "pa.Table") -> torch.Tensor: # type: ignore[override] """Format a column with enhanced tensorization.""" column = self.numpy_arrow_extractor().extract_column(pa_table) diff --git a/src/interpretune/analysis/ops/auto_columns.py b/src/interpretune/analysis/ops/auto_columns.py index 699a3171..7208b77a 100644 --- a/src/interpretune/analysis/ops/auto_columns.py +++ b/src/interpretune/analysis/ops/auto_columns.py @@ -1,7 +1,7 @@ """Auto-columns system for analysis operations.""" from __future__ import annotations -from typing import Dict, Any, Union, Literal, Tuple +from typing import Any, Literal from dataclasses import dataclass, field from interpretune.analysis.ops.base import ColCfg @@ -12,9 +12,9 @@ class FieldCondition: """Represents a condition for a field in a schema.""" field_name: str - conditions: Dict[str, Any] = field(default_factory=dict) + conditions: dict[str, Any] = field(default_factory=dict) - def matches(self, field_config: Dict[str, Any] | ColCfg) -> bool: + def matches(self, field_config: dict[str, Any] | ColCfg) -> bool: """Check if a field configuration matches this condition.""" if not isinstance(field_config, (dict, ColCfg)): raise TypeError(f"Expected field_config to be dict or ColCfg, got {type(field_config)}") @@ -33,8 +33,8 @@ def matches(self, field_config: Dict[str, Any] | ColCfg) -> bool: class AutoColumnCondition: """Represents a complete condition set for triggering auto-columns.""" - field_conditions: Tuple[FieldCondition, ...] - auto_columns: Dict[str, Union[ColCfg, Dict[str, Any]]] + field_conditions: tuple[FieldCondition, ...] + auto_columns: dict[str, ColCfg | dict[str, Any]] condition_target: Literal["input_schema", "output_schema"] = "input_schema" def __post_init__(self): @@ -42,7 +42,7 @@ def __post_init__(self): if not isinstance(self.field_conditions, tuple): object.__setattr__(self, "field_conditions", tuple(self.field_conditions)) - def matches_schema(self, input_schema: Dict[str, Any], output_schema: Dict[str, Any] = None) -> bool: # type: ignore[assignment] + def matches_schema(self, input_schema: dict[str, Any], output_schema: dict[str, Any] = None) -> bool: # type: ignore[assignment] """Check if schemas match all field conditions.""" # Select the schema the condition should apply to based on condition_target condition_schema = input_schema if self.condition_target == "input_schema" else (output_schema or {}) diff --git a/src/interpretune/analysis/ops/base.py b/src/interpretune/analysis/ops/base.py index f0232112..25930131 100644 --- a/src/interpretune/analysis/ops/base.py +++ b/src/interpretune/analysis/ops/base.py @@ -1,7 +1,7 @@ """Base classes for analysis operations.""" from __future__ import annotations # see PEP 749, no longer needed when 3.13 reaches EOL -from typing import Literal, Union, Optional, Any, Dict, Callable, Sequence +from typing import Literal, Any, Callable, Sequence from dataclasses import dataclass, fields from contextlib import contextmanager import os @@ -148,14 +148,14 @@ class ColCfg: datasets_dtype: str # Explicit datasets dtype string (e.g. "float32", "int64") required: bool = True - dyn_dim: Optional[int] = None - dyn_dim_ceil: Optional[DIM_VAR] = None # helper for dynamic dimension handling in some contexts + dyn_dim: int | None = None + dyn_dim_ceil: DIM_VAR | None = None # helper for dynamic dimension handling in some contexts non_tensor: bool = False per_latent: bool = False per_sae_hook: bool = False # For fields that have per-SAE hook subfields intermediate_only: bool = False # Indicates column used in processing but not written to output connected_obj: Literal["analysis_store", "datamodule"] = "analysis_store" - array_shape: tuple[Optional[Union[int, DIM_VAR]], ...] | None = None # Shape with optional dimension variables + array_shape: tuple[int | DIM_VAR | None, ...] | None = None # Shape with optional dimension variables sequence_type: bool = True # Default to sequence type for most fields array_dtype: str | None = None # Override for array fields, defaults to datasets_dtype @@ -228,7 +228,7 @@ def wrap_summary( tokenizer: PreTrainedTokenizerBase | None = None, save_prompts: bool = False, save_tokens: bool = False, - decode_kwargs: Optional[dict[str, Any]] = None, + decode_kwargs: dict[str, Any] | None = None, ) -> BaseAnalysisBatchProtocol: decode_kwargs = decode_kwargs or {} if save_prompts: @@ -269,9 +269,9 @@ def __init__( name: str, description: str, output_schema: OpSchema, - input_schema: Optional[OpSchema] = None, - aliases: Optional[Sequence[str]] = None, - impl_params: Optional[Dict[str, Any]] = None, + input_schema: OpSchema | None = None, + aliases: Sequence[str] | None = None, + impl_params: dict[str, Any] | None = None, ) -> None: self.name = name self.description = description @@ -279,7 +279,7 @@ def __init__( self.input_schema = input_schema self._ctx_key = None self._aliases = aliases # Store aliases for the operation - self._impl: Optional[Callable] = None + self._impl: Callable | None = None self.impl_params = impl_params or {} @property @@ -302,7 +302,7 @@ def active_ctx_key(self, ctx_key): self._ctx_key = original_ctx_key def _validate_input_schema( - self, analysis_batch: Optional[BaseAnalysisBatchProtocol], batch: Optional[BatchEncoding] + self, analysis_batch: BaseAnalysisBatchProtocol | None, batch: BatchEncoding | None ) -> None: """Validate that required inputs defined in input_schema exist in analysis_batch or batch.""" if self.input_schema is None: @@ -346,7 +346,7 @@ def process_batch( tokenizer: PreTrainedTokenizerBase | None = None, save_prompts: bool = False, save_tokens: bool = False, - decode_kwargs: Optional[dict[str, Any]] = None, + decode_kwargs: dict[str, Any] | None = None, ) -> BaseAnalysisBatchProtocol: """Process analysis batch using provided output schema. @@ -416,7 +416,7 @@ def save_batch( tokenizer: PreTrainedTokenizerBase | None = None, save_prompts: bool = False, save_tokens: bool = False, - decode_kwargs: Optional[dict[str, Any]] = None, + decode_kwargs: dict[str, Any] | None = None, ) -> BaseAnalysisBatchProtocol: """Save analysis batch using process_batch static method.""" return self.process_batch( @@ -459,13 +459,13 @@ def __reduce__(self): return (_reconstruct_op, (self.__class__, self.__dict__.copy())) @property - def impl(self) -> Optional[Callable]: + def impl(self) -> Callable | None: """Get the implementation function.""" return self._impl def _resolve_call_params( self, impl_func: Callable, module, analysis_batch, batch, batch_idx, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Resolve parameters to pass to the implementation function using smart parameter detection.""" import inspect @@ -504,10 +504,10 @@ def _call_with_resolved_params(self, module, analysis_batch, batch, batch_idx, * def __call__( self, - module: Optional[torch.nn.Module] = None, - analysis_batch: Optional[BaseAnalysisBatchProtocol] = None, - batch: Optional[BatchEncoding] = None, - batch_idx: Optional[int] = None, + module: torch.nn.Module | None = None, + analysis_batch: BaseAnalysisBatchProtocol | None = None, + batch: BatchEncoding | None = None, + batch_idx: int | None = None, **kwargs, ) -> BaseAnalysisBatchProtocol: """Execute the operation using the configured implementation.""" @@ -533,9 +533,9 @@ class CompositeAnalysisOp(AnalysisOp): def __init__( self, ops: Sequence[AnalysisOp], - name: Optional[str] = None, - aliases: Optional[Sequence[str]] = None, - description: Optional[str] = None, + name: str | None = None, + aliases: Sequence[str] | None = None, + description: str | None = None, *args, **kwargs, ) -> None: @@ -570,10 +570,10 @@ def __init__( def __call__( self, - module: Optional[torch.nn.Module] = None, - analysis_batch: Optional[BaseAnalysisBatchProtocol] = None, - batch: Optional[BatchEncoding] = None, - batch_idx: Optional[int] = None, + module: torch.nn.Module | None = None, + analysis_batch: BaseAnalysisBatchProtocol | None = None, + batch: BatchEncoding | None = None, + batch_idx: int | None = None, **kwargs, ) -> BaseAnalysisBatchProtocol: """Execute all operations in sequence with automatic parameter resolution.""" diff --git a/src/interpretune/analysis/ops/compiler/cache_manager.py b/src/interpretune/analysis/ops/compiler/cache_manager.py index 1be6db0f..0e4e1cc0 100644 --- a/src/interpretune/analysis/ops/compiler/cache_manager.py +++ b/src/interpretune/analysis/ops/compiler/cache_manager.py @@ -4,7 +4,7 @@ import re import hashlib import importlib.util -from typing import List, Dict, Any, Optional +from typing import Any from pathlib import Path from dataclasses import dataclass, field @@ -27,13 +27,13 @@ class OpDef: implementation: str input_schema: OpSchema output_schema: OpSchema - aliases: List[str] = field(default_factory=list) - importable_params: Dict[str, str] = field(default_factory=dict) - normal_params: Dict[str, Any] = field(default_factory=dict) - required_ops: List[str] = field(default_factory=list) - composition: Optional[List[str]] = None + aliases: list[str] = field(default_factory=list) + importable_params: dict[str, str] = field(default_factory=dict) + normal_params: dict[str, Any] = field(default_factory=dict) + required_ops: list[str] = field(default_factory=list) + composition: list[str] | None = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert to dictionary format for compatibility with existing code.""" result = { "name": self.name, @@ -67,7 +67,7 @@ def from_path(cls, path: Path) -> "YamlFileInfo": return cls(path, stat.st_mtime, content_hash) -def _get_latest_revision(repo: CachedRepoInfo) -> Optional[CachedRevisionInfo]: +def _get_latest_revision(repo: CachedRepoInfo) -> CachedRevisionInfo | None: """Get the latest revision for a repository, preferring 'main' ref. Args: @@ -102,8 +102,8 @@ class OpDefinitionsCacheManager: def __init__(self, cache_dir: Path): self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) - self._yaml_files: List[YamlFileInfo] = [] - self._fingerprint: Optional[str] = None + self._yaml_files: list[YamlFileInfo] = [] + self._fingerprint: str | None = None def add_yaml_file(self, yaml_file: Path) -> None: """Add a YAML file to be monitored for changes.""" @@ -119,7 +119,7 @@ def add_yaml_file(self, yaml_file: Path) -> None: # Skip files that don't exist anymore pass - def add_hub_yaml_files(self) -> List[Path]: + def add_hub_yaml_files(self) -> list[Path]: """Add hub YAML files to monitoring.""" hub_yaml_files = [] try: @@ -146,7 +146,7 @@ def add_hub_yaml_files(self) -> List[Path]: rank_zero_debug(f"[ANALYSIS_HUB_CACHE] Returning {len(hub_yaml_files)} YAML files") return hub_yaml_files # type: ignore[return-value] - def discover_hub_yaml_files(self) -> List[Path]: + def discover_hub_yaml_files(self) -> list[Path]: """Discover YAML files from the most recent revision of each hub repository. Uses HuggingFace's cache manager to efficiently find YAML files only from @@ -352,7 +352,7 @@ def is_cache_valid(self) -> bool: rank_zero_debug("[ANALYSIS_HUB_CACHE] Cache is valid") return True - def _generate_module_content(self, op_definitions: Dict[str, OpDef]) -> str: + def _generate_module_content(self, op_definitions: dict[str, OpDef]) -> str: """Generate Python module content for the cache.""" lines = [ "# GENERATED FILE - DO NOT EDIT", @@ -450,7 +450,7 @@ def _serialize_col_cfg(self, col_cfg: ColCfg) -> str: return f"ColCfg({', '.join(args)})" - def save_cache(self, op_definitions: Dict[str, OpDef]) -> Path: + def save_cache(self, op_definitions: dict[str, OpDef]) -> Path: """Save operation definitions to cache.""" cache_path = self._get_cache_module_path() @@ -465,7 +465,7 @@ def save_cache(self, op_definitions: Dict[str, OpDef]) -> Path: return cache_path - def load_cache(self) -> Optional[Dict[str, OpDef]]: + def load_cache(self) -> dict[str, OpDef] | None: """Load operation definitions from cache.""" if not self.is_cache_valid(): return None diff --git a/src/interpretune/analysis/ops/compiler/schema_compiler.py b/src/interpretune/analysis/ops/compiler/schema_compiler.py index b7ef3a68..bcec85bf 100644 --- a/src/interpretune/analysis/ops/compiler/schema_compiler.py +++ b/src/interpretune/analysis/ops/compiler/schema_compiler.py @@ -1,6 +1,6 @@ """Schema compiler for analysis operations to maintain field propagation while minimizing schemas.""" -from typing import Dict, List, Tuple, Any, Union, TypeVar, Callable, Set +from typing import Dict, Any, TypeVar, Callable from dataclasses import replace from copy import deepcopy import re @@ -15,12 +15,12 @@ def _compile_composition_schema_core( - operations: List[Any], - get_schemas_fn: Callable[[Any], Tuple[Dict[str, T_Field], Dict[str, T_Field]]], + operations: list[Any], + get_schemas_fn: Callable[[Any], tuple[dict[str, T_Field], dict[str, T_Field]]], is_intermediate_fn: Callable[[T_Field], bool], handle_object_field_fn: Callable[[T_Field], T_Field], - create_schema_fn: Callable[[Dict[str, T_Field]], T_Schema], -) -> Tuple[T_Schema, T_Schema]: + create_schema_fn: Callable[[dict[str, T_Field]], T_Schema], +) -> tuple[T_Schema, T_Schema]: """Core logic for compiling composition schemas with customizable handling of types. Args: @@ -33,9 +33,9 @@ def _compile_composition_schema_core( Returns: Tuple of (input_schema, output_schema) """ - input_fields: Dict[str, T_Field] = {} - output_fields: Dict[str, T_Field] = {} - intermediate_fields: Dict[str, T_Field] = {} + input_fields: dict[str, T_Field] = {} + output_fields: dict[str, T_Field] = {} + intermediate_fields: dict[str, T_Field] = {} if not operations: raise ValueError("No operations provided for composite schema compilation") @@ -62,8 +62,8 @@ def _compile_composition_schema_core( def jit_compile_composition_schema( - operations: List[Union[str, AnalysisOp]], op_definitions: Dict[str, Dict] -) -> Tuple[OpSchema, OpSchema]: + operations: list[str | AnalysisOp], op_definitions: dict[str, Dict] +) -> tuple[OpSchema, OpSchema]: """Compile the complete schema for a composition of operations using operation definitions. Args: @@ -146,8 +146,8 @@ def create_schema(fields): def compile_operation_composition_schema( - operations: List[str], all_operations_dict: Dict[str, Dict] -) -> Tuple[Dict, Dict]: + operations: list[str], all_operations_dict: dict[str, Dict] +) -> tuple[Dict, Dict]: """Compile the complete schema for a composition of operations. Args: @@ -210,7 +210,7 @@ def create_schema(fields): ) -def resolve_required_ops(op_name: str, op_def: Dict[str, Any], op_definitions: Dict[str, Dict]) -> List[str]: +def resolve_required_ops(op_name: str, op_def: dict[str, Any], op_definitions: dict[str, Dict]) -> list[str]: """Resolve required_ops for an operation, handling namespaced operations. Args: @@ -268,7 +268,7 @@ def resolve_required_ops(op_name: str, op_def: Dict[str, Any], op_definitions: D return resolved_ops -def build_operation_compositions(yaml_config: Dict) -> Dict[str, Any]: +def build_operation_compositions(yaml_config: Dict) -> dict[str, Any]: """Build operation compositions with compiled schemas from YAML configuration. Args: @@ -318,7 +318,7 @@ def build_operation_compositions(yaml_config: Dict) -> Dict[str, Any]: return ops -def _parse_composition_string(composition_str: str) -> List[str]: +def _parse_composition_string(composition_str: str) -> list[str]: """Parse composition string to handle parentheses-wrapped namespaced operations. Examples: @@ -359,8 +359,8 @@ def _parse_composition_string(composition_str: str) -> List[str]: def compile_op_schema( - op_name: str, op_definitions: Dict[str, Dict[str, Any]], _processing: Set[str] | None = None -) -> Dict[str, Any]: # type: ignore[assignment] + op_name: str, op_definitions: dict[str, dict[str, Any]], _processing: set[str] | None = None +) -> dict[str, Any]: # type: ignore[assignment] """Compile operation schema by merging schemas from required operations. Args: diff --git a/src/interpretune/analysis/ops/dispatcher.py b/src/interpretune/analysis/ops/dispatcher.py index 9359fabe..5092fc5c 100644 --- a/src/interpretune/analysis/ops/dispatcher.py +++ b/src/interpretune/analysis/ops/dispatcher.py @@ -1,7 +1,7 @@ """Dispatcher for analysis operations.""" from __future__ import annotations -from typing import Optional, Dict, NamedTuple, List, Tuple, Iterator, Callable, Union, Any +from typing import Dict, NamedTuple, Iterator, Callable, Any from pathlib import Path from functools import wraps from collections import defaultdict @@ -48,7 +48,7 @@ class AnalysisOpDispatcher: # TODO: # - decide whether to make the dispatcher a singleton or not # - decide whether to make the dispatcher thread-safe - def __init__(self, yaml_paths: Optional[Union[Path, List[Path]]] = None, enable_hub_ops: bool = True): + def __init__(self, yaml_paths: Path | list[Path] | None = None, enable_hub_ops: bool = True): # Initialize yaml_paths self.yaml_paths = [Path(p.strip()) for p in IT_ANALYSIS_OP_PATHS] # Start with op_paths @@ -67,7 +67,7 @@ def __init__(self, yaml_paths: Optional[Union[Path, List[Path]]] = None, enable_ ) self.enable_hub_ops = enable_hub_ops - self._op_definitions: Dict[str, OpDef] = {} + self._op_definitions: dict[str, OpDef] = {} self._dispatch_table = {} # {op_name: {context: instantiated_op}} self._aliases = {} # {alias: op_name} self._op_to_aliases = defaultdict(list) # {op_name: [aliases]} @@ -85,7 +85,7 @@ def _normalize_op_name(self, name: str) -> str: # Normalize operation names for consistent lookup (case-insensitive, cross-platform) return name.replace("/", ".").replace("-", "_").lower() - def _discover_yaml_files(self, paths: List[Path]) -> List[Path]: + def _discover_yaml_files(self, paths: list[Path]) -> list[Path]: """Discover all YAML files from the given paths (files or directories).""" yaml_files = [] for path in paths: @@ -147,7 +147,7 @@ def load_definitions(self) -> None: finally: self._loading_in_progress = False - def _load_from_yaml_and_compile(self, yaml_files: List[Path]): + def _load_from_yaml_and_compile(self, yaml_files: list[Path]): """Load from YAML files and compile to cache.""" # Load and merge all YAML files raw_definitions = {} @@ -215,7 +215,7 @@ def _load_from_yaml_and_compile(self, yaml_files: List[Path]): self._loaded = True - def _compile_required_ops_schemas(self, definitions_to_compile: Dict[str, Dict]): + def _compile_required_ops_schemas(self, definitions_to_compile: dict[str, Dict]): """Compile schemas by recursively including required_ops dependencies.""" from interpretune.analysis.ops.compiler.schema_compiler import compile_op_schema @@ -232,7 +232,7 @@ def _compile_required_ops_schemas(self, definitions_to_compile: Dict[str, Dict]) # Remove the operation if it fails to compile definitions_to_compile.pop(op_name, None) - def _convert_raw_definitions_to_opdefs(self, raw_definitions: Dict[str, Dict]): + def _convert_raw_definitions_to_opdefs(self, raw_definitions: dict[str, Dict]): """Convert raw dictionary definitions to OpDef objects.""" for op_name, op_def in raw_definitions.items(): op_name = self._normalize_op_name(op_name) @@ -258,7 +258,7 @@ def _convert_raw_definitions_to_opdefs(self, raw_definitions: Dict[str, Dict]): self._op_definitions[op_name] = op_def_obj - def _apply_hub_namespacing(self, yaml_content: Dict[str, Any], yaml_file: Path) -> Dict[str, Any]: + def _apply_hub_namespacing(self, yaml_content: dict[str, Any], yaml_file: Path) -> dict[str, Any]: """Apply hub namespacing to operations from hub files.""" rank_zero_debug(f"[DISPATCHER] Processing yaml_file: {yaml_file}") @@ -360,7 +360,7 @@ def _populate_aliases_from_definitions(self): self._op_to_aliases[op_name_norm].append(alias_norm) @_ensure_loaded - def list_operations(self) -> List[str]: + def list_operations(self) -> list[str]: """Get a list of all available operation names. Returns: @@ -370,7 +370,7 @@ def list_operations(self) -> List[str]: @property @_ensure_loaded - def registered_ops(self) -> Dict[str, OpDef]: + def registered_ops(self) -> dict[str, OpDef]: """Get all registered operation definitions without instantiating them.""" # TODO: return a generator here instead of a dict? May be better to provide a separate method for that return {name: op_def for name, op_def in self._op_definitions.items()} @@ -384,12 +384,12 @@ def get_op_aliases(self, op_name: str) -> list[str]: return self._op_to_aliases[op_name] @_ensure_loaded - def get_all_aliases(self) -> Iterator[Tuple[str, str]]: + def get_all_aliases(self) -> Iterator[tuple[str, str]]: """Get all registered operation aliases.""" for alias, op_name in self._aliases.items(): yield (alias, op_name) - def _resolve_name_safe(self, op_name: str, visited: Optional[set] = None) -> str: + def _resolve_name_safe(self, op_name: str, visited: set | None = None) -> str: """Safely resolve names with cycle detection.""" if visited is None: visited = set() @@ -516,7 +516,7 @@ def _import_hub_callable(self, op_name: str, op_def: OpDef) -> Callable: return implementation @staticmethod - def _function_param_from_hub_module(param_path: str, implementation: Callable) -> Optional[Callable]: + def _function_param_from_hub_module(param_path: str, implementation: Callable) -> Callable | None: # Try to use the dynamically loaded module if module names match func_name = param_path.rsplit(".", 1)[-1] param_module = param_path.rsplit(".", 1)[0] @@ -611,9 +611,7 @@ def _convert_to_op_schema(self, schema_dict: Dict) -> OpSchema: return OpSchema(result) @_ensure_loaded - def get_op( - self, op_name: str, context: Optional[DispatchContext] = None, lazy: bool = False - ) -> AnalysisOp | Callable: + def get_op(self, op_name: str, context: DispatchContext | None = None, lazy: bool = False) -> AnalysisOp | Callable: """Get an operation by name, optionally instantiating it if needed. Args: @@ -704,7 +702,7 @@ def _maybe_instantiate_op(self, op_ref, context: DispatchContext = DispatchConte return result @_ensure_loaded - def instantiate_all_ops(self) -> Dict[str, AnalysisOp]: + def instantiate_all_ops(self) -> dict[str, AnalysisOp]: """Get all operations as instantiated AnalysisOp objects.""" instantiated_ops = {} @@ -726,7 +724,7 @@ def instantiate_all_ops(self) -> Dict[str, AnalysisOp]: @_ensure_loaded def compile_ops( - self, op_names: str | List[str | AnalysisOp], name: Optional[str] = None, aliases: Optional[List[str]] = None + self, op_names: str | list[str | AnalysisOp], name: str | None = None, aliases: list[str] | None = None ) -> CompositeAnalysisOp: """Create a composition of operations from a list of operation names.""" # See NOTE [Composition and Compilation Limitations] @@ -758,10 +756,10 @@ def compile_ops( def __call__( self, op_name: str, - module: Optional[torch.nn.Module] = None, - analysis_batch: Optional[BaseAnalysisBatchProtocol] = None, - batch: Optional[BatchEncoding] = None, - batch_idx: Optional[int] = None, + module: torch.nn.Module | None = None, + analysis_batch: BaseAnalysisBatchProtocol | None = None, + batch: BatchEncoding | None = None, + batch_idx: int | None = None, ) -> BaseAnalysisBatchProtocol: """Call an operation by name.""" # Support for dot-separated operation names (creating compositions on-demand) diff --git a/src/interpretune/analysis/ops/dynamic_module_utils.py b/src/interpretune/analysis/ops/dynamic_module_utils.py index 7197123e..f6e5269b 100644 --- a/src/interpretune/analysis/ops/dynamic_module_utils.py +++ b/src/interpretune/analysis/ops/dynamic_module_utils.py @@ -12,7 +12,7 @@ import filecmp from pathlib import Path from types import ModuleType -from typing import Callable, Optional, Union, List, Set +from typing import Callable from huggingface_hub import try_to_load_from_cache from transformers.dynamic_module_utils import get_relative_import_files, check_imports @@ -22,7 +22,7 @@ from interpretune.utils.logging import rank_zero_debug, rank_zero_warn # Track paths we've added to sys.path to avoid duplicates -_added_op_paths: Set[str] = set() +_added_op_paths: set[str] = set() logger = logging.get_logger(__name__) _IT_REMOTE_CODE_LOCK = threading.Lock() @@ -48,7 +48,7 @@ def init_it_modules() -> None: importlib.invalidate_caches() -def create_dynamic_module_it(name: Union[str, os.PathLike]) -> None: +def create_dynamic_module_it(name: str | os.PathLike) -> None: """Creates a dynamic module in the interpretune cache directory for modules. Args: @@ -70,7 +70,7 @@ def create_dynamic_module_it(name: Union[str, os.PathLike]) -> None: def get_function_in_module( function_name: str, - module_path: Union[str, os.PathLike], + module_path: str | os.PathLike, *, force_reload: bool = False, ) -> Callable: @@ -96,7 +96,7 @@ def get_function_in_module( if force_reload: sys.modules.pop(name, None) importlib.invalidate_caches() - cached_module: Optional[ModuleType] = sys.modules.get(name) + cached_module: ModuleType | None = sys.modules.get(name) module_spec = importlib.util.spec_from_file_location(name, location=module_file) if module_spec is None: raise ImportError(f"Could not create module spec for {name} from {module_file}") @@ -122,17 +122,17 @@ def get_function_in_module( def get_cached_module_file_it( - op_repo_name_or_path: Union[str, os.PathLike], + op_repo_name_or_path: str | os.PathLike, module_file: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, + cache_dir: str | os.PathLike | None = None, force_download: bool = False, - resume_download: Optional[bool] = None, - proxies: Optional[dict[str, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: bool | str | None = None, + revision: str | None = None, local_files_only: bool = False, - repo_type: Optional[str] = None, - _commit_hash: Optional[str] = None, + repo_type: str | None = None, + _commit_hash: str | None = None, **deprecated_kwargs, ) -> str: """Downloads a module from a local folder or a distant repo and returns its path inside the cached interpretune @@ -155,7 +155,7 @@ def get_cached_module_file_it( exist. resume_download: Deprecated and ignored. All downloads are now resumed by default when possible. - proxies (`Dict[str, str]`, *optional*): + proxies (`dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. @@ -298,16 +298,16 @@ def get_cached_module_file_it( def get_function_from_dynamic_module( function_reference: str, - op_repo_name_or_path: Union[str, os.PathLike], - cache_dir: Optional[Union[str, os.PathLike]] = None, + op_repo_name_or_path: str | os.PathLike, + cache_dir: str | os.PathLike | None = None, force_download: bool = False, - resume_download: Optional[bool] = None, - proxies: Optional[dict[str, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: bool | str | None = None, + revision: str | None = None, local_files_only: bool = False, - repo_type: Optional[str] = None, - code_revision: Optional[str] = None, + repo_type: str | None = None, + code_revision: str | None = None, **kwargs, ) -> Callable: """Extracts a function from a module file, present in the local folder or repository of an operations repo. @@ -337,7 +337,7 @@ def get_function_from_dynamic_module( exist. resume_download: Deprecated and ignored. All downloads are now resumed by default when possible. - proxies (`Dict[str, str]`, *optional*): + proxies (`dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. @@ -419,7 +419,7 @@ def get_function_from_dynamic_module( return get_function_in_module(function_name, final_module, force_reload=force_download) -def ensure_op_paths_in_syspath(op_paths: List[Union[str, Path]]) -> None: +def ensure_op_paths_in_syspath(op_paths: list[str | Path]) -> None: """Ensure all operation paths are in sys.path for module discovery. Args: @@ -459,7 +459,7 @@ def ensure_op_paths_in_syspath(op_paths: List[Union[str, Path]]) -> None: _added_op_paths.add(path_str_resolved) -def remove_op_paths_from_syspath(op_paths: List[Union[str, Path]]) -> None: +def remove_op_paths_from_syspath(op_paths: list[str | Path]) -> None: """Remove operation paths from sys.path. Args: @@ -503,7 +503,7 @@ def cleanup_op_paths() -> None: _added_op_paths.clear() -def get_added_op_paths() -> Set[str]: +def get_added_op_paths() -> set[str]: """Get the set of operation paths we've added to sys.path. Returns: diff --git a/src/interpretune/analysis/ops/hub_manager.py b/src/interpretune/analysis/ops/hub_manager.py index 119cfe78..f1811209 100644 --- a/src/interpretune/analysis/ops/hub_manager.py +++ b/src/interpretune/analysis/ops/hub_manager.py @@ -2,7 +2,6 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, List from dataclasses import dataclass from huggingface_hub import HfApi, snapshot_download @@ -44,7 +43,7 @@ def namespace_prefix(self) -> str: class HubAnalysisOpManager: """Manages downloading and uploading analysis operation definitions from/to Hugging Face Hub.""" - def __init__(self, cache_dir: Optional[Path] = None, token: Optional[str] = None): + def __init__(self, cache_dir: Path | None = None, token: str | None = None): """Initialize the hub manager. Args: @@ -106,7 +105,7 @@ def upload_ops( create_pr: bool = False, private: bool = False, clean_existing: bool = False, - delete_patterns: Optional[List[str]] = None, + delete_patterns: list[str] | None = None, ) -> str: """Upload analysis operations to HuggingFace Hub. @@ -206,7 +205,7 @@ def upload_ops( rank_zero_warn(f"Failed to upload to {repo_id}: {e}") raise - def list_available_collections(self, username: Optional[str] = None) -> List[str]: + def list_available_collections(self, username: str | None = None) -> list[str]: """List available analysis operation collections on the Hub. Args: @@ -235,7 +234,7 @@ def list_available_collections(self, username: Optional[str] = None) -> List[str rank_zero_warn(f"Failed to list collections: {e}") return [] - def discover_hub_ops(self, search_patterns: Optional[List[str]] = None) -> List[HubOpCollection]: + def discover_hub_ops(self, search_patterns: list[str] | None = None) -> list[HubOpCollection]: """Discover and cache analysis operations from the Hub. Args: @@ -267,7 +266,7 @@ def discover_hub_ops(self, search_patterns: Optional[List[str]] = None) -> List[ rank_zero_info(f"Discovered {len(collections)} hub operation collections") return collections - def get_cached_collections(self) -> List[HubOpCollection]: + def get_cached_collections(self) -> list[HubOpCollection]: """Get analysis operation collections that are already cached locally. Returns: diff --git a/src/interpretune/base/call.py b/src/interpretune/base/call.py index 25255bf1..7c0acff9 100644 --- a/src/interpretune/base/call.py +++ b/src/interpretune/base/call.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Union +from typing import Any from interpretune.base import ITDataModule, BaseITModule from interpretune.utils import rank_zero_info @@ -7,7 +7,7 @@ log = logging.getLogger(__name__) -HOOKABLE_ITMODULE = Union[ITDataModule, BaseITModule] +HOOKABLE_ITMODULE = ITDataModule | BaseITModule def it_init(module, datamodule, *args, **kwargs): diff --git a/src/interpretune/base/components/cli.py b/src/interpretune/base/components/cli.py index d1ecc158..502c0ddc 100644 --- a/src/interpretune/base/components/cli.py +++ b/src/interpretune/base/components/cli.py @@ -7,7 +7,7 @@ import logging import weakref from pathlib import Path -from typing import Any, Union, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from collections.abc import Callable, Sequence from typing_extensions import override from functools import reduce @@ -168,7 +168,7 @@ def add_default_arguments_to_parser(self, parser: ArgumentParser) -> None: """Adds default arguments to the parser.""" parser.add_argument( "--seed_everything", - type=Union[bool, int, str, float], + type=bool | int | str | float, default=self.seed_everything_default, help=( "Set to an int to run seed_everything with this value before classes instantiation." diff --git a/src/interpretune/base/components/core.py b/src/interpretune/base/components/core.py index 1ab936ff..4a43e3b7 100644 --- a/src/interpretune/base/components/core.py +++ b/src/interpretune/base/components/core.py @@ -3,7 +3,7 @@ import warnings import logging from datetime import datetime -from typing import Any, Union, TYPE_CHECKING, cast, Optional +from typing import Any, TYPE_CHECKING, cast from functools import reduce, partial from pathlib import Path from copy import deepcopy @@ -279,10 +279,10 @@ def _maybe_dispatch(self, non_dispatch_val: Any | None = None) -> Any | None: """_summary_ Args: - non_dispatch_val (Optional[Any], optional): The value to return if we are not dispatching. Defaults to None. + non_dispatch_val (Any | None, optional): The value to return if we are not dispatching. Defaults to None. Returns: - Optional[Any]: _description_ + Any | None: _description_ """ current_frame = inspect.currentframe() if current_frame is None or current_frame.f_back is None: @@ -310,14 +310,14 @@ def _core_or_framework(self, c2f_map_key: str): return attr_val @property - def core_log_dir(self) -> Optional[StrOrPath]: + def core_log_dir(self) -> StrOrPath | None: result = self._core_or_framework(c2f_map_key="_log_dir") - return cast(Optional[StrOrPath], result) + return cast(StrOrPath | None, result) @property def datamodule(self) -> ITDataModule | None: result = self._core_or_framework(c2f_map_key="_datamodule") - return cast(Union[ITDataModule, None], result) + return cast(ITDataModule | None, result) @property def session_complete(self) -> bool: @@ -328,7 +328,7 @@ def cuda_allocator_history(self) -> bool: return self.it_cfg.memprofiler_cfg.enabled and self.it_cfg.memprofiler_cfg.cuda_allocator_history @property - def dtype(self) -> Union[torch.dtype, "str"] | None: + def dtype(self) -> "torch.dtype | str | None": try: # If `it_cfg` is not present on the object at all, treat this as an unexpected context and return None if not hasattr(self, "it_cfg"): @@ -361,7 +361,7 @@ def device(self) -> torch.device | None: except AttributeError as ae: rank_zero_warn(f"Could not find a device reference (has it been set yet?): {ae}") device = None - return cast(Union[torch.device, None], device) + return cast(torch.device | None, device) @device.setter def device(self, value: str | torch.device | None) -> None: @@ -370,7 +370,7 @@ def device(self, value: str | torch.device | None) -> None: self._it_state._device = value @dtype.setter - def dtype(self, value: Union[torch.dtype, "str"] | None) -> None: + def dtype(self, value: "torch.dtype | str | None") -> None: if value is not None and not isinstance(value, torch.dtype): value = _resolve_dtype(value) self.it_cfg._dtype = value @@ -416,7 +416,7 @@ def current_epoch(self) -> int: return self._core_or_framework(c2f_map_key="_current_epoch") @property - def optimizers(self) -> Optional[list[Optimizable]]: + def optimizers(self) -> list[Optimizable] | None: return self._core_or_framework(c2f_map_key="_it_optimizers") @property diff --git a/src/interpretune/base/components/mixins.py b/src/interpretune/base/components/mixins.py index 1410a548..a5f2bdec 100644 --- a/src/interpretune/base/components/mixins.py +++ b/src/interpretune/base/components/mixins.py @@ -1,7 +1,7 @@ from __future__ import annotations import os import inspect -from typing import Any, TYPE_CHECKING, List, Optional, Dict +from typing import Any, TYPE_CHECKING, Dict from contextlib import contextmanager from functools import wraps @@ -75,7 +75,7 @@ def wrapper(self, *args, **kwargs): class AnalysisStepMixin: @property - def analysis_cfg(self) -> Optional[AnalysisCfgProtocol]: + def analysis_cfg(self) -> AnalysisCfgProtocol | None: if not hasattr(self.it_cfg, "analysis_cfg") or self.it_cfg.analysis_cfg is None: # type: ignore[attr-defined] # mixin provides it_cfg rank_zero_warn("Analysis configuration has not been set.") return @@ -268,10 +268,10 @@ def standardize_logits(self, logits: torch.Tensor) -> torch.Tensor: logits = logits[:, -1:, :] return logits - def labels_to_ids(self, labels: List[str]) -> tuple[torch.Tensor, List[str]]: + def labels_to_ids(self, labels: list[str]) -> tuple[torch.Tensor, list[str]]: return torch.take(self.it_cfg.classification_mapping_indices, labels), labels # type: ignore[attr-defined] # mixin provides it_cfg - def logits_and_labels(self, batch: BatchEncoding, batch_idx: int) -> tuple[torch.Tensor, torch.Tensor, List[str]]: + def logits_and_labels(self, batch: BatchEncoding, batch_idx: int) -> tuple[torch.Tensor, torch.Tensor, list[str]]: label_ids, labels = self.labels_to_ids(batch.pop("labels")) logits = self(**batch) # type: ignore[misc] # mixin provides __call__ through composition # TODO: add another layer of abstraction here to handle different model output types? Tradeoffs to consider... @@ -280,7 +280,7 @@ def logits_and_labels(self, batch: BatchEncoding, batch_idx: int) -> tuple[torch assert isinstance(logits, torch.Tensor), f"Expected logits to be a torch.Tensor but got {type(logits)}" return torch.squeeze(logits[:, -1, :], dim=1), label_ids, labels - def collect_answers(self, logits: torch.Tensor | tuple, labels: torch.Tensor, mode: str = "log") -> Optional[Dict]: + def collect_answers(self, logits: torch.Tensor | tuple, labels: torch.Tensor, mode: str = "log") -> Dict | None: logits = self.standardize_logits(logits) # type: ignore[arg-type] # standardize_logits handles tuple case per_example_answers, _ = torch.max(logits, dim=-2) preds = torch.argmax(per_example_answers, axis=-1) # type: ignore[call-arg] diff --git a/src/interpretune/base/modules.py b/src/interpretune/base/modules.py index 21e3534f..b05e3009 100644 --- a/src/interpretune/base/modules.py +++ b/src/interpretune/base/modules.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import torch @@ -51,7 +51,7 @@ def model_init(self) -> None: def load_metric(self) -> None: """Optionally load a metric at the end of model initialization.""" - def on_session_end(self) -> Optional[Any]: + def on_session_end(self) -> Any | None: """Optionally execute some post-interpretune session (train, test, iterative exploration) steps.""" if getattr(self, "memprofiler", None): self.memprofiler.dump_memory_stats() diff --git a/src/interpretune/config/analysis.py b/src/interpretune/config/analysis.py index 2cbced57..c418de50 100644 --- a/src/interpretune/config/analysis.py +++ b/src/interpretune/config/analysis.py @@ -1,5 +1,5 @@ from __future__ import annotations # see PEP 749, no longer needed when 3.13 reaches EOL -from typing import Optional, Generator, Union, Callable +from typing import Generator, Callable from dataclasses import dataclass, field import datetime import warnings @@ -26,9 +26,9 @@ class AnalysisCfg(ITSerializableCfg): output_store: AnalysisStoreProtocol | None = None # usually constructed on setup() input_store: AnalysisStoreProtocol | None = None # store containing input data from previous op - target_op: Optional[Union[str, AnalysisOp, Callable, list[AnalysisOp]]] = None # input op to be resolved - output_schema: Optional[OpSchema | str | AnalysisOp] = None # Schema, op, or op name to define schema - name: Optional[str] = None # Name for this analysis configuration + target_op: str | AnalysisOp | Callable | list[AnalysisOp] | None = None # input op to be resolved + output_schema: OpSchema | str | AnalysisOp | None = None # Schema, op, or op name to define schema + name: str | None = None # Name for this analysis configuration fwd_hooks: list[tuple] = field(default_factory=list) bwd_hooks: list[tuple] = field(default_factory=list) cache_dict: dict = field(default_factory=dict) @@ -41,10 +41,10 @@ class AnalysisCfg(ITSerializableCfg): step_fn: str = "analysis_step" # Name of the method to use/generate for analysis auto_prune_batch_encoding: bool = True # Automatically prune encoded batches to only include relevant keys _applied_to: dict = field(default_factory=dict) # Dictionary tracking which modules this cfg has been applied to - _op: Optional[Union[str, AnalysisOp, Callable, list[AnalysisOp]]] = None # op via generated analysis step + _op: str | AnalysisOp | Callable | list[AnalysisOp] | None = None # op via generated analysis step @property - def op(self) -> Optional[Union[str, AnalysisOp, Callable, list[AnalysisOp]]]: + def op(self) -> str | AnalysisOp | Callable | list[AnalysisOp] | None: """Get the operation, unwrapping any OpWrapper if present.""" if self._op is None: return None @@ -56,7 +56,7 @@ def op(self) -> Optional[Union[str, AnalysisOp, Callable, list[AnalysisOp]]]: return self._op @op.setter - def op(self, value: Optional[Union[str, AnalysisOp, Callable, list[AnalysisOp]]]) -> None: + def op(self, value: str | AnalysisOp | Callable | list[AnalysisOp] | None) -> None: """Set the operation value.""" self._op = value @@ -176,7 +176,7 @@ def resolve_op(self) -> None: # Otherwise leave as-is; callers will raise clear errors if this is invalid - def materialize_names_filter(self, module, fallback_sae_targets: Optional[SAEAnalysisTargets] = None) -> None: + def materialize_names_filter(self, module, fallback_sae_targets: SAEAnalysisTargets | None = None) -> None: """Set names_filter using sae_analysis_targets if not already set. Args: @@ -204,7 +204,7 @@ def maybe_set_hooks(self) -> None: if not self.fwd_hooks and not self.bwd_hooks: self.check_add_default_hooks() - def prepare_model_ctx(self, module, fallback_sae_targets: Optional[SAEAnalysisTargets] = None) -> None: + def prepare_model_ctx(self, module, fallback_sae_targets: SAEAnalysisTargets | None = None) -> None: """Configure names_filter and hooks for a specific module. Args: @@ -299,7 +299,7 @@ def add_default_cache_hooks(self, include_backward: bool = True) -> None: bwd_hooks = [(self.names_filter, _make_simple_cache_hook(cache_dict=self.cache_dict, is_backward=True))] self.bwd_hooks = bwd_hooks - def check_add_default_hooks(self) -> Optional[tuple[list, list]]: + def check_add_default_hooks(self) -> tuple[list, list] | None: """Construct forward and backward hooks based on analysis operation.""" fwd_hooks, bwd_hooks = [], [] @@ -349,9 +349,9 @@ def reset_applied_state(self, module=None) -> None: def apply( self, module, - cache_dir: Optional[str] = None, - op_output_dataset_path: Optional[str] = None, - fallback_sae_targets: Optional[SAEAnalysisTargets] = None, + cache_dir: str | None = None, + op_output_dataset_path: str | None = None, + fallback_sae_targets: SAEAnalysisTargets | None = None, ): """Set up analysis configuration and configure for the given module. diff --git a/src/interpretune/config/circuit_tracer.py b/src/interpretune/config/circuit_tracer.py index fe051d03..a09721fb 100644 --- a/src/interpretune/config/circuit_tracer.py +++ b/src/interpretune/config/circuit_tracer.py @@ -1,5 +1,4 @@ from __future__ import annotations -from typing import List, Optional from dataclasses import dataclass import torch @@ -17,7 +16,7 @@ class CircuitTracerConfig(ITSerializableCfg): # Model and transcoder settings """Model name to use for circuit tracing. If None, uses the base model name.""" - model_name: Optional[str] = None + model_name: str | None = None """Transcoder set to use. Can be 'gemma', 'llama', or path to custom config. @@ -34,9 +33,9 @@ class CircuitTracerConfig(ITSerializableCfg): """Batch size for backward passes during attribution.""" batch_size: int = 256 """Maximum number of feature nodes to include in attribution.""" - max_feature_nodes: Optional[int] = None + max_feature_nodes: int | None = None """Memory optimization option ('cpu', 'disk', or None).""" - offload: Optional[str] = None + offload: str | None = None """Whether to display detailed progress information.""" verbose: bool = True @@ -50,14 +49,14 @@ class CircuitTracerConfig(ITSerializableCfg): """Whether to automatically save generated graphs.""" save_graphs: bool = True """Directory to save attribution graphs.If None, uses analysis output directory.""" - graph_output_dir: Optional[str] = None + graph_output_dir: str | None = None # Interpretune CT enhancement settings """ Specific tokens to analyze, will use tokens associated with top `max_n_logits` if `None`.""" - analysis_target_tokens: Optional[List[str]] = None + analysis_target_tokens: list[str] | None = None """A tensor of pre-tokenized target token IDs for analysis or a module attribute to be used as a source for them.""" - target_token_ids: Optional[List[int] | torch.Tensor | str] = None + target_token_ids: list[int] | torch.Tensor | str | None = None """Whether to prepare graphs for Neuronpedia graph storage and analysis.""" use_neuronpedia: bool = False @@ -67,4 +66,4 @@ class CircuitTracerConfig(ITSerializableCfg): class CircuitTracerITLensConfig(ITLensConfig): """ITLens configuration with Circuit Tracer support.""" - circuit_tracer_cfg: Optional[CircuitTracerConfig] = None + circuit_tracer_cfg: CircuitTracerConfig | None = None diff --git a/src/interpretune/config/datamodule.py b/src/interpretune/config/datamodule.py index 76143141..16fbd77e 100644 --- a/src/interpretune/config/datamodule.py +++ b/src/interpretune/config/datamodule.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, Dict, Tuple, List +from typing import Any, Tuple, List import logging import os from dataclasses import dataclass, field @@ -18,30 +18,30 @@ @dataclass(kw_only=True) class PromptConfig(ITSerializableCfg): - cust_task_prompt: Dict[str, Any] = field(default_factory=dict) + cust_task_prompt: dict[str, Any] = field(default_factory=dict) - def model_chat_template_fn(self, task_prompt: str, tokenization_pattern: Optional[str] = None) -> str: + def model_chat_template_fn(self, task_prompt: str, tokenization_pattern: str | None = None) -> str: return task_prompt.strip() @dataclass(kw_only=True) class TokenizationConfig(ITSerializableCfg): tokenizers_parallelism: bool = True - local_fast_tokenizer_path: Optional[str] = None - cust_tokenization_pattern: Optional[str] = None - special_tokens_dict: Dict[str, Any] = field(default_factory=dict) + local_fast_tokenizer_path: str | None = None + cust_tokenization_pattern: str | None = None + special_tokens_dict: dict[str, Any] = field(default_factory=dict) max_seq_length: int = 2048 # TODO: force this to be set rather than allowing a default? @dataclass(kw_only=True) class DatasetProcessingConfig(ITSerializableCfg): remove_unused_columns: bool = True - text_fields: Optional[Tuple] = None - dataset_path: Optional[StrOrPath] = None - enable_datasets_cache: Optional[bool] = False # disable caching unless explicitly set to improve reproducibility - data_collator_cfg: Dict[str, Any] = field(default_factory=dict) - signature_columns: Optional[List] = field(default_factory=list) - prepare_data_map_cfg: Dict[str, Any] = field(default_factory=dict) + text_fields: Tuple | None = None + dataset_path: StrOrPath | None = None + enable_datasets_cache: bool | None = False # disable caching unless explicitly set to improve reproducibility + data_collator_cfg: dict[str, Any] = field(default_factory=dict) + signature_columns: List | None = field(default_factory=list) + prepare_data_map_cfg: dict[str, Any] = field(default_factory=dict) @dataclass(kw_only=True) @@ -49,7 +49,7 @@ class ITDataModuleConfig(ITSharedConfig, TokenizationConfig, DatasetProcessingCo # See NOTE [Interpretune Dataclass-Oriented Configuration] train_batch_size: int = 32 eval_batch_size: int = 32 - dataloader_kwargs: Dict[str, Any] = field(default_factory=dict) + dataloader_kwargs: dict[str, Any] = field(default_factory=dict) # note that for prompt_cfg, we: # 1. use (data)classes to minimize special character yaml parsing complications (can override w/ diff init_args) # 2. do not provide a default dataclass to avoid current dataclass subclass limitations diff --git a/src/interpretune/config/extensions.py b/src/interpretune/config/extensions.py index 0db60180..a47ee9e8 100644 --- a/src/interpretune/config/extensions.py +++ b/src/interpretune/config/extensions.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Callable, NamedTuple +from typing import Any, Callable, NamedTuple from dataclasses import field, make_dataclass, dataclass from interpretune.config import ITSerializableCfg @@ -19,8 +19,8 @@ class ExtensionsContext: allows for easy integration of new extensions into the Interpretune framework. """ - SUPPORTED_EXTENSIONS: Dict[str, Callable] = field(default_factory=dict) - SUPPORTED_EXTENSION_CFGS: Dict[str, Any] = field(default_factory=dict) + SUPPORTED_EXTENSIONS: dict[str, Callable] = field(default_factory=dict) + SUPPORTED_EXTENSION_CFGS: dict[str, Any] = field(default_factory=dict) BASE_EXTENSIONS: tuple = ( ITExtension( "debug_lm", diff --git a/src/interpretune/config/module.py b/src/interpretune/config/module.py index 384338a1..31a00295 100644 --- a/src/interpretune/config/module.py +++ b/src/interpretune/config/module.py @@ -1,4 +1,6 @@ -from typing import Any, TYPE_CHECKING, Optional, Tuple, Type +# N.B. we need to avoid annotations import here due to jsonargparse validation issues that emerge when it is used +# from __future__ import annotations +from typing import Any, TYPE_CHECKING, Tuple, Type from dataclasses import dataclass, field import torch @@ -50,13 +52,13 @@ class OptimizerSchedulerConf(ITSerializableCfg): @dataclass(kw_only=True) class ClassificationConf(ITSerializableCfg): - classification_mapping: Optional[Tuple] = None - classification_mapping_indices: Optional[torch.Tensor] = None + classification_mapping: Tuple | None = None + classification_mapping_indices: torch.Tensor | None = None @dataclass(kw_only=True) class MixinsConf(ITSerializableCfg): - analysis_cfg: Optional["AnalysisCfgProtocol"] = None + analysis_cfg: "AnalysisCfgProtocol | None" = None generative_step_cfg: GenerativeClassificationConfig = field(default_factory=GenerativeClassificationConfig) hf_from_pretrained_cfg: HFFromPretrainedConfig | None = None @@ -121,8 +123,8 @@ class ITState: _it_lr_scheduler_configs: list[LRSchedulerConfig] = field(default_factory=list) _it_optimizers: list[Optimizable] = field(default_factory=list) # init'd via `configure_optimizers` - _log_dir: Optional[StrOrPath] = None - _datamodule: Optional["ITDataModule"] = None # datamodule handle attached after init + _log_dir: StrOrPath | None = None + _datamodule: "ITDataModule | None" = None # datamodule handle attached after init _device: torch.device | None = None # root device (sometimes used if not handled by Lightning) _extensions: dict[str, Any] = field(default_factory=dict) _session_complete: bool = False diff --git a/src/interpretune/config/runner.py b/src/interpretune/config/runner.py index e00a0c01..fe441ae7 100644 --- a/src/interpretune/config/runner.py +++ b/src/interpretune/config/runner.py @@ -1,5 +1,5 @@ from __future__ import annotations # see PEP 749, no longer needed when 3.13 reaches EOL -from typing import TYPE_CHECKING, Optional, Union, List, Iterable +from typing import TYPE_CHECKING, Iterable from dataclasses import dataclass, field from pathlib import Path @@ -16,8 +16,8 @@ # Standalone functions for analysis initialization def to_analysis_cfgs( - analysis_cfgs: Optional[Union[AnalysisCfg, "AnalysisOp", Iterable[Union[AnalysisCfg, "AnalysisOp"]]]], -) -> List[AnalysisCfg]: + analysis_cfgs: "AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] | None", +) -> list[AnalysisCfg]: """Convert various input formats to a list of AnalysisCfg objects. Args: @@ -72,9 +72,9 @@ def to_analysis_cfgs( def init_analysis_dirs( module: "SAEAnalysisModuleProtocol", - cache_dir: Optional[Union[str, Path]] = None, - op_output_dataset_path: Optional[Union[str, Path]] = None, - analysis_cfgs: Optional[List[AnalysisCfg]] = None, + cache_dir: str | Path | None = None, + op_output_dataset_path: str | Path | None = None, + analysis_cfgs: list[AnalysisCfg] | None = None, ) -> tuple[Path, Path]: """Initialize the analysis directories for the given module and analysis configurations. @@ -123,10 +123,10 @@ def init_analysis_dirs( def init_analysis_cfgs( module: "SAEAnalysisModuleProtocol", - analysis_cfgs: List[AnalysisCfg], - cache_dir: Optional[Union[str, Path]] = None, - op_output_dataset_path: Optional[Union[str, Path]] = None, - sae_analysis_targets: Optional["SAEAnalysisTargets"] = None, + analysis_cfgs: list[AnalysisCfg], + cache_dir: str | Path | None = None, + op_output_dataset_path: str | Path | None = None, + sae_analysis_targets: "SAEAnalysisTargets | None" = None, ignore_manual: bool = False, ) -> None: """Initialize analysis configurations for the given module. @@ -187,12 +187,12 @@ def _session_validation(self): @dataclass(kw_only=True) class AnalysisRunnerCfg(SessionRunnerCfg): # Change the field to a private attribute that will store the raw value - analysis_cfgs: Optional[Union[AnalysisCfg, AnalysisOp, Iterable[Union[AnalysisCfg, AnalysisOp]]]] = None + analysis_cfgs: AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] | None = None limit_analysis_batches: int = -1 - cache_dir: Optional[str | Path] = None - op_output_dataset_path: Optional[str | Path] = None + cache_dir: str | Path | None = None + op_output_dataset_path: str | Path | None = None # Add optional sae_analysis_targets as a fallback - sae_analysis_targets: Optional[SAEAnalysisTargets] = None + sae_analysis_targets: SAEAnalysisTargets | None = None # Add artifact configuration artifact_cfg: AnalysisArtifactCfg = field(default_factory=AnalysisArtifactCfg) # Global override for ignore_manual setting in analysis configs @@ -217,6 +217,6 @@ def __post_init__(self): self.op_output_dataset_path = str(self.op_output_dataset_path) @property - def _processed_analysis_cfgs(self) -> List[AnalysisCfg]: + def _processed_analysis_cfgs(self) -> list[AnalysisCfg]: """Process and return the analysis_cfgs as a standardized list of AnalysisCfg objects.""" return to_analysis_cfgs(self.analysis_cfgs) diff --git a/src/interpretune/config/shared.py b/src/interpretune/config/shared.py index 512793ac..baa96545 100644 --- a/src/interpretune/config/shared.py +++ b/src/interpretune/config/shared.py @@ -1,4 +1,4 @@ -from typing import Any, TypeVar, TypeAlias, Sequence, Optional +from typing import Any, TypeVar, TypeAlias, Sequence from dataclasses import dataclass, field, fields, make_dataclass import inspect import logging @@ -211,7 +211,7 @@ def issue_noncomposition_feedback(auto_comp_cfg, superclasses, subclasses): def issue_incomplete_composition_feedback( - auto_comp_cfg: AutoCompConfig, kwargs_not_in_target_type: dict, nonsubcls_mixins: Optional[tuple[type, ...]] + auto_comp_cfg: AutoCompConfig, kwargs_not_in_target_type: dict, nonsubcls_mixins: tuple[type, ...] | None ): no_auto_prefix = ( f"Could not find an auto-composition for {auto_comp_cfg._orig_cfg_cls} that supports all of" diff --git a/src/interpretune/config/transformer_lens.py b/src/interpretune/config/transformer_lens.py index 08762325..df00897f 100644 --- a/src/interpretune/config/transformer_lens.py +++ b/src/interpretune/config/transformer_lens.py @@ -1,4 +1,4 @@ -from typing import Optional, Literal, Dict, Any, TypeAlias, Tuple +from typing import Literal, Any, TypeAlias from dataclasses import dataclass from functools import reduce @@ -19,9 +19,9 @@ class ITLensSharedConfig(ITSerializableCfg): """TransformerLens configuration shared across both `from_pretrained` and config based instantiation modes.""" - move_to_device: Optional[bool] = True - default_padding_side: Optional[Literal["left", "right"]] = "right" - use_bridge: Optional[bool] = True # Use TransformerBridge (v3) by default, set False for legacy HookedTransformer + move_to_device: bool | None = True + default_padding_side: Literal["left", "right"] | None = "right" + use_bridge: bool | None = True # Use TransformerBridge (v3) by default, set False for legacy HookedTransformer @dataclass(kw_only=True) @@ -71,18 +71,18 @@ class ITLensBridgeConfig(ITLensSharedConfig): # The model name/path for TransformerBridge - IT handles HF model instantiation via model_name_or_path model_name: str = "gpt2-small" # Optional kwargs to pass to TransformerBridgeConfig constructor - transformer_bridge_config_overrides: Optional[Dict[str, Any]] = None + transformer_bridge_config_overrides: dict[str, Any] | None = None # Whether to call enable_compatibility_mode on the bridge after instantiation # N.B.: See transformer_lens/model_bridge/bridge.py for details, among other things, this mode: # 1. Breaks weight tying between embed and unembed to allow separate unembed centering # 2. Extracts q/k/v from joint qkv matrices for compatibility with HookedTransformer parameterizations enable_compatibility_mode: bool = False # Optional kwargs for enable_compatibility_mode() - enable_compatibility_mode_kwargs: Optional[Dict[str, Any]] = None + enable_compatibility_mode_kwargs: dict[str, Any] | None = None # Bridge config defaults to using bridge - use_bridge: Optional[bool] = True + use_bridge: bool | None = True # Device is commonly set, so we provide a top-level field for convenience - device: Optional[str] = None + device: str | None = None # Dtype is commonly set, so we provide a top-level field for convenience dtype: str = "float32" @@ -101,21 +101,21 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) class ITLensFromPretrainedConfig(ITLensSharedConfig): model_name: str = "gpt2-small" - fold_ln: Optional[bool] = True - center_writing_weights: Optional[bool] = True - center_unembed: Optional[bool] = True - refactor_factored_attn_matrices: Optional[bool] = False - checkpoint_index: Optional[int] = None - checkpoint_value: Optional[int] = None + fold_ln: bool | None = True + center_writing_weights: bool | None = True + center_unembed: bool | None = True + refactor_factored_attn_matrices: bool | None = False + checkpoint_index: int | None = None + checkpoint_value: int | None = None # for pretrained cfg, IT handles the HF model instantiation via model_name or_path - hf_model: Optional[AutoModelForCausalLM | str] = None + hf_model: AutoModelForCausalLM | str | None = None # currently only annotating with str due to omegaconf container dumping limitations wrt torch.device - device: Optional[str] = None - n_devices: Optional[int] = 1 + device: str | None = None + n_devices: int | None = 1 # IT handles the tokenizer instantiation via either tokenizer, tokenizer_name or model_name_or_path - tokenizer: Optional[PreTrainedTokenizerBase] = None # for pretrained cfg, IT instantiates the tokenizer - fold_value_biases: Optional[bool] = True - default_prepend_bos: Optional[bool] = True + tokenizer: PreTrainedTokenizerBase | None = None # for pretrained cfg, IT instantiates the tokenizer + fold_value_biases: bool | None = True + default_prepend_bos: bool | None = True dtype: str = "float32" def __post_init__(self) -> None: @@ -143,14 +143,14 @@ class ITLensCustomConfig(ITLensSharedConfig): Set `use_bridge=False` (default) or interpretune will force the value to False and warn. """ - cfg: HookedTransformerConfig | Dict[str, Any] + cfg: HookedTransformerConfig | dict[str, Any] # When using a custom config, default to legacy HookedTransformer behavior to prevent # misconfiguration. If the user explicitly sets `use_bridge=True`, Interpretune will # warn and force it to False in `ITLensConfig.__post_init__`. - use_bridge: Optional[bool] = False + use_bridge: bool | None = False # IT handles the tokenizer instantiation via either tokenizer, tokenizer_name or model_name_or_path - # tokenizer: Optional[PreTrainedTokenizerBase] = None + # tokenizer: PreTrainedTokenizerBase | None = None def __post_init__(self) -> None: if not isinstance(self.cfg, HookedTransformerConfig): # ensure the user provided a valid dtype (should be handled by HookedTransformerConfig ideally) @@ -160,7 +160,7 @@ def __post_init__(self) -> None: ITLensCfg: TypeAlias = ITLensFromPretrainedConfig | ITLensCustomConfig | ITLensBridgeConfig # for static typing -ITLensCfgTypes: Tuple[type, type, type] = ( +ITLensCfgTypes: tuple[type, type, type] = ( ITLensFromPretrainedConfig, ITLensCustomConfig, ITLensBridgeConfig, @@ -306,11 +306,11 @@ def _sync_hf_tl_dtypes(self, hf_dtype, tl_dtype): @dataclass(kw_only=True) class TLensGenerationConfig(CoreGenerationConfig): stop_at_eos: bool = True - eos_token_id: Optional[int] = None + eos_token_id: int | None = None freq_penalty: float = 0.0 use_past_kv_cache: bool = True - prepend_bos: Optional[bool] = None - padding_side: Optional[Literal["left", "right"]] = None - return_type: Optional[str] = "input" - output_logits: Optional[bool] = None + prepend_bos: bool | None = None + padding_side: Literal["left", "right"] | None = None + return_type: str | None = "input" + output_logits: bool | None = None verbose: bool = True diff --git a/src/interpretune/extensions/debug_generation.py b/src/interpretune/extensions/debug_generation.py index b900819c..c79f6e22 100644 --- a/src/interpretune/extensions/debug_generation.py +++ b/src/interpretune/extensions/debug_generation.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any, Dict, Tuple, Union, cast +from typing import List, Any, Dict, cast from dataclasses import dataclass, field from copy import deepcopy @@ -17,10 +17,10 @@ @dataclass(kw_only=True) class DebugLMConfig(ITSerializableCfg): enabled: bool = False - debug_raw_preds: Optional[np.ndarray] = None - debug_raw_labels: Optional[np.ndarray] = None - debug_raw_sequences: Optional[List[str]] = None - raw_debug_sequences: List = field(default_factory=list) + debug_raw_preds: np.ndarray | None = None + debug_raw_labels: np.ndarray | None = None + debug_raw_sequences: list[str] | None = None + raw_debug_sequences: list = field(default_factory=list) def __post_init__(self) -> None: if len(self.raw_debug_sequences) == 0 and self.enabled: @@ -51,7 +51,7 @@ class DebugGeneration: # Derive standard output attributes from HF dataclass DEFAULT_OUTPUT_ATTRS = tuple(list(DEFAULT_OUTPUT_DATACLS.__annotations__.keys())) DEFAULT_MODEL_CONFIG_ATTRS = ("cfg", "config") - phandle: Optional[ITModuleGenDebuggable] + phandle: ITModuleGenDebuggable | None def __init__( self, @@ -74,11 +74,11 @@ def _check_phandle(self) -> ITModuleGenDebuggable: assert isinstance(self.phandle, ITModuleGenDebuggable) return cast(ITModuleGenDebuggable, self.phandle) - def debug_sequences(self, sequences: Optional[Union[List, str]] = None) -> List: + def debug_sequences(self, sequences: List | str | None = None) -> List: """_summary_ Args: - sequences (Optional[List], optional): _description_. Defaults to None. + sequences (List | None, optional): _description_. Defaults to None. Returns: List: _description_ @@ -98,11 +98,11 @@ def debug_sequences(self, sequences: Optional[Union[List, str]] = None) -> List: sequences = [sequences] return [f"{ex.strip()}" for ex in sequences] - def chat_debug_sequences(self, sequences: Optional[List] = None, format: Optional[str] = None) -> List: + def chat_debug_sequences(self, sequences: List | None = None, format: str | None = None) -> List: """_summary_ Args: - sequences (Optional[List], optional): _description_. Defaults to None. + sequences (List | None, optional): _description_. Defaults to None. Returns: List: _description_ @@ -141,16 +141,16 @@ def chat_debug_sequences(self, sequences: Optional[List] = None, format: Optiona def _debug_generate( self, inputs: List | torch.Tensor, - gen_kwargs_override: Optional[Dict] = None, - gen_config_override: Optional[Dict] = None, - gen_output_attr: Optional[str] = None, + gen_kwargs_override: Dict | None = None, + gen_config_override: Dict | None = None, + gen_output_attr: str | None = None, ) -> Any: """_summary_ Args: inputs (_type_): _description_ - gen_kwargs_override (Optional[Dict], optional): _description_. Defaults to None. - gen_output_attr (Optional[str], optional): Specific attribute on the model output to use for decoding + gen_kwargs_override (Dict | None, optional): _description_. Defaults to None. + gen_output_attr (str | None, optional): Specific attribute on the model output to use for decoding (e.g., `sequences`). If None, DebugGeneration.DEFAULT_OUTPUT_ATTRS are used for validation. If a raw `torch.Tensor` is returned by the model generate (e.g., HookedTransformer), the DebugGeneration extension will normalize the result into a `transformers.utils.ModelOutput` @@ -178,7 +178,7 @@ def _debug_generate( outputs = ph.it_generate(inputs, **gen_kwargs) return self._normalize_output_to_model_output(outputs, gen_output_attr) - def _normalize_output_to_model_output(self, outputs: Any, gen_output_attr: Optional[str] = None) -> Any: + def _normalize_output_to_model_output(self, outputs: Any, gen_output_attr: str | None = None) -> Any: """Normalize outputs to a Hugging Face ModelOutput if required. - If the model returns a plain torch.Tensor (e.g. legacy HookedTransformer generate), and either @@ -221,9 +221,9 @@ def _normalize_output_to_model_output(self, outputs: Any, gen_output_attr: Optio def perplexity_on_sample( self, - corpus: Optional[Dataset | Dict] = None, - stride: Optional[int] = None, - limit_chars: Optional[int] = None, + corpus: Dataset | Dict | None = None, + stride: int | None = None, + limit_chars: int | None = None, ) -> torch.Tensor: ph = self._check_phandle() @@ -238,7 +238,7 @@ def perplexity_on_sample( perplexity_kwargs = {"stride": stride} if stride else {} return self.naive_perplexity(encoded_corpus, **perplexity_kwargs) - def top1_token_accuracy_on_sample(self, sample: str) -> Tuple[float, List[str]]: + def top1_token_accuracy_on_sample(self, sample: str) -> tuple[float, list[str]]: ph = self._check_phandle() sample_input_ids = ph.datamodule.tokenizer.encode(sample) # type: ignore[attr-defined] # protocol provides datamodule @@ -285,15 +285,15 @@ def naive_perplexity(self, encoded_corpus, stride: int = 512) -> torch.Tensor: return ppl def sanitize_gen_output( - self, outputs: Any, gen_output_attr: Optional[str] = None, decode_cfg_override: Optional[Dict] = None - ) -> Tuple[Any, Dict]: + self, outputs: Any, gen_output_attr: str | None = None, decode_cfg_override: Dict | None = None + ) -> tuple[Any, Dict]: decode_target = self.sanitize_model_output(outputs, gen_output_attr) decode_kwargs = deepcopy(DEFAULT_DECODE_KWARGS) if decode_cfg_override: decode_kwargs.update(decode_cfg_override) return decode_target, decode_kwargs - def sanitize_model_output(self, output: Any, gen_output_attr: Optional[str] = None) -> Any: + def sanitize_model_output(self, output: Any, gen_output_attr: str | None = None) -> Any: # TODO: revisit this logic after getting TL generate PR ready # For simplification, sanitization returns the requested attribute if # specified. Otherwise, return the output directly (e.g. raw tensor or @@ -314,18 +314,18 @@ def sanitize_model_output(self, output: Any, gen_output_attr: Optional[str] = No return output @property - def model_input_names(self) -> List[str]: + def model_input_names(self) -> list[str]: ph = self._check_phandle() return ph.datamodule.tokenizer.model_input_names # type: ignore[attr-defined] # protocol provides datamodule def debug_generate_batch( self, sequences: List, - gen_output_attr: Optional[str] = None, - gen_config_override: Optional[Dict] = None, - gen_kwargs_override: Optional[Dict] = None, - decode_cfg_override: Optional[Dict] = None, - ) -> Tuple[List, List]: + gen_output_attr: str | None = None, + gen_config_override: Dict | None = None, + gen_kwargs_override: Dict | None = None, + decode_cfg_override: Dict | None = None, + ) -> tuple[List, List]: ph = self._check_phandle() test_input_ids = ph.datamodule.tokenizer.batch_encode_plus(sequences) # type: ignore[attr-defined] # protocol provides datamodule @@ -352,11 +352,11 @@ def debug_generate_batch( def debug_generate_serial( self, sequences: List, - gen_output_attr: Optional[str] = None, - gen_config_override: Optional[Dict] = None, - gen_kwargs_override: Optional[Dict] = None, - decode_cfg_override: Optional[Dict] = None, - ) -> Tuple[List, List]: + gen_output_attr: str | None = None, + gen_config_override: Dict | None = None, + gen_kwargs_override: Dict | None = None, + decode_cfg_override: Dict | None = None, + ) -> tuple[List, List]: ph = self._check_phandle() answers = [] diff --git a/src/interpretune/extensions/memprofiler.py b/src/interpretune/extensions/memprofiler.py index 41f6dbfa..a09c3403 100644 --- a/src/interpretune/extensions/memprofiler.py +++ b/src/interpretune/extensions/memprofiler.py @@ -91,7 +91,7 @@ def __post_init__(self) -> None: # @dataclass(kw_only=True) # class PyTorchProfilerCfg(ITSerializableCfg): # # pytorch_profiler_enabled: bool = False -# # pytorch_profiler_cfg: Dict[str, Any] = field(default_factory=dict) +# # pytorch_profiler_cfg: dict[str, Any] = field(default_factory=dict) # accessed in global scope to track non-parameter packed bytes (npp) as a simple proxy (ceiling) for activation memory _npp_bytes = 0 diff --git a/src/interpretune/extensions/neuronpedia.py b/src/interpretune/extensions/neuronpedia.py index 012cba4a..53e7838c 100644 --- a/src/interpretune/extensions/neuronpedia.py +++ b/src/interpretune/extensions/neuronpedia.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Optional, Dict, List, Union, Tuple +from typing import Any, Dict from dataclasses import dataclass, field from pathlib import Path from copy import deepcopy @@ -50,7 +50,7 @@ class NeuronpediaConfig(ITSerializableCfg): """Default prefix for graph slugs when not specified.""" default_slug_prefix: str = "it-generated" """Default metadata to add to graphs.""" - default_metadata: Dict[str, Any] = field( + default_metadata: dict[str, Any] = field( default_factory=lambda: { "info": { "creator_name": "interpretune-user", @@ -110,7 +110,7 @@ def neuronpedia_cfg(self) -> NeuronpediaConfig: assert self.phandle is not None and self.phandle.it_cfg is not None, "IT configuration is not available" return self.phandle.it_cfg.neuronpedia_cfg - def _get_latest_graph_schema(self, schema_path: Union[str, Path] = "graph-schema.json") -> Dict[str, Any]: + def _get_latest_graph_schema(self, schema_path: str | Path = "graph-schema.json") -> dict[str, Any]: """Fetch the latest Neuronpedia graph schema and return it as a dictionary.""" if not _NEURONPEDIA_AVAILABLE: raise RuntimeError("Neuronpedia package is not available.") @@ -178,8 +178,8 @@ def cache_schema(schema_bytes, tag): raise FileNotFoundError("No graph-schema.json available (failed to fetch, no cache, no fallback)") def apply_qparam_transforms( - self, graph_dict: Dict[str, Any], in_place: bool = True, change_log_path: Optional[Union[str, Path]] = None - ) -> Tuple[bool, Dict[str, Any]]: + self, graph_dict: dict[str, Any], in_place: bool = True, change_log_path: str | Path | None = None + ) -> tuple[bool, dict[str, Any]]: """Apply transformations to qParams fields based on predefined rules. Args: @@ -242,7 +242,7 @@ def csv_str_to_list(value): return was_valid, graph_dict def _log_qparam_changes( - self, log: List[Dict], graph_dict: Dict[str, Any], change_log_path: Optional[Union[str, Path]] + self, log: list[Dict], graph_dict: dict[str, Any], change_log_path: str | Path | None ) -> None: """Log qParams changes to console and file.""" # Console logging @@ -278,7 +278,7 @@ def _log_qparam_changes( except Exception as e: rank_zero_warn(f"[NeuronpediaIntegration] Failed to write change log: {e}") - def prune_unsupported_metadata(self, graph_dict: Dict[str, Any]) -> None: + def prune_unsupported_metadata(self, graph_dict: dict[str, Any]) -> None: """Prune metadata fields in graph_dict that do not conform to the graph schema.""" np_graph_schema = self._get_latest_graph_schema() metadata_schema = np_graph_schema.get("properties", {}).get("metadata", {}) @@ -294,10 +294,10 @@ def prune_unsupported_metadata(self, graph_dict: Dict[str, Any]) -> None: def prepare_graph_metadata( self, - graph_dict: Dict[str, Any], - slug: Optional[str] = None, - custom_metadata: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + graph_dict: dict[str, Any], + slug: str | None = None, + custom_metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: """Prepare and enrich graph metadata for Neuronpedia. Args: @@ -353,11 +353,11 @@ def prepare_graph_metadata( def transform_circuit_tracer_graph( self, - graph_path: Union[str, Path], - output_path: Optional[Union[str, Path]] = None, - slug: Optional[str] = None, - custom_metadata: Optional[Dict[str, Any]] = None, - ) -> Tuple[Dict[str, Any], Path]: + graph_path: str | Path, + output_path: str | Path | None = None, + slug: str | None = None, + custom_metadata: dict[str, Any] | None = None, + ) -> tuple[dict[str, Any], Path]: """Transform a Circuit Tracer graph for Neuronpedia compatibility. Args: @@ -406,7 +406,7 @@ def transform_circuit_tracer_graph( return graph_dict, output_path - def _save_graph_json(self, graph_dict: Dict[str, Any], output_path: Path) -> None: + def _save_graph_json(self, graph_dict: dict[str, Any], output_path: Path) -> None: """Save graph dictionary as JSON.""" output_path.parent.mkdir(parents=True, exist_ok=True) @@ -418,7 +418,7 @@ def _save_graph_json(self, graph_dict: Dict[str, Any], output_path: Path) -> Non with open(output_path, "w") as f: f.write(json_str) - def validate_graph(self, graph_dict: Dict[str, Any]) -> bool: + def validate_graph(self, graph_dict: dict[str, Any]) -> bool: """Validate the provided graph dictionary against the Neuronpedia schema. Args: @@ -453,7 +453,7 @@ def validate_graph(self, graph_dict: Dict[str, Any]) -> bool: rank_zero_warn(f"[NeuronpediaIntegration] Graph validation failed: {e.message}") return False - def upload_graph_to_neuronpedia(self, graph_path: Union[str, Path], api_key: Optional[str] = None) -> Any: + def upload_graph_to_neuronpedia(self, graph_path: str | Path, api_key: str | None = None) -> Any: """Upload a graph to Neuronpedia. Args: @@ -516,12 +516,12 @@ def upload_graph_to_neuronpedia(self, graph_path: Union[str, Path], api_key: Opt def transform_graph_for_np( self, - graph_path: Union[str, Path], - slug: Optional[str] = None, + graph_path: str | Path, + slug: str | None = None, upload_to_np: bool = False, - custom_metadata: Optional[Dict[str, Any]] = None, - api_key: Optional[str] = None, - ) -> Tuple[Dict[str, Any], Any]: + custom_metadata: dict[str, Any] | None = None, + api_key: str | None = None, + ) -> tuple[dict[str, Any], Any]: """Transform and upload a Circuit Tracer graph to Neuronpedia. Args: diff --git a/src/interpretune/metadata.py b/src/interpretune/metadata.py index 7505b71d..310c7a54 100644 --- a/src/interpretune/metadata.py +++ b/src/interpretune/metadata.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, Tuple, Type +from typing import Any, Type @dataclass(frozen=True) @@ -13,15 +13,15 @@ class ITClassMetadata: `interpretune.base` package during import-time of modules like `interpretune.session`. """ - base_attrs: Dict[Any, Tuple[str, ...]] = field(default_factory=dict) - ready_attrs: Tuple[str, ...] = field(default_factory=tuple) - composition_target_attrs: Tuple[str, ...] = field(default_factory=tuple) - ready_protocols: Tuple[Type, ...] = field(default_factory=tuple) + base_attrs: dict[Any, tuple[str, ...]] = field(default_factory=dict) + ready_attrs: tuple[str, ...] = field(default_factory=tuple) + composition_target_attrs: tuple[str, ...] = field(default_factory=tuple) + ready_protocols: tuple[Type, ...] = field(default_factory=tuple) # Generic extension points used by other components - core_to_framework_attrs_map: Dict[str, Any] = field(default_factory=dict) - property_composition: Dict[str, Any] = field(default_factory=dict) - gen_prepares_inputs_sigs: Tuple[str, ...] = field(default_factory=tuple) + core_to_framework_attrs_map: dict[str, Any] = field(default_factory=dict) + property_composition: dict[str, Any] = field(default_factory=dict) + gen_prepares_inputs_sigs: tuple[str, ...] = field(default_factory=tuple) __all__ = ["ITClassMetadata"] diff --git a/src/interpretune/protocol.py b/src/interpretune/protocol.py index 19207623..1394f910 100644 --- a/src/interpretune/protocol.py +++ b/src/interpretune/protocol.py @@ -2,12 +2,10 @@ from typing import ( Protocol, runtime_checkable, - Union, TypeAlias, NamedTuple, TYPE_CHECKING, Callable, - Optional, Any, Sequence, Iterable, @@ -47,7 +45,7 @@ # Interpretune helper types ################################################################################ -StrOrPath: TypeAlias = Union[str, Path] +StrOrPath: TypeAlias = str | Path ################################################################################ # Interpretune Enhanced Enums @@ -216,12 +214,12 @@ def __init__( def step(self, metrics: float | int | Tensor, epoch: int | None = None) -> None: ... -STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] +STEP_OUTPUT = Tensor | Mapping[str, Any] | None -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler.LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerTypeUnion = torch.optim.lr_scheduler.LRScheduler | torch.optim.lr_scheduler.ReduceLROnPlateau # Protocol-level union covering the scheduler Protocol and the ReduceLROnPlateau Protocol -LRSchedulerProtocolUnion: TypeAlias = Union[LRScheduler, ReduceLROnPlateau] +LRSchedulerProtocolUnion: TypeAlias = LRScheduler | ReduceLROnPlateau @dataclass @@ -256,16 +254,15 @@ class OptimizerLRSchedulerConfig(TypedDict): lr_scheduler: NotRequired[LRSchedulerTypeUnion | LRSchedulerConfigType] -OptimizerLRScheduler = Optional[ - Union[ - Optimizer, - Sequence[Optimizer], - tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], - OptimizerLRSchedulerConfig, - ] -] +OptimizerLRScheduler = ( + Optimizer + | Sequence[Optimizer] + | tuple[Sequence[Optimizer], Sequence[LRSchedulerTypeUnion | LRSchedulerConfig]] + | OptimizerLRSchedulerConfig + | None +) -ArgsType = Optional[Union[list[str], dict[str, Any], Namespace]] +ArgsType = list[str] | dict[str, Any] | Namespace | None AnyDataClass = TypeVar("AnyDataClass") @@ -383,7 +380,7 @@ def gen_protocol_variants( # supported protocol variant generation. Also add an issue tracker for this approach to solicit ideas for a # cleaner/more pythonic approach. As Python structural subtyping features are still evolving, if a cleaner and # more pythonic approach isn't available now, one will hopefully be available in the near future. -InterpretunableType: TypeAlias = Union[ITDataModuleProtocol, ITModuleProtocol] +InterpretunableType: TypeAlias = ITDataModuleProtocol | ITModuleProtocol class InterpretunableTuple(NamedTuple): @@ -450,7 +447,7 @@ class ITModuleGenDebuggable(ITModuleBase, GenerativeStepProtocol, Protocol): # Analysis Protocols ################################################################################ -NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] +NamesFilter = Callable[[str], bool] | Sequence[str] | str | None class SAEFqn(NamedTuple): @@ -464,7 +461,7 @@ class AnalysisOpProtocol(Protocol): name: str description: str output_schema: dict - input_schema: Optional[dict] + input_schema: dict | None def save_batch( self, @@ -473,7 +470,7 @@ def save_batch( tokenizer: PreTrainedTokenizerBase | None = None, save_prompts: bool = False, save_tokens: bool = False, - decode_kwargs: Optional[dict] = None, + decode_kwargs: dict | None = None, ) -> BaseAnalysisBatchProtocol: ... @@ -493,7 +490,7 @@ def apply_op_by_sae(self, operation: Callable | str, *args, **kwargs) -> "SAEDic class AnalysisStoreProtocol(Protocol): """Protocol verifying core analysis store functionality.""" - dataset: Union[HfDataset, StrOrPath, PathLike, None] + dataset: HfDataset | StrOrPath | PathLike | None streaming: bool cache_dir: str | None @@ -621,50 +618,50 @@ class DefaultAnalysisBatchProtocol(BaseAnalysisBatchProtocol): or change existing attributes as needed. Attributes: - logit_diffs (Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]]): + logit_diffs (torch.Tensor | dict[str, dict[int, torch.Tensor]] | None): Per batch logit differences with shape [batch_size] - answer_logits (Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]]): + answer_logits (torch.Tensor | dict[str, dict[int, torch.Tensor]] | None): Model output logits with shape [batch_size, 1, num_classes] - loss (Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]]): + loss (torch.Tensor | dict[str, dict[int, torch.Tensor]] | None): Loss values with shape [batch_size] - label_ids (Optional[torch.Tensor]): + label_ids (torch.Tensor | None): Input labels translated to token ids with shape [batch_size] (if labels provided & translation is needed) - orig_labels (Optional[torch.Tensor]): + orig_labels (torch.Tensor | None): Ground truth unmodified labels with shape [batch_size] - preds (Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]]): + preds (torch.Tensor | dict[str, dict[int, torch.Tensor]] | None): Model predictions with shape [batch_size] - cache (Optional[ActivationCacheProtocol]): + cache (ActivationCacheProtocol | None): Forward pass activation cache - grad_cache (Optional[ActivationCacheProtocol]): + grad_cache (ActivationCacheProtocol | None): Backward pass gradient cache - answer_indices (Optional[torch.Tensor]): + answer_indices (torch.Tensor | None): Indices of answers with shape [batch_size] - alive_latents (Optional[dict[str, list[int]]]): + alive_latents (dict[str, list[int]] | None): Active latent indices per SAE hook - correct_activations (Optional[dict[str, torch.Tensor]]): + correct_activations (dict[str, torch.Tensor] | None): SAE activations after corrections with shape [batch_size, d_sae] for each SAE - attribution_values (Optional[dict[str, torch.Tensor]]): + attribution_values (dict[str, torch.Tensor] | None): Attribution values per SAE hook - tokens (Optional[torch.Tensor]): + tokens (torch.Tensor | None): Input token IDs - prompts (Optional[list[str]]): + prompts (list[str] | None): Text prompts """ - logit_diffs: Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]] - answer_logits: Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]] - loss: Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]] - preds: Optional[torch.Tensor | dict[str, dict[int, torch.Tensor]]] - label_ids: Optional[torch.Tensor] - orig_labels: Optional[torch.Tensor] - cache: Optional[ActivationCacheProtocol] - grad_cache: Optional[ActivationCacheProtocol] - answer_indices: Optional[torch.Tensor] - alive_latents: Optional[dict[str, list[int]]] - correct_activations: Optional[dict[str, torch.Tensor]] - attribution_values: Optional[dict[str, torch.Tensor]] - tokens: Optional[torch.Tensor] - prompts: Optional[list[str]] + logit_diffs: torch.Tensor | dict[str, dict[int, torch.Tensor]] | None + answer_logits: torch.Tensor | dict[str, dict[int, torch.Tensor]] | None + loss: torch.Tensor | dict[str, dict[int, torch.Tensor]] | None + preds: torch.Tensor | dict[str, dict[int, torch.Tensor]] | None + label_ids: torch.Tensor | None + orig_labels: torch.Tensor | None + cache: ActivationCacheProtocol | None + grad_cache: ActivationCacheProtocol | None + answer_indices: torch.Tensor | None + alive_latents: dict[str, list[int]] | None + correct_activations: dict[str, torch.Tensor] | None + attribution_values: dict[str, torch.Tensor] | None + tokens: torch.Tensor | None + prompts: list[str] | None class CircuitAnalysisBatchProtocol(DefaultAnalysisBatchProtocol): @@ -673,17 +670,17 @@ class CircuitAnalysisBatchProtocol(DefaultAnalysisBatchProtocol): Extends the default protocol with circuit tracing specific attributes. Attributes: - attribution_graphs (Optional[list]): + attribution_graphs (list | None): Generated attribution graphs for each prompt in the batch - graph_metadata (Optional[list[dict]]): + graph_metadata (list[dict] | None): Metadata for each generated graph including parameters used - graph_paths (Optional[list[str]]): + graph_paths (list[str] | None): File paths where graphs are saved (if saved) - circuit_prompts (Optional[list[str]]): + circuit_prompts (list[str] | None): Prompts used for circuit attribution (may differ from input prompts) """ - attribution_graphs: Optional[list] - graph_metadata: Optional[list[dict]] - graph_paths: Optional[list[str]] - circuit_prompts: Optional[list[str]] + attribution_graphs: list | None + graph_metadata: list[dict] | None + graph_paths: list[str] | None + circuit_prompts: list[str] | None diff --git a/src/interpretune/registry.py b/src/interpretune/registry.py index a0d8fddd..cb44aa4d 100644 --- a/src/interpretune/registry.py +++ b/src/interpretune/registry.py @@ -1,5 +1,4 @@ from typing import ( - Optional, Any, Dict, Tuple, @@ -61,11 +60,11 @@ def register( phase: str, model_src_key: str, task_name: str, - adapter_combinations: Tuple[Adapter] | Tuple[Tuple[Adapter]], + adapter_combinations: tuple[Adapter] | tuple[tuple[Adapter]], reg_key: str, registered_cfg: RegisteredCfg, - cfg_dict: Optional[Dict[str, Any]] = None, - description: Optional[str] = None, + cfg_dict: dict[str, Any] | None = None, + description: str | None = None, ) -> None: """Registers valid component + adapter compositions mapped to composition keys with required metadata. @@ -76,11 +75,11 @@ def register( task_name: adapter_combination: tuple identifying the valid adapter composition reg_key: The canonical key of the test/example module. - registered_cfg: Tuple[Callable], + registered_cfg: tuple[Callable], description : composition description cfg_dict: optionally save original configuration dictionary """ - supported_composition: Dict[str | Adapter | Tuple[Adapter | str], Any] = {} + supported_composition: dict[str | Adapter | tuple[Adapter | str], Any] = {} supported_composition[reg_key] = registered_cfg supported_composition["description"] = description if description is not None else "" supported_composition["cfg_dict"] = cfg_dict @@ -102,7 +101,7 @@ def available_keys(self, key_type: RegKeyType | str = "string") -> None: def available_keys_feedback(self, target_key: str | Tuple) -> str: assert isinstance(target_key, (str, tuple)), "`target_key` must be either a str or a tuple" # Collect entries as (displayable_key, description) and sort by the displayable key - entries: List[Tuple[str, str]] = [] + entries: list[tuple[str, str]] = [] for key in self.keys(): if not isinstance(key, type(target_key)): continue @@ -154,11 +153,11 @@ def get(self, target: Tuple | str | RegKeyQueryable, default: Any = None) -> Any ) raise KeyError(err_msg) - def remove(self, composition_key: Tuple[Adapter | str]) -> None: + def remove(self, composition_key: tuple[Adapter | str]) -> None: """Removes the registered adapter composition by name.""" del self[composition_key] - def available_compositions(self, adapter_filter: Optional[Sequence[Adapter] | Adapter] = None) -> Set: + def available_compositions(self, adapter_filter: Sequence[Adapter] | Adapter | None = None) -> Set: """Returns a list of registered compositions, optionally filtering by an adapter or sequence of adapters.""" if adapter_filter is not None: @@ -175,12 +174,12 @@ def __str__(self) -> str: def instantiate_and_register( reg_key: str, - rv: Dict[str, Any], - datamodule_cls: Optional[Type[DataModuleInitable] | str] = None, - module_cls: Optional[Type[ModuleSteppable] | str] = None, + rv: dict[str, Any], + datamodule_cls: Type[DataModuleInitable] | str | None = None, + module_cls: Type[ModuleSteppable] | str | None = None, target_registry: ModuleRegistry = MODULE_REGISTRY, - itdm_cfg_defaults_fn: Optional[Callable] = None, - it_cfg_defaults_fn: Optional[Callable] = None, + itdm_cfg_defaults_fn: Callable | None = None, + it_cfg_defaults_fn: Callable | None = None, ) -> None: cfg_dict = deepcopy(rv) reg_info, shared_cfg, registered_cfg = rv["reg_info"], rv["shared_config"], rv["registered_cfg"] @@ -262,7 +261,7 @@ def apply_defaults(cfg: ITConfig | ITDataModuleConfig, defaults: Dict, force_ove setattr(cfg, k, v) -def itdm_cfg_factory(cfg: Dict, shared_config: Dict, defaults_func: Optional[Callable] = None): +def itdm_cfg_factory(cfg: Dict, shared_config: Dict, defaults_func: Callable | None = None): prompt_cfg = cfg.get("prompt_cfg", {}) # instantiate supported class_path refs # TODO: add path for specifying custom datamodule_cfg subclass when necessary @@ -274,7 +273,7 @@ def itdm_cfg_factory(cfg: Dict, shared_config: Dict, defaults_func: Optional[Cal return instantiated_cfg -def it_cfg_factory(cfg: Dict, shared_config: Optional[Dict] = None, defaults_func: Optional[Callable] = None): +def it_cfg_factory(cfg: Dict, shared_config: Dict | None = None, defaults_func: Callable | None = None): if "class_path" in cfg: cfg["init_args"] = cfg["init_args"] | shared_config if "init_args" in cfg else shared_config instantiated_cfg = instantiate_nested(cfg) diff --git a/src/interpretune/runners/analysis.py b/src/interpretune/runners/analysis.py index 046497e6..d90ae006 100644 --- a/src/interpretune/runners/analysis.py +++ b/src/interpretune/runners/analysis.py @@ -1,5 +1,5 @@ from __future__ import annotations # see PEP 749, no longer needed when 3.13 reaches EOL -from typing import Any, TYPE_CHECKING, Union, List, Optional, Dict, Tuple +from typing import Any, TYPE_CHECKING import logging from functools import partialmethod from pathlib import Path @@ -47,7 +47,7 @@ def analysis_store_generator( _call_itmodule_hook(module, hook_name="on_analysis_epoch_end", hook_msg="Running analysis epoch end hooks") -def maybe_init_analysis_cfg(module: "ITModule", analysis_cfg: Optional[AnalysisCfg] = None, **kwargs) -> dict: +def maybe_init_analysis_cfg(module: "ITModule", analysis_cfg: AnalysisCfg | None = None, **kwargs) -> dict: """Initialize analysis configuration if needed and return updated kwargs. Args: @@ -75,7 +75,7 @@ def maybe_init_analysis_cfg(module: "ITModule", analysis_cfg: Optional[AnalysisC return kwargs -def dataset_features_and_format(module: "ITModule", kwargs: dict) -> Tuple[dict, dict, dict]: +def dataset_features_and_format(module: "ITModule", kwargs: dict) -> tuple[dict, dict, dict]: """Generate dataset features and formatting parameters based on module configuration. Args: @@ -162,7 +162,7 @@ def core_analysis_loop( limit_analysis_batches: int = -1, step_fn: str = "analysis_step", max_epochs: int = 1, - analysis_cfg: Optional[AnalysisCfg] = None, + analysis_cfg: AnalysisCfg | None = None, *args, **kwargs, ): @@ -227,7 +227,7 @@ def it_init(self): "no analysis configuration to generate one." ) - def _run(self, phase, loop_fn, step_fn: Optional[str] = None, *args: Any, **kwargs: Any) -> Any | None: + def _run(self, phase, loop_fn, step_fn: str | None = None, *args: Any, **kwargs: Any) -> Any | None: self.phase = AllPhases[phase] # type: ignore[assignment] # phase attribute assignment phase_artifacts = loop_fn(step_fn=step_fn, **self.run_cfg.__dict__) self.it_session_end() @@ -237,11 +237,11 @@ def _run(self, phase, loop_fn, step_fn: Optional[str] = None, *args: Any, **kwar def run_analysis( self, - analysis_cfgs: Optional[Union[AnalysisCfg, Any, List[Union[AnalysisCfg, Any]]]] = None, + analysis_cfgs: AnalysisCfg | Any | list[AnalysisCfg | Any] | None = None, cache_dir: str | Path | None = None, op_output_dataset_path: str | Path | None = None, **kwargs, - ) -> AnalysisStoreProtocol | Dict[str, AnalysisStoreProtocol]: + ) -> AnalysisStoreProtocol | dict[str, AnalysisStoreProtocol]: """Unified method to run analysis operations based on the provided configuration. Args: diff --git a/src/interpretune/session.py b/src/interpretune/session.py index 891afe7d..522ac82b 100644 --- a/src/interpretune/session.py +++ b/src/interpretune/session.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Callable, Mapping, Type, Sequence +from typing import Any, Dict, Callable, Mapping, Type, Sequence import os import importlib from dataclasses import dataclass, field @@ -28,14 +28,14 @@ def __new__(mcs, name, bases, classdict, **kwargs): component, input_cls, ctx = mcs._validate_build_ctx(kwargs) # TODO: add runtime checks for adherence to IT protocol here? composition_classes = mcs._map_composition_target(component, ctx) - new_bases: Tuple[type, ...] = (NamedWrapper, input_cls, *composition_classes) # type: ignore[misc] + new_bases: tuple[type, ...] = (NamedWrapper, input_cls, *composition_classes) # type: ignore[misc] built_class = super().__new__(mcs, name, new_bases, classdict) built_class._orig_module_name = input_cls.__qualname__ # type: ignore[attr-defined] # dynamic attribute for session tracking built_class._composed_classes = composition_classes # type: ignore[attr-defined] # dynamic attribute for session tracking return built_class @staticmethod - def _validate_build_ctx(kwargs: Dict) -> Tuple[str, Callable, Tuple]: + def _validate_build_ctx(kwargs: dict) -> tuple[str, Callable, tuple]: required_kwargs = ("component", "input", "ctx") for kwarg in required_kwargs: if kwarg not in kwargs: @@ -44,7 +44,7 @@ def _validate_build_ctx(kwargs: Dict) -> Tuple[str, Callable, Tuple]: raise ValueError(f"Specified component was {component}, should be either 'module' or 'datamodule'") if not callable(input := kwargs.get("input")): raise ValueError(f"Specified input {input} is not a callable, it should be the class to be enriched.") - if not isinstance(ctx := kwargs.get("ctx"), Tuple): + if not isinstance(ctx := kwargs.get("ctx"), tuple): raise ValueError(f"Specified ctx {ctx} must be a tuple specifying the desired class enrichment") return component, input, ctx @@ -64,10 +64,10 @@ def _map_composition_target(component, ctx): class UnencapsulatedArgs(ITSerializableCfg): # Most use cases will encapsulate datamodule/module config by subclassing the relevant dataclasses for a given # experiment/application but we also allow unencapsulated args/kwargs to be passed to the datamodule and module - dm_args: Tuple = () - dm_kwargs: Dict[str, Any] = field(default_factory=dict) - module_args: Tuple = () - module_kwargs: Dict[str, Any] = field(default_factory=dict) + dm_args: tuple = () + dm_kwargs: dict[str, Any] = field(default_factory=dict) + module_args: tuple = () + module_kwargs: dict[str, Any] = field(default_factory=dict) @dataclass(kw_only=True) @@ -75,11 +75,11 @@ class ITSessionConfig(UnencapsulatedArgs): adapter_ctx: Sequence[Adapter | str] = (Adapter.core,) datamodule_cfg: ITDataModuleConfig module_cfg: ITConfig - shared_cfg: Optional[ITSharedConfig | Dict] = None - datamodule_cls: Optional[Type[DataModuleInitable] | str] = None - module_cls: Optional[Type[ModuleSteppable] | str] = None - datamodule: Optional[ITDataModuleProtocol] = None - module: Optional[ITModuleProtocol] = None + shared_cfg: ITSharedConfig | Dict | None = None + datamodule_cls: Type[DataModuleInitable] | str | None = None + module_cls: Type[ModuleSteppable] | str | None = None + datamodule: ITDataModuleProtocol | None = None + module: ITModuleProtocol | None = None def __post_init__(self): self.adapter_ctx = ADAPTER_REGISTRY.canonicalize_composition(self.adapter_ctx) @@ -136,7 +136,7 @@ def __iter__(self): def __len__(self): return len(self.to_dict()) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return {self._ctx[0]: self.datamodule, self._ctx[1]: self.module} def __repr__(self) -> str: diff --git a/src/interpretune/utils/data_movement.py b/src/interpretune/utils/data_movement.py index 6478ecbb..229e8202 100644 --- a/src/interpretune/utils/data_movement.py +++ b/src/interpretune/utils/data_movement.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any import torch @@ -12,12 +12,10 @@ # Data movement utils ################################################################################ -_DEVICE = Union[torch.device, str, int] +_DEVICE = torch.device | str | int -def to_device( - device: _DEVICE, obj: Union[torch.nn.Module, torch.Tensor, Any] -) -> Union[torch.nn.Module, torch.Tensor, Any]: +def to_device(device: _DEVICE, obj: torch.nn.Module | torch.Tensor | Any) -> torch.nn.Module | torch.Tensor | Any: r"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on that device. diff --git a/src/interpretune/utils/exceptions.py b/src/interpretune/utils/exceptions.py index bd413103..898c80cd 100644 --- a/src/interpretune/utils/exceptions.py +++ b/src/interpretune/utils/exceptions.py @@ -6,7 +6,7 @@ import tempfile from pathlib import Path from datetime import datetime -from typing import Dict, Any, Optional, Union, Sequence +from typing import Any, Sequence log = logging.getLogger(__name__) @@ -20,9 +20,9 @@ class MisconfigurationException(Exception): def handle_exception_with_debug_dump( e: Exception, - context_data: Union[Dict[str, Any], Sequence[Any]], + context_data: dict[str, Any] | Sequence[Any], operation_name: str = "operation", - debug_dir_override: Optional[Union[str, Path]] = None, + debug_dir_override: str | Path | None = None, ) -> None: """Handle an exception by creating a detailed debug dump file and re-raising the exception. @@ -46,7 +46,7 @@ def handle_exception_with_debug_dump( dump_file = debug_dir / f"{operation_name}_error_{timestamp}.json" # Add exception information to debug data - debug_info: Dict[str, Any] = { + debug_info: dict[str, Any] = { "error": str(e), "traceback": traceback.format_exc(), } @@ -114,7 +114,7 @@ def handle_exception_with_debug_dump( raise e -def _introspect_variable(var: Any) -> Dict[str, Any]: +def _introspect_variable(var: Any) -> dict[str, Any]: """Introspect a variable to create a detailed representation for debugging. Args: @@ -123,7 +123,7 @@ def _introspect_variable(var: Any) -> Dict[str, Any]: Returns: A dictionary with detailed information about the variable """ - result: Dict[str, Any] = { + result: dict[str, Any] = { "type": str(type(var).__name__), } diff --git a/src/interpretune/utils/import_utils.py b/src/interpretune/utils/import_utils.py index 13187727..b588d832 100644 --- a/src/interpretune/utils/import_utils.py +++ b/src/interpretune/utils/import_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Union, Optional, Dict, Tuple, Callable, List +from typing import Any, Callable import importlib from functools import lru_cache from importlib.util import find_spec @@ -11,7 +11,7 @@ def instantiate_class( - init: Dict[str, Any], args: Optional[Union[Any, Tuple[Any, ...]]] = None, import_only: bool = False + init: dict[str, Any], args: Any | tuple[Any, ...] | None = None, import_only: bool = False ) -> Any: """Instantiates a class with the given args and init. Accepts class definitions with a "class_path". @@ -49,7 +49,7 @@ def instantiate_class( return args_class(**kwargs) if not args else args_class(*args, **kwargs) -def resolve_funcs(cfg_obj: Any, func_type: str) -> List[Callable[..., Any]]: +def resolve_funcs(cfg_obj: Any, func_type: str) -> list[Callable[..., Any]]: resolved_funcs = [] funcs_to_resolve = getattr(cfg_obj, func_type) if not isinstance(funcs_to_resolve, list): @@ -74,7 +74,7 @@ def resolve_funcs(cfg_obj: Any, func_type: str) -> List[Callable[..., Any]]: return resolved_funcs -def _resolve_dtype(dtype: Union[torch.dtype, str]) -> Optional[torch.dtype]: +def _resolve_dtype(dtype: torch.dtype | str) -> torch.dtype | None: """Resolve a dtype which may be a torch.dtype or a string to a torch.dtype.""" if isinstance(dtype, torch.dtype): return dtype @@ -82,7 +82,7 @@ def _resolve_dtype(dtype: Union[torch.dtype, str]) -> Optional[torch.dtype]: return _str_to_dtype(dtype) -def _str_to_dtype(str_dtype: str) -> Optional[torch.dtype]: +def _str_to_dtype(str_dtype: str) -> torch.dtype | None: if hasattr(torch, str_dtype): return getattr(torch, str_dtype) elif hasattr(torch, str_dtype.split(".")[-1]): diff --git a/src/interpretune/utils/logging.py b/src/interpretune/utils/logging.py index 698b98fb..a7c40d68 100644 --- a/src/interpretune/utils/logging.py +++ b/src/interpretune/utils/logging.py @@ -2,7 +2,7 @@ import sys import logging import warnings -from typing import Optional, Callable, Any, TypeVar, Union, Dict +from typing import Callable, Any, TypeVar, Dict from typing_extensions import ParamSpec, overload from functools import wraps from contextlib import contextmanager @@ -71,7 +71,7 @@ def collect_env_info() -> Dict: ################################################################################ -def _get_rank() -> Optional[int]: +def _get_rank() -> int | None: rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) @@ -82,21 +82,21 @@ def _get_rank() -> Optional[int]: @overload -def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: ... +def rank_zero_only(fn: Callable[P, T]) -> Callable[P, T | None]: ... @overload def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ... -def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]: +def rank_zero_only(fn: Callable[P, T], default: T | None = None) -> Callable[P, T | None]: """Wrap a function to call internal function only in rank zero. Function that can be used as a decorator to enable a function/method being called only on global rank 0. """ @wraps(fn) - def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: + def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> T | None: rank = getattr(rank_zero_only, "rank", None) if rank is None: raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") @@ -134,12 +134,12 @@ def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: _info(*args, stacklevel=stacklevel, **kwargs) -def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None: +def _warn(message: str | Warning, stacklevel: int = 2, **kwargs: Any) -> None: warnings.warn(message, stacklevel=stacklevel, **kwargs) @rank_zero_only -def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None: +def rank_zero_warn(message: str | Warning, stacklevel: int = 4, **kwargs: Any) -> None: """Emit warn-level messages only on global rank 0.""" _warn(message, stacklevel=stacklevel, **kwargs) @@ -147,7 +147,7 @@ def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: rank_zero_deprecation_category = DeprecationWarning -def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **kwargs: Any) -> None: +def rank_zero_deprecation(message: str | Warning, stacklevel: int = 5, **kwargs: Any) -> None: """Emit a deprecation warning only on global rank 0.""" category = kwargs.pop("category", rank_zero_deprecation_category) rank_zero_warn(message, stacklevel=stacklevel, category=category, **kwargs) diff --git a/src/interpretune/utils/warnings.py b/src/interpretune/utils/warnings.py index bb0e501c..cec1b819 100644 --- a/src/interpretune/utils/warnings.py +++ b/src/interpretune/utils/warnings.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Union, Type +from typing import Type from pathlib import Path _default_format_warning = warnings.formatwarning @@ -12,7 +12,7 @@ def _is_path_in_interpretune(path: Path) -> bool: # adapted from lightning.fabric.utilities.warnings def _custom_format_warning( - message: Union[Warning, str], category: Type[Warning], filename: str, lineno: int, line: Optional[str] = None + message: Warning | str, category: Type[Warning], filename: str, lineno: int, line: str | None = None ) -> str: """Custom formatting that avoids an extra line in case warnings are emitted from the `rank_zero`-functions.""" if _is_path_in_interpretune(Path(filename)): diff --git a/src/it_examples/example_module_registry.py b/src/it_examples/example_module_registry.py index 466a86ba..f3c13fbb 100644 --- a/src/it_examples/example_module_registry.py +++ b/src/it_examples/example_module_registry.py @@ -5,7 +5,7 @@ from pathlib import Path from functools import partial import threading -from typing import Any, Union, Tuple, TYPE_CHECKING +from typing import Any, Tuple, TYPE_CHECKING if TYPE_CHECKING: from interpretune.registry import ModuleRegistry @@ -83,7 +83,7 @@ def _create_registry(self): return registry - def get(self, target: Union[Tuple, str, Any], default: Any = None) -> Any: + def get(self, target: Tuple | str | Any, default: Any = None) -> Any: """Get item from registry, initializing if needed.""" return self.registry.get(target, default) diff --git a/src/it_examples/example_prompt_configs.py b/src/it_examples/example_prompt_configs.py index d3a1e352..a3fab99d 100644 --- a/src/it_examples/example_prompt_configs.py +++ b/src/it_examples/example_prompt_configs.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from it_examples.experiments.rte_boolq import RTEBoolqPromptConfig #################################### @@ -21,7 +20,7 @@ def __post_init__(self) -> None: self.USER_ROLE_START = self.B_TURN + self.USER_ROLE + "\n" self.USER_ROLE_END = self.E_TURN + self.B_TURN + self.ASSISTANT_ROLE + "\n" - def model_chat_template_fn(self, task_prompt: str, tokenization_pattern: Optional[str] = None) -> str: + def model_chat_template_fn(self, task_prompt: str, tokenization_pattern: str | None = None) -> str: if tokenization_pattern == "gemma2-chat": sequence = self.USER_ROLE_START + f"{task_prompt.strip()} {self.USER_ROLE_END}" else: @@ -65,7 +64,7 @@ def __post_init__(self) -> None: ) self.USER_ROLE_END = self.E_TURN + self.ASSISTANT_ROLE_HEADER + "\n" - def model_chat_template_fn(self, task_prompt: str, tokenization_pattern: Optional[str] = None) -> str: + def model_chat_template_fn(self, task_prompt: str, tokenization_pattern: str | None = None) -> str: if tokenization_pattern == "llama3-chat": sequence = self.SYS_ROLE_START + f"{task_prompt.strip()} {self.USER_ROLE_END}" else: diff --git a/src/it_examples/experiments/rte_boolq.py b/src/it_examples/experiments/rte_boolq.py index 28e2f90b..00c50b16 100644 --- a/src/it_examples/experiments/rte_boolq.py +++ b/src/it_examples/experiments/rte_boolq.py @@ -18,7 +18,7 @@ # pyright: reportOptionalMemberAccess=false import os -from typing import Any, Dict, Optional, Tuple, List, Callable, Generator +from typing import Any, Tuple, Callable, Generator from dataclasses import dataclass, field from pprint import pformat import logging @@ -64,7 +64,7 @@ @dataclass(kw_only=True) class RTEBoolqEntailmentMapping: entailment_mapping: Tuple = ("Yes", "No") # RTE style, invert mapping for BoolQ - entailment_mapping_indices: Optional[torch.Tensor] = None + entailment_mapping_indices: torch.Tensor | None = None @dataclass(kw_only=True) @@ -80,7 +80,7 @@ def __repr__(self): class RTEBoolqPromptConfig(PromptConfig): ctx_question_join: str = "Does the previous passage imply that " question_suffix: str = "? Answer with only one word, either Yes or No." - cust_task_prompt: Dict[str, Any] | None = None # type: ignore[assignment] # intentional override for demo + cust_task_prompt: dict[str, Any] | None = None # type: ignore[assignment] # intentional override for demo # add our custom model attributes @@ -104,7 +104,7 @@ def __init__(self, itdm_cfg: ITDataModuleConfig) -> None: itdm_cfg.text_fields = TASK_TEXT_FIELD_MAP[itdm_cfg.task_name] super().__init__(itdm_cfg=itdm_cfg) - def prepare_data(self, target_model: Optional[torch.nn.Module] = None) -> None: + def prepare_data(self, target_model: torch.nn.Module | None = None) -> None: """Load the SuperGLUE dataset.""" # N.B. prepare_data is called in a single process (rank 0, either per node or globally) so do not use it to # assign state (e.g. self.x=y) @@ -156,10 +156,10 @@ def predict_dataloader(self) -> DataLoader: def encode_for_rteboolq( example_batch: LazyDict, tokenizer: PreTrainedTokenizerBase, - text_fields: List[str], + text_fields: list[str], prompt_cfg: PromptConfig, template_fn: Callable, - tokenization_pattern: Optional[str] = None, + tokenization_pattern: str | None = None, ) -> BatchEncoding: example_batch["sequences"] = [] # TODO: use promptsource instead of this manual approach after tinkering @@ -204,14 +204,14 @@ def training_step(self, batch: BatchEncoding, batch_idx: int) -> STEP_OUTPUT: return loss @MemProfilerHooks.memprofilable - def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]: + def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT | None: answer_logits, labels, orig_labels = self.logits_and_labels(batch, batch_idx) val_loss = self.loss_fn(answer_logits, labels) self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) self.collect_answers(answer_logits, orig_labels) @MemProfilerHooks.memprofilable - def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]: + def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT | None: if self.it_cfg.generative_step_cfg.enabled: self.generative_classification_test_step(batch, batch_idx, dataloader_idx=dataloader_idx) else: @@ -219,7 +219,7 @@ def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = def generative_classification_test_step( self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0 - ) -> Optional[STEP_OUTPUT]: + ) -> STEP_OUTPUT | None: labels = batch.pop("labels") outputs = self.it_generate(batch, **self.it_cfg.generative_step_cfg.lm_generation_cfg.generate_kwargs) # We expect a HF ModelOutput with a `.logits` attribute when `output_logits` is used. @@ -231,12 +231,12 @@ def generative_classification_test_step( raise ValueError("Expected ModelOutput with `logits` or a logits tensor from generate()") self.collect_answers(logits, labels) - def default_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]: + def default_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT | None: labels = batch.pop("labels") outputs = self(**batch) self.collect_answers(outputs.logits, labels) - def predict_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]: + def predict_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT | None: labels = batch.pop("labels") outputs = self(**batch) return self.collect_answers(outputs, labels, mode="return") @@ -246,7 +246,7 @@ def analysis_step( batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0, - analysis_batch: Optional[AnalysisBatch] = None, + analysis_batch: AnalysisBatch | None = None, ) -> Generator[STEP_OUTPUT, None, None]: """Run analysis operations on a batch and yield results.""" # Demo mixing model methods and native IT analysis ops @@ -280,7 +280,7 @@ def load_metric(self) -> None: ) # we override the default labels_to_ids method to demo using our module-specific attributes/logic - def labels_to_ids(self, labels: List[str]) -> List[int]: + def labels_to_ids(self, labels: list[str]) -> list[int]: return torch.take(self.it_cfg.entailment_mapping_indices, labels), labels # We override the default standardize_logits method to demo using custom attributes etc. diff --git a/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py b/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py index 460db761..5e0003d6 100644 --- a/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py +++ b/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import torch @@ -21,8 +21,8 @@ # ruff: noqa: F821 -def ap_compute_attribution_end(local_vars: Dict[str, Any]) -> None: - data: Dict[str, Any] = {} +def ap_compute_attribution_end(local_vars: dict[str, Any]) -> None: + data: dict[str, Any] = {} # Collect shapes from attribution data collect_shapes( @@ -36,7 +36,7 @@ def ap_compute_attribution_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after attribution component computation", data) -def ap_precomputation_phase_end(local_vars: Dict[str, Any]) -> None: +def ap_precomputation_phase_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars(context_keys=["target_token_analysis"], local_keys=["ctx"], local_vars=local_vars) v["target_token_analysis"].act_matrix = v["ctx"].activation_matrix @@ -49,7 +49,7 @@ def ap_precomputation_phase_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after precomputation phase", data) -def ap_forward_pass_end(local_vars: Dict[str, Any]) -> None: +def ap_forward_pass_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars(context_keys=["target_token_analysis"], local_keys=["ctx", "model"], local_vars=local_vars) @@ -63,7 +63,7 @@ def ap_forward_pass_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after forward pass", data) -def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None: +def ap_build_input_vectors_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -115,7 +115,7 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after building input vectors w/ target logits", data) -def ap_compute_logit_attribution_end(local_vars: Dict[str, Any]) -> None: +def ap_compute_logit_attribution_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -148,7 +148,7 @@ def ap_compute_logit_attribution_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after logit attribution", data) -def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None: +def ap_compute_feature_attributions_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -195,7 +195,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after feature attribution", data) -def ap_graph_creation_start(local_vars: Dict[str, Any]) -> None: +def ap_graph_creation_start(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -207,7 +207,7 @@ def ap_graph_creation_start(local_vars: Dict[str, Any]) -> None: n_logits = v["n_logits"] tta.reorg_logit_indices = (v["edge_matrix"].shape[0] - n_logits) + tta.logit_indices tta.graph_logit_indices = (full_edge_matrix.shape[0] - n_logits) + tta.logit_indices - data: Dict[str, Any] = {} + data: dict[str, Any] = {} collect_shapes(data, local_vars, ["full_edge_matrix", "edge_matrix"]) pre_normalized_logit_node_sum = full_edge_matrix[tta.graph_logit_indices.to(full_edge_matrix.device), :].sum(1) data["pre_normalized_logit_node_sum"] = VarAnnotate( @@ -243,7 +243,7 @@ def ap_graph_creation_start(local_vars: Dict[str, Any]) -> None: analysis_log_point("Graph packaging complete", data) -def ap_node_compute_influence_init(local_vars: Dict[str, Any]) -> None: +def ap_node_compute_influence_init(local_vars: dict[str, Any]) -> None: """Collect initial current_influence vector in compute_influence.""" # Check call stack to determine context context = get_caller_context( @@ -325,7 +325,7 @@ def ap_node_compute_influence_init(local_vars: Dict[str, Any]) -> None: analysis_log_point("After initial compute_influence computation (node context)", data) -def ap_node_compute_influence(local_vars: Dict[str, Any]) -> None: +def ap_node_compute_influence(local_vars: dict[str, Any]) -> None: """Collect current_influence vectors after each iteration of compute_influence.""" # Check call stack to determine context context = get_caller_context( @@ -356,7 +356,7 @@ def ap_node_compute_influence(local_vars: Dict[str, Any]) -> None: analysis_log_point("After compute_influence iteration (node context)", data) -def ap_graph_prune_node_influence_end(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_node_influence_end(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( local_keys=["node_influence", "node_mask", "node_threshold", "pruned_matrix", "n_logits", "n_tokens"], local_vars=local_vars, @@ -402,7 +402,7 @@ def ap_graph_prune_node_influence_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("After node_influence threshold pruning applied", data) -def ap_graph_prune_edge_influence_post_norm(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_edge_influence_post_norm(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( context_keys=["target_token_analysis", "n_pos"], local_keys=["edge_scores", "normalized_pruned", "pruned_influence", "pruned_matrix", "max_n_logits"], @@ -538,7 +538,7 @@ def ap_graph_prune_edge_influence_post_norm(local_vars: Dict[str, Any]) -> None: analysis_log_point("After edge influence calculation", data) -def ap_graph_prune_edge_influence_pre_mask(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_edge_influence_pre_mask(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( # context_keys=["target_token_analysis", "n_pos"], local_keys=["edge_mask", "node_mask", "logit_weights", "edge_scores", "n_logits"], @@ -596,7 +596,7 @@ def ap_graph_prune_edge_influence_pre_mask(local_vars: Dict[str, Any]) -> None: analysis_log_point("After edge influence calculation", data) -def ap_graph_prune_edge_influence_end(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_edge_influence_end(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( context_keys=["target_token_analysis"], local_keys=[ diff --git a/src/it_examples/notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py b/src/it_examples/notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py index 814e7a81..d3a1af9f 100644 --- a/src/it_examples/notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py +++ b/src/it_examples/notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py @@ -1,7 +1,5 @@ """Trivial example of a hub-based analysis operation for interpretune framework testing.""" -from typing import Optional - import torch from interpretune.protocol import BaseAnalysisBatchProtocol, DefaultAnalysisBatchProtocol @@ -11,8 +9,8 @@ class SomeDifferentBatchDef(BaseAnalysisBatchProtocol): """Example of batch definition for a trivial demo op.""" # Define any additional attributes or methods specific to this batch definition - preds: Optional[torch.Tensor] - pred_sum: Optional[torch.Tensor] + preds: torch.Tensor | None + pred_sum: torch.Tensor | None def trivial_test_op_impl(analysis_batch: DefaultAnalysisBatchProtocol) -> DefaultAnalysisBatchProtocol: diff --git a/src/it_examples/notebooks/dev/example_op_collections/op_collection_demo_utils.py b/src/it_examples/notebooks/dev/example_op_collections/op_collection_demo_utils.py index e296d4ba..369cd6c2 100644 --- a/src/it_examples/notebooks/dev/example_op_collections/op_collection_demo_utils.py +++ b/src/it_examples/notebooks/dev/example_op_collections/op_collection_demo_utils.py @@ -10,7 +10,7 @@ import io import contextlib from pathlib import Path -from typing import Tuple, Generator +from typing import Generator from interpretune.utils import rank_zero_warn from interpretune.analysis import IT_ANALYSIS_OP_PATHS @@ -197,7 +197,7 @@ def print_env_summary( def setup_local_op_collection( source_local_op_collection: Path, tmp_local_op_collection: Path = Path("/tmp/local_op_collection") -) -> Tuple[str, str]: +) -> tuple[str, str]: """Setup local operation collection by copying to /tmp/ and updating environment variables. Args: @@ -359,7 +359,7 @@ def cleanup_hub_repository(download_result) -> None: print("⚠️ No download_result available - cannot determine what to clean up") -def reimport_interpretune_with_capture() -> Tuple[str, str, object]: # type: ignore[misc] # DISPATCHER type unknown +def reimport_interpretune_with_capture() -> tuple[str, str, object]: # type: ignore[misc] # DISPATCHER type unknown """Re-import interpretune with stdout and stderr capture to check for expected warnings. Returns: @@ -400,7 +400,7 @@ def inspect_err_for_composite_op_warning(stderr_output: str) -> None: print(stderr_output) -def generate_test_batches(num_batches: int = 2) -> Generator[Tuple[str, object, object], None, None]: +def generate_test_batches(num_batches: int = 2) -> Generator[tuple[str, object, object], None, None]: """Generator that yields test analysis_batch objects with random orig_labels. Args: diff --git a/src/it_examples/notebooks/publish/.notebook_hashes.json b/src/it_examples/notebooks/publish/.notebook_hashes.json index e50be2be..19396955 100644 --- a/src/it_examples/notebooks/publish/.notebook_hashes.json +++ b/src/it_examples/notebooks/publish/.notebook_hashes.json @@ -2,17 +2,17 @@ "notebooks/dev/attribution_analysis/__pycache__/analysis_points.cpython-310.pyc": "3d1c0c51fe0e6026c710b90e2b3c3b892aecdb8c3c50356273228a58dec4610e", "notebooks/dev/attribution_analysis/__pycache__/analysis_points.cpython-312.pyc": "32665520ebe6c5becf19c9ec52aa260cc5ce8cde167d5b21ed1f7d4148d31e98", "notebooks/dev/attribution_analysis/analysis_injection_config.yaml": "0153616832195ab9adfc9008cdf7f26d73426128f9fe01b9a614929018e1a53a", - "notebooks/dev/attribution_analysis/analysis_points.py": "b27a7eb3d6cf02194ed776b4d7e9a885c2ebb563aac73822e3d21b0366c3e8a0", + "notebooks/dev/attribution_analysis/analysis_points.py": "d50932960f524b5fa30f950048d687d3a75cf2576189ce1f75e8e99ef16070a4", "notebooks/dev/attribution_analysis/attribution_analysis.ipynb": "09a357b3db12c6bcaa8cae2d2c854aa62fb2155443e484e5b2721fd616e19464", "notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic.ipynb": "760c051bfb3932815e400918724661e46ef49fbe10a4e9231442653707dc83c9", "notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb": "c101e6594dd34cc55c7dd46439e316a8b4a4fff8d7afc1edb9b5bba88156ab0a", "notebooks/dev/circuit_tracer_examples/gradient_flow_analysis.ipynb": "21818d3fa9765b29b84d38d977a744c185d89fced8913893147d61600c124b45", "notebooks/dev/example_op_collections/__pycache__/op_collection_demo_utils.cpython-312.pyc": "26040b9a677df03c8b52c34bb31682936a16a5fa6e790b856e8596d64452b12e", "notebooks/dev/example_op_collections/hub_op_collection/hub_op_collection.yaml": "93d81adff9eb7380500df6e598a712d7ac1b73c54a831bdfbfaa05ed701b244b", - "notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py": "ebcc238e3c0cefc94dba7963197f5a6d81a28158f5f13b1bb826859dd5475f60", + "notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py": "423bf5982fae43be9e340401d27818062b05fc961919cf4dd72eb7dd21fae718", "notebooks/dev/example_op_collections/local_op_collection/local_op_collection.yaml": "9542558d80cbd561dcc59173073c4c28abb641fcd956da40bd47a59d7d8d670d", "notebooks/dev/example_op_collections/local_op_collection/local_op_definitions.py": "e7a21b8151f745ed60fabf992f9742eae30a4a4b9c98d1ced6a5d47a0461d450", - "notebooks/dev/example_op_collections/op_collection_demo_utils.py": "705c962e03656f8ef9f3ec7ccb973cc8895db1073fc29a9e5b10f075ef217bde", + "notebooks/dev/example_op_collections/op_collection_demo_utils.py": "5259ef92a22b99ff18b3f2caf76afbf53ab81df09419229085c436d3be1938ec", "notebooks/dev/example_op_collections/op_collection_example.ipynb": "cce5a0226bb60e558c57016580d900a164169fb80dd3720614ac4422e5bf9bd2", "notebooks/dev/neuronpedia_example/circuit_tracer_w_neuronpedia_example.ipynb": "dcf25c891ee0170aa766a17cdc4e71010b66448a85260e7c704341e747eaf1a2", "notebooks/dev/saelens_adapter_example/emphasized_yaml.png": "586a308c0d875b32900f41d65b95321798c8f36a39123d2dd1245949e35914ab", diff --git a/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py b/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py index 460db761..5e0003d6 100644 --- a/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py +++ b/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import torch @@ -21,8 +21,8 @@ # ruff: noqa: F821 -def ap_compute_attribution_end(local_vars: Dict[str, Any]) -> None: - data: Dict[str, Any] = {} +def ap_compute_attribution_end(local_vars: dict[str, Any]) -> None: + data: dict[str, Any] = {} # Collect shapes from attribution data collect_shapes( @@ -36,7 +36,7 @@ def ap_compute_attribution_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after attribution component computation", data) -def ap_precomputation_phase_end(local_vars: Dict[str, Any]) -> None: +def ap_precomputation_phase_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars(context_keys=["target_token_analysis"], local_keys=["ctx"], local_vars=local_vars) v["target_token_analysis"].act_matrix = v["ctx"].activation_matrix @@ -49,7 +49,7 @@ def ap_precomputation_phase_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after precomputation phase", data) -def ap_forward_pass_end(local_vars: Dict[str, Any]) -> None: +def ap_forward_pass_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars(context_keys=["target_token_analysis"], local_keys=["ctx", "model"], local_vars=local_vars) @@ -63,7 +63,7 @@ def ap_forward_pass_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after forward pass", data) -def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None: +def ap_build_input_vectors_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -115,7 +115,7 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after building input vectors w/ target logits", data) -def ap_compute_logit_attribution_end(local_vars: Dict[str, Any]) -> None: +def ap_compute_logit_attribution_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -148,7 +148,7 @@ def ap_compute_logit_attribution_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after logit attribution", data) -def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None: +def ap_compute_feature_attributions_end(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -195,7 +195,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("after feature attribution", data) -def ap_graph_creation_start(local_vars: Dict[str, Any]) -> None: +def ap_graph_creation_start(local_vars: dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], @@ -207,7 +207,7 @@ def ap_graph_creation_start(local_vars: Dict[str, Any]) -> None: n_logits = v["n_logits"] tta.reorg_logit_indices = (v["edge_matrix"].shape[0] - n_logits) + tta.logit_indices tta.graph_logit_indices = (full_edge_matrix.shape[0] - n_logits) + tta.logit_indices - data: Dict[str, Any] = {} + data: dict[str, Any] = {} collect_shapes(data, local_vars, ["full_edge_matrix", "edge_matrix"]) pre_normalized_logit_node_sum = full_edge_matrix[tta.graph_logit_indices.to(full_edge_matrix.device), :].sum(1) data["pre_normalized_logit_node_sum"] = VarAnnotate( @@ -243,7 +243,7 @@ def ap_graph_creation_start(local_vars: Dict[str, Any]) -> None: analysis_log_point("Graph packaging complete", data) -def ap_node_compute_influence_init(local_vars: Dict[str, Any]) -> None: +def ap_node_compute_influence_init(local_vars: dict[str, Any]) -> None: """Collect initial current_influence vector in compute_influence.""" # Check call stack to determine context context = get_caller_context( @@ -325,7 +325,7 @@ def ap_node_compute_influence_init(local_vars: Dict[str, Any]) -> None: analysis_log_point("After initial compute_influence computation (node context)", data) -def ap_node_compute_influence(local_vars: Dict[str, Any]) -> None: +def ap_node_compute_influence(local_vars: dict[str, Any]) -> None: """Collect current_influence vectors after each iteration of compute_influence.""" # Check call stack to determine context context = get_caller_context( @@ -356,7 +356,7 @@ def ap_node_compute_influence(local_vars: Dict[str, Any]) -> None: analysis_log_point("After compute_influence iteration (node context)", data) -def ap_graph_prune_node_influence_end(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_node_influence_end(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( local_keys=["node_influence", "node_mask", "node_threshold", "pruned_matrix", "n_logits", "n_tokens"], local_vars=local_vars, @@ -402,7 +402,7 @@ def ap_graph_prune_node_influence_end(local_vars: Dict[str, Any]) -> None: analysis_log_point("After node_influence threshold pruning applied", data) -def ap_graph_prune_edge_influence_post_norm(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_edge_influence_post_norm(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( context_keys=["target_token_analysis", "n_pos"], local_keys=["edge_scores", "normalized_pruned", "pruned_influence", "pruned_matrix", "max_n_logits"], @@ -538,7 +538,7 @@ def ap_graph_prune_edge_influence_post_norm(local_vars: Dict[str, Any]) -> None: analysis_log_point("After edge influence calculation", data) -def ap_graph_prune_edge_influence_pre_mask(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_edge_influence_pre_mask(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( # context_keys=["target_token_analysis", "n_pos"], local_keys=["edge_mask", "node_mask", "logit_weights", "edge_scores", "n_logits"], @@ -596,7 +596,7 @@ def ap_graph_prune_edge_influence_pre_mask(local_vars: Dict[str, Any]) -> None: analysis_log_point("After edge influence calculation", data) -def ap_graph_prune_edge_influence_end(local_vars: Dict[str, Any]) -> None: +def ap_graph_prune_edge_influence_end(local_vars: dict[str, Any]) -> None: v = get_analysis_vars( context_keys=["target_token_analysis"], local_keys=[ diff --git a/src/it_examples/notebooks/publish/example_op_collections/hub_op_collection/hub_op_definitions.py b/src/it_examples/notebooks/publish/example_op_collections/hub_op_collection/hub_op_definitions.py index 814e7a81..d3a1af9f 100644 --- a/src/it_examples/notebooks/publish/example_op_collections/hub_op_collection/hub_op_definitions.py +++ b/src/it_examples/notebooks/publish/example_op_collections/hub_op_collection/hub_op_definitions.py @@ -1,7 +1,5 @@ """Trivial example of a hub-based analysis operation for interpretune framework testing.""" -from typing import Optional - import torch from interpretune.protocol import BaseAnalysisBatchProtocol, DefaultAnalysisBatchProtocol @@ -11,8 +9,8 @@ class SomeDifferentBatchDef(BaseAnalysisBatchProtocol): """Example of batch definition for a trivial demo op.""" # Define any additional attributes or methods specific to this batch definition - preds: Optional[torch.Tensor] - pred_sum: Optional[torch.Tensor] + preds: torch.Tensor | None + pred_sum: torch.Tensor | None def trivial_test_op_impl(analysis_batch: DefaultAnalysisBatchProtocol) -> DefaultAnalysisBatchProtocol: diff --git a/src/it_examples/notebooks/publish/example_op_collections/op_collection_demo_utils.py b/src/it_examples/notebooks/publish/example_op_collections/op_collection_demo_utils.py index e296d4ba..369cd6c2 100644 --- a/src/it_examples/notebooks/publish/example_op_collections/op_collection_demo_utils.py +++ b/src/it_examples/notebooks/publish/example_op_collections/op_collection_demo_utils.py @@ -10,7 +10,7 @@ import io import contextlib from pathlib import Path -from typing import Tuple, Generator +from typing import Generator from interpretune.utils import rank_zero_warn from interpretune.analysis import IT_ANALYSIS_OP_PATHS @@ -197,7 +197,7 @@ def print_env_summary( def setup_local_op_collection( source_local_op_collection: Path, tmp_local_op_collection: Path = Path("/tmp/local_op_collection") -) -> Tuple[str, str]: +) -> tuple[str, str]: """Setup local operation collection by copying to /tmp/ and updating environment variables. Args: @@ -359,7 +359,7 @@ def cleanup_hub_repository(download_result) -> None: print("⚠️ No download_result available - cannot determine what to clean up") -def reimport_interpretune_with_capture() -> Tuple[str, str, object]: # type: ignore[misc] # DISPATCHER type unknown +def reimport_interpretune_with_capture() -> tuple[str, str, object]: # type: ignore[misc] # DISPATCHER type unknown """Re-import interpretune with stdout and stderr capture to check for expected warnings. Returns: @@ -400,7 +400,7 @@ def inspect_err_for_composite_op_warning(stderr_output: str) -> None: print(stderr_output) -def generate_test_batches(num_batches: int = 2) -> Generator[Tuple[str, object, object], None, None]: +def generate_test_batches(num_batches: int = 2) -> Generator[tuple[str, object, object], None, None]: """Generator that yields test analysis_batch objects with random orig_labels. Args: diff --git a/src/it_examples/patching/dep_patch_shim.py b/src/it_examples/patching/dep_patch_shim.py index 676b700a..a2d72de1 100644 --- a/src/it_examples/patching/dep_patch_shim.py +++ b/src/it_examples/patching/dep_patch_shim.py @@ -2,7 +2,7 @@ import sys import os from enum import Enum -from typing import NamedTuple, Tuple, Callable +from typing import NamedTuple, Callable from it_examples.patching._patch_utils import lwt_compare_version @@ -18,7 +18,7 @@ class DependencyPatch(NamedTuple): must default) to '1' """ - condition: Tuple[Callable] # typically a tuple of `lwt_compare_version` to define version dependency + condition: tuple[Callable] # typically a tuple of `lwt_compare_version` to define version dependency env_flag: OSEnvToggle # a tuple defining the environment variable based condition and its default if not set function: Callable patched_package: str diff --git a/src/it_examples/test_examples.py b/src/it_examples/test_examples.py index bc602741..bcfa3f53 100644 --- a/src/it_examples/test_examples.py +++ b/src/it_examples/test_examples.py @@ -38,7 +38,7 @@ # BASE_DEBUG_CONFIG = IT_CONFIG_GLOBAL / "base_debug.yaml" # BASE_TL_CONFIG = IT_CONFIG_GLOBAL / "base_transformer_lens.yaml" -# def gen_experiment_cfg_sets(test_keys: Sequence[Tuple[str, str, str, Optional[str], bool]]) -> Dict: +# def gen_experiment_cfg_sets(test_keys: Sequence[tuple[str, str, str, str | None, bool]]) -> Dict: # exp_cfg_sets = {} # for exp, model, subexp, adapter_ctx, debug_mode in test_keys: # base_model_cfg = EXPERIMENTS_BASE / exp / f"{model}.yaml" diff --git a/src/it_examples/utils/analysis_injection/analysis_hook_patcher.py b/src/it_examples/utils/analysis_injection/analysis_hook_patcher.py index a0f9261b..e4f15d88 100644 --- a/src/it_examples/utils/analysis_injection/analysis_hook_patcher.py +++ b/src/it_examples/utils/analysis_injection/analysis_hook_patcher.py @@ -10,7 +10,7 @@ import sys import tempfile from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable import logging from .config_parser import FileHook @@ -20,15 +20,15 @@ class HookRegistry: """Global registry for managing analysis hook functions.""" def __init__(self): - self._hooks: Dict[str, Callable] = {} + self._hooks: dict[str, Callable] = {} self._enabled: bool = False - self._context: Dict[str, Any] = {} # Shared context for all hooks + self._context: dict[str, Any] = {} # Shared context for all hooks def register(self, point_id: str, func: Callable) -> None: """Register an analysis function for a point ID.""" self._hooks[point_id] = func - def execute(self, point_id: str, local_vars: Dict[str, Any]) -> None: + def execute(self, point_id: str, local_vars: dict[str, Any]) -> None: """Execute hook if enabled and registered.""" if not self._enabled: return @@ -97,7 +97,7 @@ def get_analysis_vars( return result -def find_line_by_regex(file_path: Path, regex_pattern: str) -> Optional[int]: +def find_line_by_regex(file_path: Path, regex_pattern: str) -> int | None: """Find line number that matches regex pattern. Args: @@ -128,8 +128,8 @@ def find_line_by_regex(file_path: Path, regex_pattern: str) -> Optional[int]: def patch_file_with_hooks( file_path: Path, - hooks: List[FileHook], - output_path: Optional[Path] = None, + hooks: list[FileHook], + output_path: Path | None = None, ) -> Path: """Patch a file by inserting hook calls at regex-matched lines. @@ -231,10 +231,10 @@ def create_patched_module_loader(module_name: str, patched_file_path: Path) -> N def patch_target_package_files( - file_hooks: Dict[str, FileHook], + file_hooks: dict[str, FileHook], target_package_path: str | Path, target_package_name: str, -) -> Dict[str, Path]: +) -> dict[str, Path]: """Patch target package files with hooks. Args: @@ -248,7 +248,7 @@ def patch_target_package_files( target_package_path = Path(target_package_path) # Group hooks by file - hooks_by_file: Dict[Path, List[FileHook]] = {} + hooks_by_file: dict[Path, list[FileHook]] = {} for hook in file_hooks.values(): if not hook.enabled: continue @@ -272,7 +272,7 @@ def patch_target_package_files( return patched_modules -def install_patched_modules(patched_modules: Dict[str, Path]) -> None: +def install_patched_modules(patched_modules: dict[str, Path]) -> None: """Install patched modules into sys.modules. Args: @@ -282,7 +282,7 @@ def install_patched_modules(patched_modules: Dict[str, Path]) -> None: create_patched_module_loader(module_name, patched_path) -def install_patched_modules_with_references(patched_modules: Dict[str, Path]) -> None: +def install_patched_modules_with_references(patched_modules: dict[str, Path]) -> None: """Install patched modules and update all cross-references to patched functions. This function addresses the issue where modules that have already imported functions @@ -340,15 +340,15 @@ def count_regex_matches(file_path: Path, regex_pattern: str) -> int: return matches -def validate_file_hooks(file_hooks: Dict[str, FileHook], target_package_path: str | Path) -> Dict[str, List[str]]: +def validate_file_hooks(file_hooks: dict[str, FileHook], target_package_path: str | Path) -> dict[str, list[str]]: """Validate that each FileHook's regex matches exactly one line in the target file. Returns a dict with keys 'missing' and 'multiple', each a list of human-readable descriptions for hooks that failed validation. """ target_package_path = Path(target_package_path) - missing: List[str] = [] - multiple: List[str] = [] + missing: list[str] = [] + multiple: list[str] = [] for hook in file_hooks.values(): if not hook.enabled: @@ -370,7 +370,7 @@ def validate_file_hooks(file_hooks: Dict[str, FileHook], target_package_path: st return {"missing": missing, "multiple": multiple} -def get_module_debug_info(module_name: str) -> Dict[str, Any]: +def get_module_debug_info(module_name: str) -> dict[str, Any]: """Get debug information about a module's patching status. Args: @@ -436,7 +436,7 @@ def get_module_debug_info(module_name: str) -> Dict[str, Any]: return info -def verify_patching(patched_modules: Dict[str, Path]) -> Dict[str, Dict[str, Any]]: +def verify_patching(patched_modules: dict[str, Path]) -> dict[str, dict[str, Any]]: """Verify that patched modules are correctly loaded and contain hooks. Args: diff --git a/src/it_examples/utils/analysis_injection/config_parser.py b/src/it_examples/utils/analysis_injection/config_parser.py index deea24f5..323d8b24 100644 --- a/src/it_examples/utils/analysis_injection/config_parser.py +++ b/src/it_examples/utils/analysis_injection/config_parser.py @@ -11,7 +11,7 @@ from copy import deepcopy from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, cast +from typing import Any, Mapping, cast import yaml @@ -73,14 +73,14 @@ class AnalysisInjectionConfig: log_to_file: bool log_dir: str analysis_log_prefix: str - enabled_points: List[str] - file_hooks: Dict[str, FileHook] - target_package_version: Optional[str] = None - analysis_points_module_path: Optional[Path] = None - shared_context: Optional[Dict[str, Any]] = None + enabled_points: list[str] + file_hooks: dict[str, FileHook] + target_package_version: str | None = None + analysis_points_module_path: Path | None = None + shared_context: dict[str, Any] | None = None -def merge_config_dict(base: Mapping[str, Any], override: Mapping[str, Any]) -> Dict[str, Any]: +def merge_config_dict(base: Mapping[str, Any], override: Mapping[str, Any]) -> dict[str, Any]: """Deep merge two configuration dictionaries. The merge strategy is: @@ -94,7 +94,7 @@ def _merge(base_value: Any, override_value: Any, *, key: str | None = None) -> A return _merge_file_hooks(base_value, override_value) if isinstance(base_value, Mapping) and isinstance(override_value, Mapping): - merged: Dict[str, Any] = {k: deepcopy(v) for k, v in base_value.items()} + merged: dict[str, Any] = {k: deepcopy(v) for k, v in base_value.items()} for child_key, child_value in override_value.items(): existing = merged.get(child_key) if existing is not None: @@ -108,9 +108,9 @@ def _merge(base_value: Any, override_value: Any, *, key: str | None = None) -> A def _merge_file_hooks( base_hooks: Mapping[str, Any], override_hooks: Mapping[str, Any] - ) -> Dict[str, Dict[str, Any]]: + ) -> dict[str, dict[str, Any]]: """Merge file_hooks with recursive merging by point_id.""" - merged: Dict[str, Dict[str, Any]] = {k: deepcopy(v) for k, v in base_hooks.items()} + merged: dict[str, dict[str, Any]] = {k: deepcopy(v) for k, v in base_hooks.items()} for point_id, hook_config in override_hooks.items(): if isinstance(hook_config, Mapping): @@ -119,14 +119,14 @@ def _merge_file_hooks( merged[point_id] = _merge(merged[point_id], hook_config) else: # New hook - merged[point_id] = cast(Dict[str, Any], deepcopy(hook_config)) + merged[point_id] = cast(dict[str, Any], deepcopy(hook_config)) else: # Override completely merged[point_id] = deepcopy(hook_config) return merged - merged_root: Dict[str, Any] = {k: deepcopy(v) for k, v in base.items()} + merged_root: dict[str, Any] = {k: deepcopy(v) for k, v in base.items()} for top_key, override_value in override.items(): existing_value = merged_root.get(top_key) if existing_value is not None: @@ -141,7 +141,7 @@ def parse_config_dict(raw_config: Mapping[str, Any], *, source_path: Path | None raw_config = deepcopy(raw_config) - def _resolve_path(maybe_path: Optional[str]) -> Optional[Path]: + def _resolve_path(maybe_path: str | None) -> Path | None: if maybe_path is None: return None candidate = Path(maybe_path) @@ -216,7 +216,7 @@ def _resolve_path(maybe_path: Optional[str]) -> Optional[Path]: ) -def load_config(config_path: Path | str, overrides: Optional[Mapping[str, Any]] = None) -> AnalysisInjectionConfig: +def load_config(config_path: Path | str, overrides: Mapping[str, Any] | None = None) -> AnalysisInjectionConfig: """Load and parse analysis injection configuration from YAML. Args: @@ -242,7 +242,7 @@ def load_config(config_path: Path | str, overrides: Optional[Mapping[str, Any]] return parse_config_dict(raw_config, source_path=config_path) -def get_enabled_points(config: AnalysisInjectionConfig) -> List[str]: +def get_enabled_points(config: AnalysisInjectionConfig) -> list[str]: """Get list of enabled analysis points. Args: diff --git a/src/it_examples/utils/analysis_injection/orchestrator.py b/src/it_examples/utils/analysis_injection/orchestrator.py index 2e2170e0..b04e68b6 100644 --- a/src/it_examples/utils/analysis_injection/orchestrator.py +++ b/src/it_examples/utils/analysis_injection/orchestrator.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Sequence import torch import yaml @@ -82,7 +82,7 @@ def output(self, value: Any) -> None: # Global analysis logger holder _ANALYSIS_LOGGER = None -ANALYSIS_FUNCTIONS: Dict[str, Callable] = {} +ANALYSIS_FUNCTIONS: dict[str, Callable] = {} # Global data collection dictionary ANALYSIS_DATA: OrderedDict[str, dict | None] = OrderedDict() @@ -103,7 +103,7 @@ def clear_analysis_data(): ANALYSIS_DATA.clear() -def _load_analysis_functions_from_module(module_path: Path) -> Dict[str, Any]: +def _load_analysis_functions_from_module(module_path: Path) -> dict[str, Any]: """Load analysis functions from a Python module. The module must define an ``AP_FUNCTIONS`` mapping. The returned dictionary @@ -667,18 +667,18 @@ def __init__( self.target_package_name = target_package_name self.config = None - self.patched_modules: Dict[str, Path] = {} - self.verification_info: Dict[str, Dict[str, Any]] = {} + self.patched_modules: dict[str, Path] = {} + self.verification_info: dict[str, dict[str, Any]] = {} self._logger_initialized = False self._modules_patched = False self._hooks_registered = False self.logger = None # Reference to the analysis logger - self._config_override_data: Optional[Dict[str, Any]] = None - self._config_override_source: Optional[Path] = None - self._version_manager: Optional[Any] = None # PackageVersionManager instance for cleanup + self._config_override_data: dict[str, Any] | None = None + self._config_override_source: Path | None = None + self._version_manager: Any | None = None # PackageVersionManager instance for cleanup - def set_config_override(self, config_data: Dict[str, Any], source_path: Optional[Path] = None) -> None: + def set_config_override(self, config_data: dict[str, Any], source_path: Path | None = None) -> None: """Provide an in-memory configuration that supersedes ``config_path``. Args: @@ -810,7 +810,7 @@ def teardown(self) -> None: self._version_manager.cleanup() @property - def analysis_log(self) -> Optional[str]: + def analysis_log(self) -> str | None: """Get the path to the analysis log file if one exists. Returns: @@ -886,7 +886,7 @@ def get_output( self, key: str, tablefmt: str = "html", - skip: Union[str, Sequence[str], None] = None, + skip: str | Sequence[str] | None = None, format_tensor_kwargs: dict | None = None, ) -> None: """Get formatted output for a specific analysis point and display/print it. @@ -982,13 +982,13 @@ def get_output( # Convenience function for notebook usage def setup_analysis_injection( - config_path: Optional[Path | str] = None, - target_package: Optional[str] = None, - target_package_path: Optional[Path | str] = None, - analysis_functions: Optional[Dict] = None, + config_path: Path | str | None = None, + target_package: str | None = None, + target_package_path: Path | str | None = None, + analysis_functions: Dict | None = None, tokenizer=None, *, - config_overrides: Optional[str] = None, + config_overrides: str | None = None, ) -> AnalysisInjectionOrchestrator: """Set up analysis injection with default paths. @@ -1102,7 +1102,7 @@ def setup_analysis_injection( merged_config_yaml = yaml.safe_dump(merged_config_dict, sort_keys=False) _get_pkg_logger().info("Final merged analysis injection config:\n%s", merged_config_yaml) - auto_functions: Dict[str, Any] = {} + auto_functions: dict[str, Any] = {} if config.analysis_points_module_path: auto_functions = _load_analysis_functions_from_module(config.analysis_points_module_path) _get_pkg_logger().info( @@ -1111,7 +1111,7 @@ def setup_analysis_injection( config.analysis_points_module_path, ) - combined_functions: Dict[str, Any] = dict(auto_functions) + combined_functions: dict[str, Any] = dict(auto_functions) if analysis_functions: combined_functions.update(analysis_functions) _get_pkg_logger().info("Merged %s caller-supplied analysis functions", len(analysis_functions)) diff --git a/src/it_examples/utils/analysis_injection/version_manager.py b/src/it_examples/utils/analysis_injection/version_manager.py index ed3a4eee..963aa7c9 100644 --- a/src/it_examples/utils/analysis_injection/version_manager.py +++ b/src/it_examples/utils/analysis_injection/version_manager.py @@ -26,7 +26,6 @@ import sys import tempfile from pathlib import Path -from typing import Optional import warnings logger = logging.getLogger(__name__) @@ -82,11 +81,11 @@ def __init__(self, package_name: str, required_version: str): """ self.package_name = package_name self.required_version = required_version - self.temp_dir: Optional[Path] = None - self.temp_site_packages: Optional[Path] = None - self._original_path_entry: Optional[str] = None + self.temp_dir: Path | None = None + self.temp_site_packages: Path | None = None + self._original_path_entry: str | None = None - def get_installed_version(self) -> Optional[str]: + def get_installed_version(self) -> str | None: """Get currently installed version of the package. Returns: diff --git a/src/it_examples/utils/example_helpers.py b/src/it_examples/utils/example_helpers.py index 4cd9cb43..9e0f1c67 100644 --- a/src/it_examples/utils/example_helpers.py +++ b/src/it_examples/utils/example_helpers.py @@ -14,7 +14,7 @@ import os import logging from pathlib import Path -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence from dataclasses import dataclass import torch @@ -37,28 +37,28 @@ class TargetTokenAnalysis: """ # Core token info - at least one of tokens or token_ids must be provided - tokens: Optional[List[str]] = None # e.g., ['▁Dallas', '▁Austin'] - token_ids: Optional["torch.Tensor"] = None # e.g., tensor([26865, 22605]) + tokens: list[str] | None = None # e.g., ['▁Dallas', '▁Austin'] + token_ids: torch.Tensor | None = None # e.g., tensor([26865, 22605]) # Optional tokenizer for conversion between tokens and token_ids - tokenizer: Optional[Any] = None # Tokenizer handle for conversion + tokenizer: Any | None = None # Tokenizer handle for conversion # Default device for tensor operations - default_device: Optional[str] = None # e.g., 'cuda', 'cpu' + default_device: str | None = None # e.g., 'cuda', 'cpu' # Activation matrix for feature conversion - act_matrix: Optional["torch.Tensor"] = None # Activation matrix for nodes_to_features conversion + act_matrix: torch.Tensor | None = None # Activation matrix for nodes_to_features conversion # Runtime-derived logit info (torch.Tensor for operations) - logit_indices: Optional["torch.Tensor"] = None # e.g., tensor([0, 1]) - indices into logit arrays - logit_probabilities: Optional["torch.Tensor"] = None # e.g., tensor([0.298, 0.456]) + logit_indices: torch.Tensor | None = None # e.g., tensor([0, 1]) - indices into logit arrays + logit_probabilities: torch.Tensor | None = None # e.g., tensor([0.298, 0.456]) # Edge matrix info (torch.Tensor for operations) - edge_matrix_indices: Optional["torch.Tensor"] = None # e.g., tensor([7358, 8921]) - indices into edge_matrix + edge_matrix_indices: torch.Tensor | None = None # e.g., tensor([7358, 8921]) - indices into edge_matrix # Optional additional analysis fields - top_init_edge_vals: Optional["torch.Tensor"] = None # Shape: (n_tokens, k) - e.g., top 5 vals per token - top_init_edge_indices: Optional["torch.Tensor"] = None # Shape: (n_tokens, k) - e.g., top 5 indices per token + top_init_edge_vals: torch.Tensor | None = None # Shape: (n_tokens, k) - e.g., top 5 vals per token + top_init_edge_indices: torch.Tensor | None = None # Shape: (n_tokens, k) - e.g., top 5 indices per token def __post_init__(self): """Validate and initialize fields.""" @@ -200,8 +200,8 @@ def to_dataframe(self): def nodes_to_features( self, - target_nodes: Union[list[int], list[list[int]], "torch.Tensor"], - act_matrix: Optional["torch.Tensor"] = None, + target_nodes: list[int] | list[list[int]] | torch.Tensor, + act_matrix: torch.Tensor | None = None, feats_only: bool = False, ) -> Any: """Convert target nodes to feature tuples using the activation matrix. @@ -294,7 +294,7 @@ def _safe_nodes_to_features(nodes_tensor): return result_dict -EnvVarSpec = Tuple[str, Union[str, Callable[[str], bool]]] +EnvVarSpec = tuple[str, str | Callable[[str], bool]] def validate_env_vars(env_specs: Sequence[EnvVarSpec]) -> bool: @@ -362,7 +362,7 @@ def required_os_env(env_path: str | Path | None = None, env_reqs: Sequence[EnvVa return True -def collect_shapes(data: dict, local_vars: dict, var_inspects: Sequence[Union[str, VarAnnotate]]) -> dict: +def collect_shapes(data: dict, local_vars: dict, var_inspects: Sequence[str | VarAnnotate]) -> dict: """Collect shape information from specified variables/attributes in local_vars context. Args: diff --git a/src/it_examples/utils/raw_graph_analysis.py b/src/it_examples/utils/raw_graph_analysis.py index 07a6ebe3..2eb6f02e 100644 --- a/src/it_examples/utils/raw_graph_analysis.py +++ b/src/it_examples/utils/raw_graph_analysis.py @@ -1,6 +1,7 @@ +from __future__ import annotations import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any import numpy as np import plotly.express as px @@ -25,8 +26,8 @@ def get_topk_2nd_order_adjacency( limit_node: int, inspect_logit_idxs: torch.Tensor, n_logits: int = 10, - graph: Optional["Graph"] = None, - adjacency_matrix: Optional[torch.Tensor] = None, + graph: Graph | None = None, + adjacency_matrix: torch.Tensor | None = None, adj_matrix_name: str = "adjacency_matrix", ): """Returns the top-k 2nd order adjacency values and indices for non-error, non-logit nodes. @@ -40,7 +41,7 @@ def get_topk_2nd_order_adjacency( adj_matrix_name (str): Name of the adjacency matrix attribute in graph. Returns: - Tuple[torch.Tensor, torch.Tensor]: (topk_values, topk_indices), both of shape [n_rows, k, k] + tuple[torch.Tensor, torch.Tensor]: (topk_values, topk_indices), both of shape [n_rows, k, k] """ if adjacency_matrix is None: if graph is None: @@ -101,7 +102,7 @@ def tensor_distribution_summary_stats(tensor: torch.Tensor): tensor (torch.Tensor): shape (num_iterations, num_values) Returns: - List[dict]: List of dicts containing summary stats for each iteration. + list[dict]: List of dicts containing summary stats for each iteration. """ stats = [] for i in range(tensor.shape[0]): @@ -123,7 +124,7 @@ def tensor_distribution_summary_stats(tensor: torch.Tensor): def plot_ridgeline_convergence( data: torch.Tensor, - stats: Optional[List[Dict[str, Any]]] = None, + stats: list[dict[str, Any]] | None = None, title: str = "Convergence Distribution Ridgeline Plot", ): """Plot ridgeline distributions for convergence data using Plotly. @@ -314,8 +315,8 @@ def get_topk_edges_for_node_range(node_range: tuple, adjacency_matrix: torch.Ten def get_logit_indices_for_tokens( graph, - token_ids: Optional[torch.Tensor] = None, - token_strings: Optional[list] = None, + token_ids: torch.Tensor | None = None, + token_strings: list | None = None, tokenizer=None, ): """Given a tensor of token ids or a list of token strings (with tokenizer), return the corresponding indices in @@ -424,7 +425,7 @@ def get_node_ids_for_adj_matrix_indices(adj_indices, node_mapping): """Returns the node_ids for all adjacency_matrix target nodes provided. Args: - adj_indices (Union[list, torch.Tensor]): Indices of nodes in the adjacency matrix. + adj_indices (list | torch.Tensor): Indices of nodes in the adjacency matrix. node_mapping (dict): Mapping from node index to node_id. Returns: @@ -458,7 +459,7 @@ class RawGraphOverview: node_mapping: dict adj_matrix_target_logit_idxs: torch.Tensor target_logit_vec_idxs: torch.Tensor - extra: Optional[Dict[str, Any]] = None + extra: dict[str, Any] | None = None def node_ids_for(self, idxs): """Given indices (tensor or list), return corresponding node_ids using the node_mapping.""" @@ -483,12 +484,12 @@ def adj_matrix_target_logit_node_ids(self): def gen_raw_graph_overview( k: int, target_token_ids: torch.Tensor, - graph: Optional["Graph"] = None, - adjacency_matrix: Optional[torch.Tensor] = None, + graph: Graph | None = None, + adjacency_matrix: torch.Tensor | None = None, adj_matrix_name: str = "adjacency_matrix", - node_ranges: Optional[dict] = None, - node_mapping: Optional[dict] = None, - node_mask: Optional[torch.Tensor] = None, + node_ranges: dict | None = None, + node_mapping: dict | None = None, + node_mask: torch.Tensor | None = None, ): """Returns the top-k 2nd order adjacency values and indices for non-error, non-logit nodes, as well as the first order adjacency values and indices for the specified logit nodes. Also returns node_ranges, node_mapping, diff --git a/tests/base_defaults.py b/tests/base_defaults.py index edfbfdb2..bb3202a0 100644 --- a/tests/base_defaults.py +++ b/tests/base_defaults.py @@ -1,16 +1,17 @@ +from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Callable, Any, Dict, Sequence, TYPE_CHECKING, Iterable, Union +from typing import List, Tuple, Callable, Any, Dict, Sequence, TYPE_CHECKING, Iterable import pytest from interpretune.adapters import ADAPTER_REGISTRY -from interpretune.config import HFFromPretrainedConfig, GenerativeClassificationConfig, AutoCompConfig +from interpretune.config import HFFromPretrainedConfig, GenerativeClassificationConfig, AutoCompConfig, AnalysisCfg from interpretune.extensions import MemProfilerCfg, DebugLMConfig from interpretune.protocol import Adapter from interpretune.analysis import SAEAnalysisTargets from tests.runif import RunIf, RUNIF_ALIASES if TYPE_CHECKING: - from interpretune.analysis import AnalysisOp, AnalysisCfg + from interpretune.analysis import AnalysisOp default_test_task = "rte" @@ -26,11 +27,11 @@ @dataclass(kw_only=True) class BaseAugTest: alias: str - cfg: Optional[Tuple] = None - marks: Optional[Dict] = None # test instance-specific marks - expected: Optional[Dict] = None - result_gen: Optional[Callable] = None - function_marks: Dict[str, Any] = field(default_factory=dict) # marks applied at test function level + cfg: Tuple | None = None + marks: Dict | None = None # test instance-specific marks + expected: Dict | None = None + result_gen: Callable | None = None + function_marks: dict[str, Any] = field(default_factory=dict) # marks applied at test function level def __post_init__(self): if self.expected is None and self.result_gen is not None: @@ -41,7 +42,7 @@ def __post_init__(self): if self.marks or self.function_marks: self.marks = self._get_marks(self.marks, self.function_marks) - def _get_marks(self, marks: Optional[Dict | str], function_marks: Dict) -> Optional[RunIf]: + def _get_marks(self, marks: Dict | str | None, function_marks: Dict) -> RunIf | None: # support RunIf aliases applied to function level if marks: if isinstance(marks, Dict): @@ -54,7 +55,7 @@ def _get_marks(self, marks: Optional[Dict | str], function_marks: Dict) -> Optio return RunIf(**function_marks) -def pytest_factory(test_configs: List[BaseAugTest], unpack: bool = True, fq_alias: bool = False) -> List: +def pytest_factory(test_configs: list[BaseAugTest], unpack: bool = True, fq_alias: bool = False) -> List: return [ pytest.param( config.alias, @@ -73,28 +74,28 @@ class BaseCfg: model_key: str = default_test_task # "real-model"-based acceptance/parity testing/profiling precision: str | int = "torch.float32" adapter_ctx: Sequence[Adapter | str] = (Adapter.core,) - model_src_key: Optional[str] = None - datamodule_cls: Optional[str] = None # Fully qualified class name (e.g., "tests.modules.DivergeTestITModule") - module_cls: Optional[str] = None # Fully qualified class name (e.g., "tests.modules.DivergeTestITModule") - limit_train_batches: Optional[int] = 1 - limit_val_batches: Optional[int] = 1 - limit_test_batches: Optional[int] = 1 - dm_override_cfg: Optional[Dict] = None - generative_step_cfg: Optional[GenerativeClassificationConfig] = None - hf_from_pretrained_cfg: Optional[HFFromPretrainedConfig] = None - memprofiler_cfg: Optional[MemProfilerCfg] = None - debug_lm_cfg: Optional[DebugLMConfig] = None - model_cfg: Optional[Dict] = None - tl_cfg: Optional[Dict] = None - sae_cfgs: Optional[Dict] = None - auto_comp_cfg: Optional[AutoCompConfig] = None + model_src_key: str | None = None + datamodule_cls: str | None = None # Fully qualified class name (e.g., "tests.modules.DivergeTestITModule") + module_cls: str | None = None # Fully qualified class name (e.g., "tests.modules.DivergeTestITModule") + limit_train_batches: int | None = 1 + limit_val_batches: int | None = 1 + limit_test_batches: int | None = 1 + dm_override_cfg: Dict | None = None + generative_step_cfg: GenerativeClassificationConfig | None = None + hf_from_pretrained_cfg: HFFromPretrainedConfig | None = None + memprofiler_cfg: MemProfilerCfg | None = None + debug_lm_cfg: DebugLMConfig | None = None + model_cfg: Dict | None = None + tl_cfg: Dict | None = None + sae_cfgs: Dict | None = None + auto_comp_cfg: AutoCompConfig | None = None add_saes_on_init: bool = False - req_grad_mask: Optional[Tuple] = None # used to toggle requires grad for non-fts contexts - max_epochs: Optional[int] = 1 - cust_fwd_kwargs: Optional[Dict] = None + req_grad_mask: Tuple | None = None # used to toggle requires grad for non-fts contexts + max_epochs: int | None = 1 + cust_fwd_kwargs: Dict | None = None # used when adding a new test dataset or changing a test model to force re-caching of test datasets force_prepare_data: bool = False # TODO: make this settable via an env variable as well - max_steps: Optional[int] = None + max_steps: int | None = None save_checkpoints: bool = False req_deterministic: bool = False logging_level: str | int = "INFO" # Logging level for test runs @@ -107,14 +108,14 @@ def __post_init__(self): @dataclass(kw_only=True) class AnalysisBaseCfg(BaseCfg): # TODO: we may want to narrow Iterable to Sequence here - analysis_cfgs: Union["AnalysisCfg", "AnalysisOp", Iterable[Union["AnalysisCfg", "AnalysisOp"]]] = None + analysis_cfgs: AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] = None limit_analysis_batches: int = 2 - cache_dir: Optional[str] = None - op_output_dataset_path: Optional[str] = None + cache_dir: str | None = None + op_output_dataset_path: str | None = None # Add optional sae_analysis_targets as a fallback - sae_analysis_targets: Optional[SAEAnalysisTargets] = None + sae_analysis_targets: SAEAnalysisTargets | None = None # Add artifact configuration - artifact_cfg: Optional[Dict] = None + artifact_cfg: Dict | None = None # Global override for ignore_manual setting in analysis configs ignore_manual: bool = False @@ -127,9 +128,9 @@ class OpTestConfig: """Configuration for operation testing.""" target_op: Any # The operation to test - resolved_op: Optional["AnalysisOp"] = None + resolved_op: AnalysisOp | None = None session_fixt: str = "get_it_session__sl_gpt2_analysis__setup" batch_size: int = 1 generate_required_only: bool = True - override_req_cols: Optional[tuple] = None + override_req_cols: tuple | None = None deepcopy_session_fixt: bool = False diff --git a/tests/configuration.py b/tests/configuration.py index 61b82600..16915b13 100644 --- a/tests/configuration.py +++ b/tests/configuration.py @@ -11,7 +11,7 @@ # limitations under the License. # Initially based on https://bit.ly/3oQ8Vqf import os -from typing import Optional, Any, Union, Dict, Sequence +from typing import Any, Dict, Sequence from copy import deepcopy import torch @@ -54,7 +54,7 @@ def apply_itdm_test_cfg(base_itdm_cfg: ITDataModuleConfig, test_cfg: BaseCfg, ** return test_itdm_cfg -def apply_it_test_cfg(base_it_cfg: ITConfig, test_cfg: BaseCfg, core_log_dir: Optional[StrOrPath] = None) -> ITConfig: +def apply_it_test_cfg(base_it_cfg: ITConfig, test_cfg: BaseCfg, core_log_dir: StrOrPath | None = None) -> ITConfig: # TODO: for attributes that don't actually belong to ITConfig (and existing subclasses), we should avoid adding them # e.g. right now, `sae_analysis_targets` is the only one that doesn't belong to ITConfig or defined subclasses test_cfg_override_attrs = [ @@ -84,7 +84,7 @@ def apply_it_test_cfg(base_it_cfg: ITConfig, test_cfg: BaseCfg, core_log_dir: Op return it_cfg -def configure_device_precision(cfg: Dict, device_type: str, precision: Union[int, str]) -> Dict[str, Any]: +def configure_device_precision(cfg: Dict, device_type: str, precision: int | str) -> dict[str, Any]: # TODO: As we accommodate many different device/precision setting sources at the moment, it may make sense # to refactor hf and tl support via additional adapter functions and only test adherence to the # common Interpretune protocol here (testing the adapter functions separately with smaller unit tests) @@ -108,7 +108,7 @@ def configure_device_precision(cfg: Dict, device_type: str, precision: Union[int return cfg -def _update_tl_cfg_device_precision(cfg: Dict, device_type: str, precision: Union[int, str]) -> None: +def _update_tl_cfg_device_precision(cfg: Dict, device_type: str, precision: int | str) -> None: dev_prec_override = {"dtype": get_model_input_dtype(precision), "device": device_type} if isinstance(cfg.tl_cfg, ITLensCustomConfig): # initialized TL custom model config cfg.tl_cfg.cfg.__dict__.update(dev_prec_override) @@ -120,7 +120,7 @@ def _update_tl_cfg_device_precision(cfg: Dict, device_type: str, precision: Unio def _update_sae_cfg_device_precision( - sae_cfg: SAELensCustomConfig | SAELensFromPretrainedConfig, device_type: str, precision: Union[int, str] + sae_cfg: SAELensCustomConfig | SAELensFromPretrainedConfig, device_type: str, precision: int | str ) -> None: dev_prec_override = {"dtype": precision, "device": device_type} # SAEConfig currently requires strings if isinstance(sae_cfg, SAELensCustomConfig): @@ -175,7 +175,7 @@ def cfg_op_env( input_data=None, batches=1, generate_required_only: bool = True, - override_req_cols: Optional[tuple] = None, + override_req_cols: tuple | None = None, ) -> tuple: """Set up a test environment for an operation using a real model. @@ -308,7 +308,7 @@ def config_modules( test_alias, expected_results, tmp_path, - prewrapped_modules: Optional[Dict[str, Any]] = None, + prewrapped_modules: dict[str, Any] | None = None, state_log_mode: bool = False, cfg_only: bool = False, ) -> ITSessionConfig | ITSession: diff --git a/tests/conftest.py b/tests/conftest.py index 3da7621e..80a1a111 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -127,7 +127,7 @@ class ITSessionFixture: class AnalysisSessionFixture(ITSessionFixture): """Dataclass for analysis session fixtures that extends ITSessionFixture.""" - result: AnalysisStore | Dict[AnalysisStore] = None + result: AnalysisStore | dict[AnalysisStore] = None runner: AnalysisRunner = None run_config: Dict | AnalysisRunnerCfg = None @@ -139,7 +139,7 @@ class FixtureCfg: module_cls: Type[ModuleSteppable] = TestITModule datamodule_cls: Type[DataModuleInitable] = TestITDataModule scope: str = "class" - variants: Dict[str, Sequence[FixtPhase]] = field(default_factory=lambda: defaultdict(list)) + variants: dict[str, Sequence[FixtPhase]] = field(default_factory=lambda: defaultdict(list)) FIXTURE_CFGS = { @@ -542,7 +542,7 @@ def gen_fixture(fixt_type, fixt_key, phase): GEN_CLI_CFGS = [TEST_CONFIGS_CLI_PARITY, TEST_CONFIGS_CLI_UNIT] -def gen_experiment_cfg_sets(test_keys: Sequence[Tuple[str, str, bool]], sess_paths: Tuple) -> Dict: +def gen_experiment_cfg_sets(test_keys: Sequence[tuple[str, str, bool]], sess_paths: Tuple) -> Dict: EXPERIMENTS_BASE, BASE_DEBUG_CONFIG = sess_paths exp_cfg_sets = {} for exp, subexp, debug_mode in test_keys: diff --git a/tests/core/cfg_aliases.py b/tests/core/cfg_aliases.py index e8abf6aa..edaba93b 100644 --- a/tests/core/cfg_aliases.py +++ b/tests/core/cfg_aliases.py @@ -2,7 +2,7 @@ from copy import deepcopy from enum import auto from dataclasses import dataclass, field -from typing import Iterable, Union, Optional +from typing import Iterable from pathlib import Path import tempfile @@ -283,7 +283,7 @@ class CoreSLGPT2(BaseCfg): model_name="gpt2-small", default_padding_side="left", use_bridge=False ) ) - # force_prepare_data: Optional[bool] = True # sometimes useful to enable for test debugging + # force_prepare_data: bool | None = True # sometimes useful to enable for test debugging @dataclass(kw_only=True) @@ -317,10 +317,10 @@ class CoreSLGPT2Analysis(AnalysisBaseCfg): ) ) # TODO: customize these cache paths for testing efficiency - # cache_dir: Optional[str] = None - # op_output_dataset_path: Optional[str] = None + # cache_dir: str | None = None + # op_output_dataset_path: str | None = None # important for ephemeral CI runner alignment - force_prepare_data: Optional[bool] = True + force_prepare_data: bool | None = True dm_override_cfg: dict | None = field( default_factory=lambda: { "enable_datasets_cache": True, @@ -340,28 +340,28 @@ def __post_init__(self): @dataclass(kw_only=True) class CoreSLGPT2LogitDiffsBase(CoreSLGPT2Analysis): - analysis_cfgs: Union[AnalysisCfg, AnalysisOp, Iterable[Union[AnalysisCfg, AnalysisOp]]] = ( + analysis_cfgs: AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] = ( AnalysisCfg(target_op=it.logit_diffs_base, save_prompts=False, save_tokens=False, ignore_manual=True), ) @dataclass(kw_only=True) class CoreSLGPT2LogitDiffsSAE(CoreSLGPT2Analysis): - analysis_cfgs: Union[AnalysisCfg, AnalysisOp, Iterable[Union[AnalysisCfg, AnalysisOp]]] = ( + analysis_cfgs: AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] = ( AnalysisCfg(target_op=it.logit_diffs_sae, save_prompts=True, save_tokens=True, ignore_manual=True), ) @dataclass(kw_only=True) class CoreSLGPT2LogitDiffsAttrGrad(CoreSLGPT2Analysis): - analysis_cfgs: Union[AnalysisCfg, AnalysisOp, Iterable[Union[AnalysisCfg, AnalysisOp]]] = ( + analysis_cfgs: AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] = ( AnalysisCfg(target_op=it.logit_diffs_attr_grad, save_prompts=False, save_tokens=False, ignore_manual=True), ) @dataclass(kw_only=True) class CoreSLGPT2LogitDiffsAttrAblation(CoreSLGPT2Analysis): - analysis_cfgs: Union[AnalysisCfg, AnalysisOp, Iterable[Union[AnalysisCfg, AnalysisOp]]] = ( + analysis_cfgs: AnalysisCfg | AnalysisOp | Iterable[AnalysisCfg | AnalysisOp] = ( AnalysisCfg(target_op=it.logit_diffs_attr_ablation, save_prompts=False, save_tokens=False, ignore_manual=True), ) diff --git a/tests/core/test_analysis_ops_definitions.py b/tests/core/test_analysis_ops_definitions.py index 792b232a..09feb4d2 100644 --- a/tests/core/test_analysis_ops_definitions.py +++ b/tests/core/test_analysis_ops_definitions.py @@ -3,7 +3,7 @@ import torch from unittest.mock import MagicMock, patch -from typing import Any, List, Dict, Optional +from typing import Any, Dict from transformers import BatchEncoding from torch.testing import assert_close @@ -951,7 +951,7 @@ def _validate_column_shape( shape_info: torch.Size, loaded_column: torch.Tensor, col_cfg, - batch_count: Optional[int] = None, + batch_count: int | None = None, context: str = "", ) -> None: """Helper to validate column shape based on config and expected shape. @@ -1007,7 +1007,7 @@ def _should_validate_column(self, column_name: str, shape_info: Any, op_cfg: OpT return col_cfg is not None def _validate_format_column_path( - self, op_cfg: OpTestConfig, result_batches: List[AnalysisBatch], loaded_dataset, pre_serialization_shapes: Dict + self, op_cfg: OpTestConfig, result_batches: list[AnalysisBatch], loaded_dataset, pre_serialization_shapes: Dict ) -> None: """Validate loaded dataset using direct column access (format_column path).""" if not pre_serialization_shapes: @@ -1037,7 +1037,7 @@ def _validate_format_column_path( print(f"Warning: Column access validation failed for '{column_name}': {e}") def _validate_format_batch_path( - self, op_cfg: OpTestConfig, result_batches: List[AnalysisBatch], loaded_dataset, pre_serialization_shapes: Dict + self, op_cfg: OpTestConfig, result_batches: list[AnalysisBatch], loaded_dataset, pre_serialization_shapes: Dict ) -> None: """Validate loaded dataset using batch access (format_batch path).""" if not pre_serialization_shapes or len(result_batches) <= 1: @@ -1085,7 +1085,7 @@ def _validate_format_batch_path( print(f"Warning: {method_name} access validation failed for '{column_name}': {e}") def _validate_format_row_path( - self, op_cfg: OpTestConfig, result_batches: List[AnalysisBatch], loaded_dataset + self, op_cfg: OpTestConfig, result_batches: list[AnalysisBatch], loaded_dataset ) -> None: """Validate loaded dataset using row-by-row access (format_row path).""" for i, original_result in enumerate(result_batches): @@ -1145,7 +1145,7 @@ def _validate_format_row_path( def validate_loaded_dataset( self, op_cfg: OpTestConfig, - result_batches: List[AnalysisBatch], + result_batches: list[AnalysisBatch], loaded_dataset, pre_serialization_shapes: Dict = None, ) -> None: diff --git a/tests/core/test_transformer_lens.py b/tests/core/test_transformer_lens.py index d01e1f94..2bea0aa7 100644 --- a/tests/core/test_transformer_lens.py +++ b/tests/core/test_transformer_lens.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field import inspect import re -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict import pytest import torch @@ -59,21 +59,21 @@ class ArchitectureExpectations: has_biases: bool = False # TL-style parameter suffixes expected per layer - expected_tl_attn_params: List[str] = field(default_factory=lambda: ["attn.W_Q", "attn.W_K", "attn.W_V", "attn.W_O"]) - expected_tl_mlp_params: List[str] = field(default_factory=lambda: ["mlp.W_in", "mlp.W_out"]) - expected_tl_embed_params: List[str] = field(default_factory=lambda: ["embed.W_E", "unembed.W_U"]) + expected_tl_attn_params: list[str] = field(default_factory=lambda: ["attn.W_Q", "attn.W_K", "attn.W_V", "attn.W_O"]) + expected_tl_mlp_params: list[str] = field(default_factory=lambda: ["mlp.W_in", "mlp.W_out"]) + expected_tl_embed_params: list[str] = field(default_factory=lambda: ["embed.W_E", "unembed.W_U"]) # Canonical naming patterns (regex) - architecture specific # These are used to validate bidirectional mapping - canonical_attn_pattern: Optional[str] = None - canonical_mlp_pattern: Optional[str] = None - canonical_embed_pattern: Optional[str] = None + canonical_attn_pattern: str | None = None + canonical_mlp_pattern: str | None = None + canonical_embed_pattern: str | None = None # Expected mapping counts for bidirectional validation - expected_mapped_tl_count: Optional[int] = None + expected_mapped_tl_count: int | None = None expected_unmapped_tl_count: int = 0 # All TL params should map - expected_mapped_canonical_count: Optional[int] = None - expected_unmapped_canonical_count: Optional[int] = None + expected_mapped_canonical_count: int | None = None + expected_unmapped_canonical_count: int | None = None # Pre-defined architecture expectations @@ -527,12 +527,12 @@ class TestArchitectureParameterMapping: } @staticmethod - def get_tl_params_from_bridge(bridge) -> Dict[str, torch.Tensor]: + def get_tl_params_from_bridge(bridge) -> dict[str, torch.Tensor]: """Get TL-style parameters from a TransformerBridge instance.""" return dict(bridge.tl_named_parameters()) @staticmethod - def get_canonical_params_from_module(module) -> Dict[str, torch.Tensor]: + def get_canonical_params_from_module(module) -> dict[str, torch.Tensor]: """Get canonical parameters from a LightningModule.""" return dict(module.named_parameters()) @@ -551,7 +551,7 @@ def categorize_tl_param(name: str) -> str: return "other" @staticmethod - def validate_mapping_structure(tl_to_canonical: Dict, canonical_to_tl: Dict) -> Dict[str, Set[str]]: + def validate_mapping_structure(tl_to_canonical: Dict, canonical_to_tl: Dict) -> dict[str, set[str]]: """Validate the structure of bidirectional mappings. Returns: @@ -617,7 +617,7 @@ def test_param_categorization(self): def _validate_tl_param_structure( self, bridge, module, arch_expectations: ArchitectureExpectations - ) -> Tuple[Dict, Dict]: + ) -> tuple[Dict, Dict]: """Verify TransformerBridge has expected TL-style parameter structure. Validates: @@ -828,7 +828,6 @@ def test_custom_model_view(self, get_it_session__l_tl_bridge_gpt2__setup): """Test that a simple custom ModelView can be successfully instantiated and used.""" from interpretune.adapters.model_view import ModelView import os - from typing import Dict, List, Optional, Union # Define a simple custom ModelView that prefixes all param names class PrefixedModelView(ModelView): @@ -842,21 +841,21 @@ def build_param_mapping(self) -> None: """No complex mapping needed for this simple test.""" pass - def transform_to_canonical(self, param_names: List[str], inspect_only: bool = False) -> List[str]: + def transform_to_canonical(self, param_names: list[str], inspect_only: bool = False) -> list[str]: """Strip custom prefix to get canonical names.""" return [ name.replace(self.prefix, "", 1) if name.startswith(self.prefix) else name for name in param_names ] - def transform_from_canonical(self, param_names: List[str]) -> List[str]: + def transform_from_canonical(self, param_names: list[str]) -> list[str]: """Add custom prefix to canonical names.""" return [f"{self.prefix}{name}" for name in param_names] - def get_named_params(self) -> Dict[str, torch.Tensor]: + def get_named_params(self) -> dict[str, torch.Tensor]: """Get params with custom prefix.""" return {f"{self.prefix}{name}": param for name, param in self.pl_module.named_parameters()} - def gen_schedule(self, dump_loc: Union[str, os.PathLike]) -> Optional[os.PathLike]: + def gen_schedule(self, dump_loc: str | os.PathLike) -> os.PathLike | None: """Not tested in this simple test.""" return None diff --git a/tests/data_generation.py b/tests/data_generation.py index c1318a6e..213a107c 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -9,7 +9,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from functools import partial import torch @@ -313,7 +312,7 @@ def gen_or_validate_input_data( input_data=None, num_batches=1, required_only=True, - override_req_cols: Optional[tuple] = None, + override_req_cols: tuple | None = None, predefined_indices=True, ): """Generate or validate input data based on an input schema. diff --git a/tests/dynamic_fixture_benchmark.py b/tests/dynamic_fixture_benchmark.py index 43624873..6bc6f834 100755 --- a/tests/dynamic_fixture_benchmark.py +++ b/tests/dynamic_fixture_benchmark.py @@ -37,7 +37,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Dict, Tuple, Optional, Any +from typing import Any import psutil import torch @@ -85,7 +85,7 @@ def __init__(self): self.gpu_memory_before: float = 0.0 self.gpu_memory_after: float = 0.0 self.gpu_peak_memory: float = 0.0 - self.error: Optional[str] = None + self.error: str | None = None @property def memory_delta(self) -> float: @@ -111,7 +111,7 @@ def get_gpu_memory_usage_mb() -> float: return 0.0 -def discover_generated_fixtures() -> Dict[str, Tuple[Any, str]]: +def discover_generated_fixtures() -> dict[str, tuple[Any, str]]: """Discover all dynamically generated fixtures from conftest.py.""" fixtures = {} @@ -267,7 +267,7 @@ def test_minimal(): Path(baseline_test_path).unlink(missing_ok=True) -def measure_baseline_pytest_startup_with_profiling() -> Tuple[float, Dict[str, str]]: +def measure_baseline_pytest_startup_with_profiling() -> tuple[float, dict[str, str]]: """Measure baseline pytest startup time with import profiling. Returns: @@ -548,7 +548,7 @@ def test_fixture_benchmark({fixture_name}): return metrics -def get_fixture_analysis() -> Dict[str, Dict[str, Any]]: +def get_fixture_analysis() -> dict[str, dict[str, Any]]: """Complete analysis of all fixtures with dynamic discovery and usage counts.""" raw_fixtures = discover_generated_fixtures() @@ -572,7 +572,7 @@ def get_fixture_analysis() -> Dict[str, Dict[str, Any]]: return processed_fixtures -def discover_static_fixtures() -> Dict[str, Dict[str, Any]]: +def discover_static_fixtures() -> dict[str, dict[str, Any]]: """Discover static fixtures from conftest.py.""" conftest_path = Path(__file__).parent / "conftest.py" static_fixtures = {} @@ -621,11 +621,11 @@ def discover_static_fixtures() -> Dict[str, Dict[str, Any]]: def run_full_benchmark( - max_fixtures: Optional[int] = None, -) -> Tuple[ - Dict[str, Tuple[Dict[str, Any], FixtureMetrics]], + max_fixtures: int | None = None, +) -> tuple[ + dict[str, tuple[dict[str, Any], FixtureMetrics]], float, - Dict[str, str], + dict[str, str], ]: """Run comprehensive benchmark of all fixtures.""" fixtures = get_fixture_analysis() @@ -702,9 +702,9 @@ def run_full_benchmark( def generate_markdown_report( - results: Dict[str, Tuple[Dict[str, Any], FixtureMetrics]], + results: dict[str, tuple[dict[str, Any], FixtureMetrics]], baseline_time: float = 0.0, - baseline_artifacts: Optional[Dict[str, str]] = None, + baseline_artifacts: dict[str, str] | None = None, ) -> str: """Generate comprehensive markdown report.""" timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/tests/examples/test_notebooks.py b/tests/examples/test_notebooks.py index 15e48570..5051cfb7 100644 --- a/tests/examples/test_notebooks.py +++ b/tests/examples/test_notebooks.py @@ -10,7 +10,7 @@ import os import shutil from pathlib import Path -from typing import Any, Dict +from typing import Any import pytest @@ -22,7 +22,7 @@ def execute_notebook_with_params( notebook_path: Path, - parameters: Dict[str, Any], + parameters: dict[str, Any], output_dir: Path, timeout: int = 1800, # 30 minutes ) -> Path: @@ -67,7 +67,7 @@ def _cleanup_notebook_artifacts(): def validate_notebook_outputs( output_notebook: Path, - params: Dict[str, Any], + params: dict[str, Any], check_prompt_errors: bool = True, check_analysis_points: bool = True, check_prompt_success: bool = True, @@ -164,7 +164,7 @@ def validate_notebook_outputs( @RunIf(standalone=True, bf16_cuda=True) @pytest.mark.parametrize("params", ATTRIBUTION_ANALYSIS_PARAMS) -def test_attribution_analysis_notebook(params: Dict[str, Any], tmp_path: Path): +def test_attribution_analysis_notebook(params: dict[str, Any], tmp_path: Path): """Test attribution analysis notebook with different parameterizations.""" notebook_path = NOTEBOOKS_DIR / "attribution_analysis" / "attribution_analysis.ipynb" @@ -236,7 +236,7 @@ def test_op_collection_notebooks(notebook_file: str, tmp_path: Path): @RunIf(standalone=True, bf16_cuda=True) @pytest.mark.parametrize("params", CIRCUIT_TRACER_PARAMS) -def test_circuit_tracer_notebooks(params: Dict[str, Any], tmp_path: Path): +def test_circuit_tracer_notebooks(params: dict[str, Any], tmp_path: Path): """Test circuit tracer notebooks with different parameterizations.""" # Use the CLT notebook for these tests notebook_path = NOTEBOOKS_DIR / "circuit_tracer_examples" / "circuit_tracer_adapter_example_basic_clt.ipynb" diff --git a/tests/modules.py b/tests/modules.py index d686dbcf..b26ff830 100644 --- a/tests/modules.py +++ b/tests/modules.py @@ -1,7 +1,7 @@ import os from pathlib import Path from jaxtyping import Float, Int -from typing import Optional, Any, Dict, Union, Callable, List, Tuple +from typing import Any, Dict, Callable, List, Tuple from unittest import mock from functools import reduce, partial from dataclasses import dataclass @@ -57,7 +57,7 @@ def sample_dataset_state(self) -> List: def sample_step_input(self, batch: BatchEncoding) -> List: return [] - def prepare_data(self, target_model: Optional[torch.nn.Module] = None) -> None: + def prepare_data(self, target_model: torch.nn.Module | None = None) -> None: """Load the SuperGLUE dataset.""" tokenization_func = partial( @@ -161,11 +161,11 @@ class TestModelArgs: dropout_p: float = 0.1 use_attn_mask: bool = True weight_tying: bool = True - tokenizer: Optional[Callable] = None - device: Optional[torch.device] = None - dtype: Optional[torch.dtype] = None + tokenizer: Callable | None = None + device: torch.device | None = None + dtype: torch.dtype | None = None # handle below can be used at runtime to allow this model's `generate` to adapt to various configuration contexts - ctx_handle: Optional[ITModuleProtocol] = None + ctx_handle: ITModuleProtocol | None = None def __post_init__(self): if self.ctx_handle: @@ -277,19 +277,19 @@ def forward(self, tokens): @torch.inference_mode() def generate( self, - tokens: Union[str, Float[torch.Tensor, "batch pos"]] = "", + tokens: str | Float[torch.Tensor, "batch pos"] = "", max_new_tokens: int = 5, - eos_token_id: Optional[int] = None, + eos_token_id: int | None = None, output_logits: bool = False, verbose: bool = True, **kwargs, - ) -> Union[ModelOutput, Int[torch.Tensor, "batch pos_plus_new_tokens"]]: + ) -> ModelOutput | Int[torch.Tensor, "batch pos_plus_new_tokens"]: """Toy generate function to support non-HF/TransformerLens tests with the same interface. Args: - tokens (Union[str, Int[torch.Tensor, "batch pos"])]): A batch of tokens ([batch, pos]). + tokens (str | Int[torch.Tensor, "batch pos"])): A batch of tokens ([batch, pos]). max_new_tokens (int): Maximum number of tokens to generate. - eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end of sentence. + eos_token_id (int | Sequence | None): The token ID to use for end of sentence. output_logits (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. verbose (bool): If True, show tqdm progress bars for generation. @@ -369,12 +369,12 @@ class StateLogInspectMixin: def __init__( self, *args, - expected_exact: Optional[Dict] = None, - expected_close: Optional[Dict] = None, - expected_memstats: Optional[Tuple] = None, - tolerance_map: Optional[Dict] = None, - test_alias: Optional[str] = None, - state_log_dir: Optional[str] = None, + expected_exact: Dict | None = None, + expected_close: Dict | None = None, + expected_memstats: Tuple | None = None, + tolerance_map: Dict | None = None, + test_alias: str | None = None, + state_log_dir: str | None = None, **kwargs, ) -> None: self.expected_memstats = expected_memstats @@ -432,7 +432,7 @@ class DivergeOnEpochMixin: current_epoch/max_epochs ratio, causing gradual loss divergence. """ - def __init__(self, *args, diverge_on_epoch: Optional[int] = None, **kwargs) -> None: + def __init__(self, *args, diverge_on_epoch: int | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) it_cfg = reduce(lambda o, a: getattr(o, a, None), ("it_cfg", "model_cfg"), self) self.diverge_on_epoch = it_cfg.get("diverge_on_epoch", diverge_on_epoch) if it_cfg else diverge_on_epoch @@ -491,7 +491,7 @@ def training_step(self, batch: BatchEncoding, batch_idx: int) -> STEP_OUTPUT: return loss @MemProfilerHooks.memprofilable - def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]: + def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT | None: answer_logits, labels, orig_labels = self.logits_and_labels(batch, batch_idx) val_loss = self._compute_diverging_loss(answer_logits, labels) self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) @@ -499,7 +499,7 @@ def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: class BaseTestModule(StateLogInspectMixin): - def __init__(self, *args, req_grad_mask: Optional[Dict] = None, **kwargs) -> None: + def __init__(self, *args, req_grad_mask: Dict | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) self.req_grad_mask = req_grad_mask or {} self.epoch_losses = {} @@ -591,7 +591,7 @@ def on_train_epoch_start(self, *args, **kwargs): def on_train_epoch_end(self, *args, **kwargs): self._epoch_end_validation(*args, **kwargs) - def on_session_end(self) -> Optional[Any]: + def on_session_end(self) -> Any | None: super().on_session_end() if self.it_cfg.memprofiler_cfg and self.expected_memstats: self._validate_memory_stats() diff --git a/tests/orchestration.py b/tests/orchestration.py index 9f811be1..242197a7 100644 --- a/tests/orchestration.py +++ b/tests/orchestration.py @@ -12,7 +12,7 @@ # Initially based on https://bit.ly/3oQ8Vqf from pathlib import Path from unittest import mock -from typing import Dict, Optional, Tuple +from typing import Dict import torch from datasets import Dataset @@ -72,7 +72,7 @@ def init_it_runner(it_session: ITSession, test_cfg: BaseCfg, *args, **kwargs): def run_it( it_session: ITSession, test_cfg: BaseCfg, init_only: bool = False -) -> SessionRunner | AnalysisStoreProtocol | Dict[str, AnalysisStoreProtocol] | None: +) -> SessionRunner | AnalysisStoreProtocol | dict[str, AnalysisStoreProtocol] | None: # Check if test_cfg is an AnalysisBaseCfg and use the appropriate runner initialization if isinstance(test_cfg, AnalysisBaseCfg): runner = init_analysis_runner(it_session, test_cfg) @@ -241,7 +241,7 @@ def run_op_with_config(request, op_cfg: OpTestConfig): def save_reload_results_dataset( - it_session, result_batches, batches, features_format: Optional[Tuple[Dict]] = None, split: str = "validation" + it_session, result_batches, batches, features_format: tuple[Dict] | None = None, split: str = "validation" ): if features_format is not None: features, it_format_kwargs = features_format diff --git a/tests/parity_acceptance/test_it_cli.py b/tests/parity_acceptance/test_it_cli.py index 70c9dc4e..d1bb659a 100644 --- a/tests/parity_acceptance/test_it_cli.py +++ b/tests/parity_acceptance/test_it_cli.py @@ -4,7 +4,7 @@ import sys from subprocess import TimeoutExpired, PIPE from unittest import mock -from typing import Optional, List, Sequence +from typing import List, Sequence from dataclasses import dataclass from interpretune.adapters import ADAPTER_REGISTRY @@ -19,14 +19,14 @@ @dataclass(kw_only=True) class CLICfg: cli_adapter: Adapter = Adapter.core - run: Optional[str] = None - env_seed: Optional[str] = None + run: str | None = None + env_seed: str | None = None compose_cfg: bool = False adapter_ctx: Sequence[Adapter | str] = (Adapter.core,) debug_mode: bool = False use_harness: bool = False - bootstrap_args: Optional[List] = None - extra_args: Optional[ArgsType] = None + bootstrap_args: List | None = None + extra_args: ArgsType | None = None req_deterministic: bool = False def __post_init__(self): @@ -109,8 +109,8 @@ def gen_cli_args( cli_adapter, compose_cfg, config_files, - bootstrap_args: Optional[ArgsType] = None, - extra_args: Optional[ArgsType] = None, + bootstrap_args: ArgsType | None = None, + extra_args: ArgsType | None = None, ): cli_main_kwargs = {"run_mode": run} if run else {"run_mode": False} cli_main_kwargs["args"] = extra_args if extra_args else None diff --git a/tests/parity_acceptance/test_it_fts.py b/tests/parity_acceptance/test_it_fts.py index 72c0eef0..fa9cebd9 100644 --- a/tests/parity_acceptance/test_it_fts.py +++ b/tests/parity_acceptance/test_it_fts.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # TODO: fill in this placeholder with actual core tests -from typing import Optional, Callable, Dict, Any, Sequence +from typing import Callable, Dict, Any, Sequence from copy import copy from dataclasses import dataclass, field from functools import partial @@ -38,21 +38,21 @@ @dataclass(kw_only=True) class FTSParityCfg(BaseCfg): adapter_ctx: Sequence[Adapter | str] = (Adapter.lightning, Adapter.transformer_lens) - model_src_key: Optional[str] = "gpt2" - callback_cfgs: Optional[Dict[Any, Dict]] = field(default_factory=lambda: {}) - limit_train_batches: Optional[int] = 2 - limit_val_batches: Optional[int] = 2 - limit_test_batches: Optional[int] = 2 - max_epochs: Optional[int] = 4 - fts_schedule_key: Optional[tuple] = None - model_cfg: Optional[dict] = field(default_factory=lambda: {}) - max_steps: Optional[int] = -1 + model_src_key: str | None = "gpt2" + callback_cfgs: dict[Any, Dict] | None = field(default_factory=lambda: {}) + limit_train_batches: int | None = 2 + limit_val_batches: int | None = 2 + limit_test_batches: int | None = 2 + max_epochs: int | None = 4 + fts_schedule_key: tuple | None = None + model_cfg: dict | None = field(default_factory=lambda: {}) + max_steps: int | None = -1 save_checkpoints: bool = True @dataclass class FTSParityTest(BaseAugTest): - result_gen: Optional[Callable] = partial(collect_results, fts_parity_results, normalize=False) + result_gen: Callable | None = partial(collect_results, fts_parity_results, normalize=False) PARITY_FTS_CONFIGS = ( diff --git a/tests/parity_acceptance/test_it_l.py b/tests/parity_acceptance/test_it_l.py index 5a1f6c13..eb5ebcfc 100644 --- a/tests/parity_acceptance/test_it_l.py +++ b/tests/parity_acceptance/test_it_l.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Callable +from typing import Callable from dataclasses import dataclass from functools import partial @@ -46,12 +46,12 @@ @dataclass(kw_only=True) class CoreCfg(BaseCfg): - model_src_key: Optional[str] = "cust" + model_src_key: str | None = "cust" @dataclass class ParityTest(BaseAugTest): - result_gen: Optional[Callable] = partial(collect_results, l_parity_results) + result_gen: Callable | None = partial(collect_results, l_parity_results) PARITY_BASIC_CONFIGS = ( @@ -86,13 +86,13 @@ def test_parity_l(recwarn, tmp_path, request, test_alias, test_cfg): @dataclass(kw_only=True) class ProfParityCfg(BaseCfg): - model_src_key: Optional[str] = "gpt2" + model_src_key: str | None = "gpt2" force_prepare_data: bool = True # force data preparation for profiling and CI runner cache reproduction @dataclass class ProfilingTest(BaseAugTest): - result_gen: Optional[Callable] = partial(collect_results, profiling_results) + result_gen: Callable | None = partial(collect_results, profiling_results) L_PROFILING_CONFIGS = ( diff --git a/tests/parity_acceptance/test_it_sl.py b/tests/parity_acceptance/test_it_sl.py index e7da5be0..39abe0dc 100644 --- a/tests/parity_acceptance/test_it_sl.py +++ b/tests/parity_acceptance/test_it_sl.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Callable, Sequence +from typing import Callable, Sequence from dataclasses import dataclass from functools import partial @@ -31,7 +31,7 @@ @dataclass(kw_only=True) class SLParityCfg(BaseCfg): adapter_ctx: Sequence[Adapter | str] = (Adapter.core, Adapter.sae_lens) - model_src_key: Optional[str] = "cust" + model_src_key: str | None = "cust" add_saes_on_init: bool = True # SAE lens doesn't support TransformerBridge yet, must use legacy HookedTransformer # This will be handled in the test configuration by ensuring tl_cfg has use_bridge=False @@ -39,7 +39,7 @@ class SLParityCfg(BaseCfg): @dataclass class SLParityTest(BaseAugTest): - result_gen: Optional[Callable] = partial(collect_results, sl_parity_results) + result_gen: Callable | None = partial(collect_results, sl_parity_results) PARITY_SL_CONFIGS = ( diff --git a/tests/parity_acceptance/test_it_tl.py b/tests/parity_acceptance/test_it_tl.py index 57cf5275..98f0faa6 100644 --- a/tests/parity_acceptance/test_it_tl.py +++ b/tests/parity_acceptance/test_it_tl.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # TODO: fill in this placeholder with actual core tests -from typing import Optional, Callable, Sequence +from typing import Callable, Sequence from dataclasses import dataclass from functools import partial @@ -39,12 +39,12 @@ @dataclass(kw_only=True) class TLParityCfg(BaseCfg): adapter_ctx: Sequence[Adapter | str] = (Adapter.core, Adapter.transformer_lens) - model_src_key: Optional[str] = "cust" + model_src_key: str | None = "cust" @dataclass class TLParityTest(BaseAugTest): - result_gen: Optional[Callable] = partial(collect_results, tl_parity_results) + result_gen: Callable | None = partial(collect_results, tl_parity_results) PARITY_TL_CONFIGS = ( @@ -83,13 +83,13 @@ def test_parity_tl(recwarn, tmp_path, request, test_alias, test_cfg): @dataclass class ProfilingTest(BaseAugTest): - result_gen: Optional[Callable] = partial(collect_results, profiling_results) + result_gen: Callable | None = partial(collect_results, profiling_results) @dataclass(kw_only=True) class TLProfileCfg(BaseCfg): adapter_ctx: Sequence[Adapter | str] = (Adapter.core, Adapter.transformer_lens) - model_src_key: Optional[str] = "gpt2" + model_src_key: str | None = "gpt2" TL_PROFILING_CONFIGS = ( diff --git a/tests/results.py b/tests/results.py index 3de9513f..1a59e15a 100644 --- a/tests/results.py +++ b/tests/results.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union, Dict, NamedTuple +from typing import List, Tuple, Dict, NamedTuple from collections import defaultdict from pathlib import Path import yaml @@ -109,12 +109,12 @@ class MemProfResult(NamedTuple): class TestResult(NamedTuple): - result_alias: Optional[str] = None # N.B. diff test aliases may map to the same result alias (e.g. parity tests) - exact_results: Optional[Dict] = None - close_results: Optional[Tuple] = None - mem_results: Optional[Dict] = None - tolerance_map: Optional[Dict[str, float]] = None - callback_results: Optional[Dict] = None + result_alias: str | None = None # N.B. diff test aliases may map to the same result alias (e.g. parity tests) + exact_results: Dict | None = None + close_results: Tuple | None = None + mem_results: Dict | None = None + tolerance_map: dict[str, float] | None = None + callback_results: Dict | None = None def mem_results(results: Dict, test_alias: str): @@ -129,7 +129,7 @@ def mem_results(results: Dict, test_alias: str): return {**tolerance_map, "expected_memstats": (step_key, expected_mem)} -def close_results(close_map: Tuple, test_alias: Optional[str] = None): +def close_results(close_map: Tuple, test_alias: str | None = None): """Result generation function that packages expected close results with a provided tolerance dict or generates a default one based upon the test_alias.""" expected_close = defaultdict(dict) @@ -141,25 +141,25 @@ def close_results(close_map: Tuple, test_alias: Optional[str] = None): return {**closestats_tol, "expected_close": expected_close} -def exact_results(expected_exact: Tuple, test_alias: Optional[str] = None): +def exact_results(expected_exact: Tuple, test_alias: str | None = None): """Result generation function that packages.""" return {"expected_exact": expected_exact} -def callback_results(callback_results: Dict, test_alias: Optional[str] = None): +def callback_results(callback_results: Dict, test_alias: str | None = None): """Result generation function that packages.""" return {"callback_results": callback_results} class DatasetState(NamedTuple): tokenizer_name: str - deterministic_token_ids: List[int] + deterministic_token_ids: list[int] expected_first_fwd_ids: List def def_results( device_type: str, - precision: Union[int, str], + precision: int | str, ds_cfg: str = "no_sample", task_name: str = default_test_task, tokenizer_cls_name: str = "GPT2TokenizerFast", @@ -192,7 +192,7 @@ def parity_normalize(test_alias) -> str: return test_alias -def collect_results(result_map: Dict[str, Tuple], test_alias: str, normalize: bool = True): +def collect_results(result_map: dict[str, Tuple], test_alias: str, normalize: bool = True): if normalize: test_alias = parity_normalize(test_alias) test_result: TestResult = result_map[test_alias] diff --git a/tests/runif.py b/tests/runif.py index d778d2fa..36b2e352 100644 --- a/tests/runif.py +++ b/tests/runif.py @@ -13,7 +13,7 @@ import os import re import sys -from typing import Optional, Union, Dict, Set +from typing import Dict, Set import pytest import torch @@ -25,7 +25,7 @@ EXTENDED_VER_PAT = re.compile(r"([0-9]+\.){2}[0-9]+") -def maybe_mark_exp(exp_patch_set: Set[ExpPatch], mark_if_false: Optional[Dict] = None): +def maybe_mark_exp(exp_patch_set: set[ExpPatch], mark_if_false: Dict | None = None): """This allows us to evaluate whether an experimental patch set that is conditionally required for a given test is required in the current execution context. @@ -96,11 +96,11 @@ def __new__( self, *args, min_cuda_gpus: int = 0, - min_torch: Optional[str] = None, - max_torch: Optional[str] = None, - min_python: Optional[str] = None, - max_python: Optional[str] = None, - env_mask: Optional[str] = None, + min_torch: str | None = None, + max_torch: str | None = None, + min_python: str | None = None, + max_python: str | None = None, + env_mask: str | None = None, bf16_cuda: bool = False, skip_windows: bool = False, skip_mac_os: bool = False, @@ -111,7 +111,7 @@ def __new__( lightning: bool = False, finetuning_scheduler: bool = False, bitsandbytes: bool = False, - exp_patch: Optional[Union[ExpPatch, Set[ExpPatch]]] = None, + exp_patch: ExpPatch | set[ExpPatch] | None = None, **kwargs, ): """ diff --git a/tests/utils.py b/tests/utils.py index 27a114dd..071c6ba8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Iterator, Dict, Optional, Union, Type, Any, Callable, NamedTuple +from typing import Tuple, List, Iterator, Dict, Type, Any, Callable, NamedTuple import random import importlib from collections import defaultdict @@ -37,7 +37,7 @@ def _recursive_defaultdict(): # useful for manipulating segments of nested dictionaries (e.g. generating config file sets for CLI composition tests) -def set_nested(chained_keys: List | str, orig_dict: Optional[Dict] = None): +def set_nested(chained_keys: List | str, orig_dict: Dict | None = None): orig_dict = {} if orig_dict is None else orig_dict chained_keys = chained_keys if isinstance(chained_keys, list) else chained_keys.split(".") reduce(lambda d, k: d.setdefault(k, {}), chained_keys, orig_dict) @@ -151,7 +151,7 @@ def kwargs_from_cfg_obj(cfg_obj, source_obj, base_kwargs=None): return kwargs -def get_super_method(cls_path_or_type: Union[str, Type], instance: Any, method_name: str) -> Callable: +def get_super_method(cls_path_or_type: str | Type, instance: Any, method_name: str) -> Callable: """Retrieves a method from a parent class by using standard super() resolution. This is useful for testing specific implementations of methods that might be overridden in subclasses. diff --git a/tests/warns.py b/tests/warns.py index d34b2743..b876c991 100644 --- a/tests/warns.py +++ b/tests/warns.py @@ -1,6 +1,6 @@ import re from functools import partial -from typing import List, Optional +from typing import List from warnings import WarningMessage from packaging.version import Version from importlib.metadata import version as get_version @@ -97,7 +97,7 @@ def multiwarn_check( rec_warns: List, expected_warns: List | str, expected_mode: bool = False -) -> List[Optional[WarningMessage]]: +) -> list[WarningMessage | None]: if isinstance(expected_warns, str): expected_warns = [expected_warns] msg_search = lambda w1, w2: re.compile(w1).search(w2.message.args[0]) # noqa: E731