Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fb888de
Initial plan
Copilot Aug 27, 2025
30c8116
Fix type errors in registration.py, lightning.py, and registry.py
Copilot Aug 27, 2025
b57c3a0
Fix type errors in transformer_lens.py and sae_lens.py
Copilot Aug 27, 2025
ad7b07f
Complete Phase 4 type checking fixes for 8 files with 20+ errors reso…
Copilot Aug 27, 2025
0759623
simplify copilot instructions and setup steps to use editable install
speediedan Aug 27, 2025
172c93a
add it_release scripts to build env
speediedan Aug 27, 2025
b50b41f
Fix 9 type checking errors with minimal targeted changes
Copilot Aug 27, 2025
516b28c
Complete type checking cleanup: Fix all 14 remaining errors and add 1…
Copilot Aug 27, 2025
933682b
Fix 18 type errors in definitions.py and runner.py with targeted changes
Copilot Aug 27, 2025
07c688b
Fix 43 type errors in debug_generation.py and circuit_tracer.py with …
Copilot Aug 27, 2025
2988ba3
Fix 12 type errors in cli.py with Lightning CLI compatibility patterns
Copilot Aug 27, 2025
13c2b34
Fix 46 type errors in runners/analysis.py with protocol compatibility…
Copilot Aug 27, 2025
b3b9c19
update regen_reqfiles.py to use --update flag to avoid relying on cac…
speediedan Aug 27, 2025
02cda9b
minor cleanup and refactor
speediedan Aug 28, 2025
db5732f
Phase 4+ Type Checking Cleanup: Fix 197 errors across 10 files comple…
Copilot Aug 28, 2025
78a6c7a
Complete type checking cleanup: Remove all custom pyright directives …
Copilot Aug 28, 2025
62acef3
fix type errors generated by `generate_op_stubs.py`, remove type warn…
speediedan Aug 28, 2025
c78e3ec
Enable type checking for 12 it_examples files with minimal surgical f…
Copilot Aug 29, 2025
d748380
Complete type checking for all it_examples Python files - 13 files to…
Copilot Aug 29, 2025
c2957b9
enhance type stub generation for static analysis
speediedan Aug 29, 2025
d876d61
Fix remaining 4 type errors in rte_boolq.py to enable complete type c…
Copilot Aug 29, 2025
d87b7e3
fixed op collection example hub op namespacing bug, converted pyproje…
speediedan Aug 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
### Required Before Each Commit
- Unless guidance in a comment or a pull request or target issue description states otherwise, always run our basic tests (which mirror the `ci_test-full.yml` workflow) in your local environment and ensure all tests are passing before committing, for example:
```bash
cd /home/runner/work/interpretune/interpretune && python -m pytest src/interpretune tests -v --import-mode=importlib
cd /home/runner/work/interpretune/interpretune && python -m pytest src/interpretune tests -v
```
- Ensure all pre-commit hooks pass.
- If the copilot session is still failing despite trying to get tests and pre-commit hooks passing for some time, it's okay to commit your intermediate work with a comment about the present challenge to be dealt with in a subsequent session.
Expand All @@ -36,10 +36,10 @@ Always install dependencies in order to avoid conflicts:
```bash
# Basic development setup
python -m pip install --upgrade pip setuptools wheel build
python -m pip install '.[test]' -c requirements/ci_constraints.txt

# Full development with examples (may fail due to circuit-tracer dependency)
python -m pip install '.[test,examples,lightning]' -c requirements/ci_constraints.txt
python -m pip install -r requirements/ci/requirements.txt -r requirements/platform_dependent.txt
python -m pip install -e '.[test,examples,lightning]'
# enabled until further notice (will be removed from instructions once no longer necessary)
pip install --upgrade -r requirements/post_upgrades.txt

# If circuit-tracer install fails, use the built-in tool after basic install:
pip install interpretune[examples]
Expand Down Expand Up @@ -79,10 +79,10 @@ pre-commit run --all-files
**Test command:**
```bash
# Basic test run (requires full dependencies)
cd /home/runner/work/interpretune/interpretune && python -m pytest src/interpretune tests -v --import-mode=importlib
cd /home/runner/work/interpretune/interpretune && python -m pytest src/interpretune tests -v

# With coverage (as used in non-editable CI setup)
python -m coverage run --append --source src/interpretune -m pytest src/interpretune tests -v --import-mode=importlib
# With coverage
python -m coverage run --append --source src/interpretune -m pytest src/interpretune tests -v
coverage report

# Test collection only (to check test discovery)
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/copilot-setup-steps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ jobs:
python -m pip install --upgrade pip setuptools wheel build
# Prefer pinned CI requirements if present
if [ -f requirements/ci/requirements.txt ]; then
pip install -r requirements/ci/requirements.txt
python -m pip install -r requirements/ci/requirements.txt -r requirements/platform_dependent.txt
python -m pip install -e '.[test,examples,lightning]'
else
python -m pip install '.[test,examples,lightning]' -c requirements/ci_constraints.txt
python -m pip install -e '.[test,examples,lightning]' -c requirements/ci_constraints.txt
fi
# Optional post-upgrades (disabled by default)
if [ "${APPLY_POST_UPGRADES:-1}" = "1" ] && [ -s requirements/post_upgrades.txt ]; then
Expand Down
101 changes: 4 additions & 97 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,109 +175,16 @@ max-complexity = 10
autoSearchPaths=false
typeCheckingMode="standard"
include = [
"src/interpretune/__about__.py",
"src/interpretune/__init__.py",
"src/interpretune/adapter_registry.py",
"src/interpretune/protocol.py",
"src/interpretune/session.py",
# "src/interpretune/registry.py", # 15 errors - protocol compatibility, dict typing

# Adapter files
"src/interpretune/adapters/__init__.py",
"src/interpretune/adapters/core.py",
# "src/interpretune/adapters/lightning.py", # 2 errors
# "src/interpretune/adapters/registration.py", # 1 error
# "src/interpretune/adapters/transformer_lens.py", # 6 errors
# "src/interpretune/adapters/sae_lens.py", # 4 errors
# "src/interpretune/adapters/circuit_tracer.py", # 23 errors

# Analysis files
"src/interpretune/analysis/__init__.py",
"src/interpretune/analysis/formatters.py",
# "src/interpretune/analysis/core.py", # 78 errors - very complex type issues
# "src/interpretune/analysis/formatters.py", # 1 error
# "src/interpretune/analysis/ops/definitions.py", # 8 errors
"src/interpretune/analysis/ops/dynamic_module_utils.py",
"src/interpretune/analysis/ops/__init__.py",
"src/interpretune/analysis/ops/auto_columns.py",
"src/interpretune/analysis/ops/dispatcher.py",
"src/interpretune/analysis/ops/base.py",
"src/interpretune/analysis/ops/hub_manager.py",
"src/interpretune/analysis/ops/compiler/__init__.py",
"src/interpretune/analysis/ops/compiler/cache_manager.py",
"src/interpretune/analysis/ops/compiler/schema_compiler.py",

# Base components
"src/interpretune/base/call.py",
"src/interpretune/base/hooks.py",
"src/interpretune/base/modules.py",
"src/interpretune/base/components/__init__.py",
# "src/interpretune/base/components/mixins.py", # 35 errors - optional member access, callable issues
# "src/interpretune/base/components/cli.py", # 24 errors - path operations, argument types
# "src/interpretune/base/components/core.py", # 1 error
"src/interpretune/base/datamodules.py",

# Config files
"src/interpretune/config/__init__.py",
"src/interpretune/config/circuit_tracer.py",
"src/interpretune/config/datamodule.py",
"src/interpretune/config/extensions.py",
"src/interpretune/config/mixins.py",
"src/interpretune/config/module.py",
"src/interpretune/config/sae_lens.py",
"src/interpretune/config/transformer_lens.py",
"src/interpretune/config/analysis.py",
# "src/interpretune/config/runner.py", # 10 errors
# "src/interpretune/config/shared.py", # 11 errors - type variable usage, tuple operations


# Extensions
"src/interpretune/extensions/__init__.py",
"src/interpretune/extensions/memprofiler.py",
"src/interpretune/extensions/neuronpedia.py",
# "src/interpretune/extensions/debug_generation.py", # 20 errors

# Runners
"src/interpretune/runners/__init__.py",
# "src/interpretune/runners/analysis.py", # 28 errors - protocol mismatches, argument types
# "src/interpretune/runners/core.py", # 3 errors

# Tools
"src/interpretune/tools/",

# Utils
"src/interpretune/utils/__init__.py",
"src/interpretune/utils/data_movement.py",
"src/interpretune/utils/exceptions.py",
"src/interpretune/utils/import_utils.py",
"src/interpretune/utils/logging.py",
"src/interpretune/utils/schema_validation.py",
"src/interpretune/utils/tokenization.py",
"src/interpretune/utils/warnings.py",

"src/",
]
exclude = [
# patch of external code
"src/interpretune/utils/patched_tlens_generate.py",
"src/it_examples/patching/patched_sae_from_pretrained.py",
"src/it_examples/raw_graph_analysis.py",
]

reportMissingTypeStubs = false
reportUnknownMemberType = "none"
reportUnknownArgumentType = "none"
reportUnknownVariableType = "none"
reportUntypedFunctionDecorator = "none"
reportUnnecessaryIsInstance = "none"
reportUnnecessaryComparison = "none"
reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportPrivateUsage = "none"
reportDeprecated = "none"
reportPrivateImportUsage = "none"
reportUnusedImport = "none"
reportUnusedVariable = "none"
reportImplicitOverride = "none"
reportNoReturnInFunction = "none"
reportMissingImports = false
# Custom directives removed for stricter type checking

[tool.coverage.run]
source = ["src/interpretune"]
Expand Down
15 changes: 6 additions & 9 deletions requirements/ci/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@6c74ea291c410bb3391e572cd6a8d020be714922
# via -r requirements/ci/requirements.in
coverage==7.10.3
coverage==7.10.5
# via
# -r requirements/ci/requirements.in
# nbval
Expand Down Expand Up @@ -42,11 +42,11 @@ neuronpedia==1.0.22
# via -r requirements/ci/requirements.in
notebook==7.4.5
# via -r requirements/ci/requirements.in
peft==0.17.0
peft==0.17.1
# via -r requirements/ci/requirements.in
pip-tools==7.5.0
# via -r requirements/ci/requirements.in
plotly==5.24.1
plotly==6.3.0
# via
# -r requirements/ci/requirements.in
# plotly-express
Expand All @@ -55,14 +55,13 @@ pre-commit==4.3.0
# via -r requirements/ci/requirements.in
psycopg==3.2.9
# via -r requirements/ci/requirements.in
pyright==1.1.403
pyright==1.1.404
# via -r requirements/ci/requirements.in
pytest==8.4.1
# via
# -r requirements/ci/requirements.in
# circuit-tracer
# nbval
# pytest-profiling
# pytest-rerunfailures
pytest-rerunfailures==15.1
# via -r requirements/ci/requirements.in
Expand All @@ -71,12 +70,10 @@ python-dotenv==1.1.1
# -r requirements/ci/requirements.in
# neuronpedia
# sae-lens
sae-lens==6.6.0
sae-lens==6.7.0
# via -r requirements/ci/requirements.in
scikit-learn==1.7.1
# via
# -r requirements/ci/requirements.in
# automated-interpretability
# via -r requirements/ci/requirements.in
tabulate==0.9.0
# via -r requirements/ci/requirements.in
toml==0.10.2
Expand Down
3 changes: 2 additions & 1 deletion requirements/regen_reqfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def run_pip_compile(req_in_path, output_path):
if not pip_compile:
print("pip-compile not found in PATH; install pip-tools to generate full pinned requirements.txt")
return False
cmd = [pip_compile, "--output-file", output_path, req_in_path]
# Use --upgrade to ensure we don't rely on cached/resolved older versions when regenerating pins
cmd = [pip_compile, "--output-file", output_path, req_in_path, "--upgrade"]
print("Running:", " ".join(shlex.quote(c) for c in cmd))
subprocess.check_call(cmd)
return True
Expand Down
8 changes: 4 additions & 4 deletions scripts/build_it_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ base_env_build(){
pip install ${pip_install_flags} torch torchvision --index-url https://download.pytorch.org/whl/cu128
fi
;;
# it_latest_pt_2_4)
# clear_activate_env python3.11
# pip install ${pip_install_flags} torch==2.4.1 torchvision --index-url https://download.pytorch.org/whl/cu118
# ;;
it_release)
clear_activate_env python3.12
pip install ${pip_install_flags} torch torchvision --index-url https://download.pytorch.org/whl/cu128
;;
*)
echo "no matching environment found, exiting..."
exit 1
Expand Down
2 changes: 1 addition & 1 deletion scripts/gen_it_coverage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ env_rebuild(){
${repo_home}/scripts/build_it_env.sh --repo_home=${repo_home} --target_env_name=$1 ${fts_from_source_param} ${ct_from_source_param} ${pip_flags_param} ${no_commit_pin_param} ${apply_post_upgrades_param}
fi
;;
it_latest_pt_2_4 )
it_release )
${repo_home}/scripts/build_it_env.sh --repo_home=${repo_home} --target_env_name=$1 ${fts_from_source_param} ${ct_from_source_param} ${pip_flags_param} ${no_commit_pin_param} ${apply_post_upgrades_param}
;;
*)
Expand Down
70 changes: 56 additions & 14 deletions scripts/generate_op_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def format_schema_doc(schema_dict: Dict) -> str:
field_str += " (required)"
lines.append(field_str)

return "\n ".join(lines)
return "\n".join(lines)


def wrap_signature(name: str, params: List[str], return_type: str = "", max_width: int = 120) -> str:
Expand All @@ -97,17 +97,33 @@ def wrap_signature(name: str, params: List[str], return_type: str = "", max_widt
return signature


def format_docstring(description: str, input_schema: Dict, output_schema: Dict) -> str:
def format_docstring(
description: str, input_schema: Dict, output_schema: Dict, function_param_defaults: Dict[str, str] | None = None
) -> str:
"""Format a docstring with proper wrapping and sections."""
doc_lines = [f'"""{description}']

if input_schema:
doc_lines.append("\nInput Schema:")
doc_lines.append(f" {format_schema_doc(input_schema)}")
doc_lines.append("\n Input Schema:")
schema_doc = format_schema_doc(input_schema)
if schema_doc:
# Add proper indentation to each line
indented_schema = "\n".join(f" {line}" for line in schema_doc.split("\n"))
doc_lines.append(indented_schema)

if output_schema:
doc_lines.append("\nOutput Schema:")
doc_lines.append(f" {format_schema_doc(output_schema)}")
doc_lines.append("\n Output Schema:")
schema_doc = format_schema_doc(output_schema)
if schema_doc:
# Add proper indentation to each line
indented_schema = "\n".join(f" {line}" for line in schema_doc.split("\n"))
doc_lines.append(indented_schema)

# Document any function-parameter defaults that were present in the YAML (FQ callable paths).
if function_param_defaults:
doc_lines.append("\n Function parameter defaults (from YAML):")
for param_name, fq_path in function_param_defaults.items():
doc_lines.append(f" {param_name}: {fq_path}")

doc_lines.append('"""')
return "\n".join(doc_lines)
Expand All @@ -125,17 +141,21 @@ def generate_operation_stub(op_name: str, op_def: Dict[str, Any], yaml_content:

# Create parameters list
params = []
# Collect function-parameter defaults to document them in the docstring
function_param_defaults: Dict[str, str] = {}
for name, param in sig.parameters.items():
annotation = format_type_annotation(param.annotation)
if annotation:
annotation = f": {annotation}"

default = ""
if param.default is not param.empty:
# Check if this parameter has a corresponding function_param in the YAML definition
if "function_params" in op_def and name in op_def["function_params"]:
# Use fully qualified function name as a string
default = f' = "{op_def["function_params"][name]}"'
# Check if this parameter has a corresponding importable_param in the YAML definition
if "importable_params" in op_def and name in op_def["importable_params"]:
# DO NOT emit the FQ path as the default in the stub (string default breaks type checkers).
# Instead, set default to ... and record the FQ path for documentation in the docstring.
default = " = ..."
function_param_defaults[name] = op_def["importable_params"][name]
elif param.default is None:
default = " = None"
elif isinstance(param.default, str):
Expand All @@ -151,9 +171,12 @@ def generate_operation_stub(op_name: str, op_def: Dict[str, Any], yaml_content:
# Create function signature
signature = wrap_signature(op_name, params, return_type)

# Create formatted docstring
# Create formatted docstring (include the recorded function_param_defaults)
docstring = format_docstring(
op_def.get("description", ""), op_def.get("input_schema", {}), op_def.get("output_schema", {})
op_def.get("description", ""),
op_def.get("input_schema", {}),
op_def.get("output_schema", {}),
function_param_defaults or None,
)

# Build the complete stub
Expand Down Expand Up @@ -197,7 +220,7 @@ def generate_composition_stub(op_name: str, op_def: Dict[str, Any]) -> str:
)

# Create docstring
doc = f' """Composition of operations: {composition_str}'
doc = f' """Composition of operations:\n {composition_str}'
if "description" in op_def:
doc += f"\n\n {op_def['description']}"
doc += '\n """'
Expand All @@ -221,11 +244,30 @@ def generate_stubs(yaml_path: Path, output_path: Path) -> None:
'"""Type stubs for Interpretune analysis operations."""',
"# This file is auto-generated. Do not modify directly.",
"",
"from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Sequence, Literal",
"from typing import Callable, Optional",
"import torch",
"from transformers import BatchEncoding",
"from interpretune.protocol import BaseAnalysisBatchProtocol, DefaultAnalysisBatchProtocol",
"",
"# Main module exports - added for static analysis",
"# These imports resolve pyright 'unknown import symbol' errors caused by the complex import hook",
"# mechanism used for analysis operations.",
"from interpretune.base.datamodules import ITDataModule as ITDataModule",
"from interpretune.base.components.mixins import MemProfilerHooks as MemProfilerHooks",
"from interpretune.analysis.ops import AnalysisBatch as AnalysisBatch",
"from interpretune.config import (",
" ITLensConfig as ITLensConfig,",
" SAELensConfig as SAELensConfig,",
" PromptConfig as PromptConfig,",
" ITDataModuleConfig as ITDataModuleConfig,",
" ITConfig as ITConfig,",
" GenerativeClassificationConfig as GenerativeClassificationConfig,",
" BaseGenerationConfig as BaseGenerationConfig,",
" HFGenerationConfig as HFGenerationConfig,",
")",
"from interpretune.utils import rank_zero_warn as rank_zero_warn, sanitize_input_name as sanitize_input_name",
"from interpretune.protocol import STEP_OUTPUT as STEP_OUTPUT",
"",
"# Basic operations",
"",
]
Expand Down
Loading