diff --git a/pyproject.toml b/pyproject.toml index d1af2ba..7241286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ examples = [ "neuronpedia", # TODO: add our packaged circuit-tracer dep (either pypi or pypi fork) once it is available and remove the # `install_circuit_tracer` tool -# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3", +# "circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87", ] docs = [ diff --git a/requirements/ci/circuit_tracer_pin.txt b/requirements/ci/circuit_tracer_pin.txt index e4cd311..2f48da4 100644 --- a/requirements/ci/circuit_tracer_pin.txt +++ b/requirements/ci/circuit_tracer_pin.txt @@ -1 +1 @@ -b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3 +004f1b2822eca3f0c1ddd2389e9105b3abffde87 diff --git a/requirements/ci/requirements.in b/requirements/ci/requirements.in index 5d01688..faaaf7b 100644 --- a/requirements/ci/requirements.in +++ b/requirements/ci/requirements.in @@ -33,4 +33,4 @@ pip < 25.3 huggingface_hub[hf_xet] nbmake >= 1.5.0 papermill >= 2.4.0 -git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3#egg=circuit-tracer +git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87#egg=circuit-tracer diff --git a/requirements/ci/requirements.txt b/requirements/ci/requirements.txt index f60b0d6..335c4ed 100644 --- a/requirements/ci/requirements.txt +++ b/requirements/ci/requirements.txt @@ -4,7 +4,7 @@ # # pip-compile --no-strip-extras --output-file=requirements/ci/requirements.txt requirements/ci/requirements.in # -circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3 +circuit-tracer @ git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87 # via -r requirements/ci/requirements.in coverage==7.11.0 # via diff --git a/scripts/build_it_env.sh b/scripts/build_it_env.sh index 8ac05fa..9e64f09 100755 --- a/scripts/build_it_env.sh +++ b/scripts/build_it_env.sh @@ -195,7 +195,7 @@ it_install(){ # Verify only the editable source installation is installed echo "Verifying circuit-tracer installation..." - if pip show circuit_tracer | grep -q "Editable project location:"; then + if pip show circuit_tracer 2>/dev/null | grep -q "Editable project location:"; then echo "✓ circuit_tracer is installed in editable mode" else echo "✗ circuit_tracer is not installed in editable mode" diff --git a/scripts/infra_utils.sh b/scripts/infra_utils.sh index 9122e17..50465f2 100755 --- a/scripts/infra_utils.sh +++ b/scripts/infra_utils.sh @@ -19,7 +19,5 @@ show_elapsed_time(){ # Function to safely deactivate a virtual environment if one is active maybe_deactivate(){ - if [ -n "$VIRTUAL_ENV" ]; then - deactivate - fi + deactivate 2>/dev/null || true } diff --git a/src/interpretune/adapters/circuit_tracer.py b/src/interpretune/adapters/circuit_tracer.py index 02ef012..4f6bdd0 100644 --- a/src/interpretune/adapters/circuit_tracer.py +++ b/src/interpretune/adapters/circuit_tracer.py @@ -115,30 +115,24 @@ def set_input_require_grads(self) -> None: # Circuit tracer handles gradient requirements internally rank_zero_info("Input gradient requirements handled by circuit tracer internally.") - def _get_analysis_target_indices(self) -> Optional[torch.Tensor]: - """Determine the value for compute_specific_logits based on CircuitTracerConfig. + def _get_attribution_targets(self) -> Optional[list | torch.Tensor]: + """Determine the attribution_targets value based on CircuitTracerConfig. - Returns a 1D tensor of token ids, or None. + Returns: + - None: Auto-select salient logits (default behavior) + - list[str]: Token strings to analyze (will be converted by AttributionTargets) + - torch.Tensor: Tensor of token IDs """ cfg = self.circuit_tracer_cfg if not cfg: return None - # If analysis_target_tokens is set, tokenize them + # If analysis_target_tokens is set, return as list of strings + # AttributionTargets will handle tokenization internally if cfg.analysis_target_tokens is not None: - tokenizer = self.datamodule.tokenizer if self.datamodule else self.it_cfg.tokenizer - # Tokenize and flatten to 1D tensor of token ids - token_ids = [] - for token in cfg.analysis_target_tokens: - assert tokenizer is not None, "Tokenizer must be available to tokenize analysis_target_tokens" - ids = tokenizer.encode(token, add_special_tokens=False) - token_ids.extend(ids) - if token_ids: - return torch.tensor(token_ids, dtype=torch.long) - else: - return None + return cfg.analysis_target_tokens - # If target_token_ids is set + # If target_token_ids is set, process it if cfg.target_token_ids is not None: ids = cfg.target_token_ids if isinstance(ids, torch.Tensor): @@ -155,7 +149,7 @@ def _get_analysis_target_indices(self) -> Optional[torch.Tensor]: else: return None - # If neither is set, return None + # If neither is set, return None (use salient logits) return None def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph: @@ -165,18 +159,18 @@ def generate_attribution_graph(self, prompt: str, **kwargs) -> Graph: cfg = self.circuit_tracer_cfg - # Determine compute_specific_logits using the new method - analysis_target_indices = self._get_analysis_target_indices() + # Determine attribution_targets using the new method + attribution_targets = self._get_attribution_targets() # Set default attribution parameters attribution_kwargs = { + "attribution_targets": attribution_targets, "max_n_logits": cfg.max_n_logits if cfg else 10, "desired_logit_prob": cfg.desired_logit_prob if cfg else 0.95, "batch_size": cfg.batch_size if cfg else 256, "max_feature_nodes": cfg.max_feature_nodes if cfg else None, "offload": cfg.offload if cfg else None, "verbose": cfg.verbose if cfg else True, - "analysis_target_indices": analysis_target_indices, } # Override with any provided kwargs diff --git a/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py b/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py index 76a2f07..460db76 100644 --- a/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py +++ b/src/it_examples/notebooks/dev/attribution_analysis/analysis_points.py @@ -68,48 +68,49 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None: v = get_analysis_vars( context_keys=["target_token_analysis"], local_keys=[ - "logit_idx", - "logit_p", + "targets", "total_nodes", "total_active_feats", "max_feature_nodes", "edge_matrix", "row_to_node_index", - "logit_vecs", "n_layers", "n_pos", ], local_vars=local_vars, ) tta = v["target_token_analysis"] - tta.update_logit_info(v["logit_idx"], v["logit_p"]) - max_n_logits = len(v["logit_idx"]) + targets = v["targets"] + # Extract token IDs and probabilities from targets object + logit_idx, logit_p = targets.token_ids, targets.logit_probabilities + tta.update_logit_info(logit_idx, logit_p) + max_n_logits = len(targets) HOOK_REGISTRY.set_context( max_feature_nodes=v["max_feature_nodes"], total_nodes=v["total_nodes"], - logit_idx=v["logit_idx"], - logit_p=v["logit_p"], + logit_idx=logit_idx, + logit_p=logit_p, n_pos=v["n_pos"], max_n_logits=max_n_logits, ) data = { - "logit_idx": v["logit_idx"], - "logit_p": VarAnnotate("logit_p", var_value=v["logit_p"], annotation="non-demeaned logits probabilities"), + "logit_idx": logit_idx, + "logit_p": VarAnnotate("logit_p", var_value=logit_p, annotation="logit probabilities"), "target_tokens": tta.tokens, "target_logit_indices": tta.logit_indices, "target_logit_p": tta.logit_probabilities, - "logit_cumulative_prob": float(v["logit_p"].sum().item()), + "logit_cumulative_prob": float(logit_p.sum().item()), "total_nodes": v["total_nodes"], "max_feature_nodes": v["max_feature_nodes"], "total_active_feats": v["total_active_feats"], - "n_logits": len(v["logit_idx"]), + "n_logits": len(targets), "n_layers": v["n_layers"], "n_pos": v["n_pos"], "max_n_logits": max_n_logits, "edge_matrix.shape": v["edge_matrix"].shape, "row_to_node_index.shape": v["row_to_node_index"].shape if v["row_to_node_index"] is not None else None, - "logit_vecs.shape": v["logit_vecs"].shape, + "logit_vecs.shape": targets.logit_vectors.shape, } analysis_log_point("after building input vectors w/ target logits", data) @@ -151,7 +152,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], - local_keys=["n_visited", "max_feature_nodes", "logit_p", "edge_matrix", "ctx"], + local_keys=["n_visited", "max_feature_nodes", "targets", "edge_matrix", "ctx"], local_vars=local_vars, ) tta = v["target_token_analysis"] diff --git a/src/it_examples/notebooks/dev/attribution_analysis/attribution_analysis.ipynb b/src/it_examples/notebooks/dev/attribution_analysis/attribution_analysis.ipynb index 1bd29e8..f11e95e 100644 --- a/src/it_examples/notebooks/dev/attribution_analysis/attribution_analysis.ipynb +++ b/src/it_examples/notebooks/dev/attribution_analysis/attribution_analysis.ipynb @@ -61,7 +61,7 @@ "if should_install_package(\"circuit-tracer\"):\n", " # Note: Using line continuation for long git URL\n", " !python -m pip install \\\n", - " 'git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3'" + " 'git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87'" ] }, { diff --git a/src/it_examples/notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb b/src/it_examples/notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb index fb34bf7..bec9afa 100644 --- a/src/it_examples/notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb +++ b/src/it_examples/notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb @@ -107,7 +107,9 @@ "source": [ "# Parameters - These will be injected by papermill during parameterized test runs\n", "use_baseline_salient_logits = True # logits computation mode: True->salient logits, False->specific logits\n", - "use_baseline_transcoder_arch = True # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n", + "use_baseline_transcoder_arch = (\n", + " False # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n", + ")\n", "core_log_dir = None # Directory to save analysis logs (if None, a temp directory will be created)" ] }, @@ -347,11 +349,13 @@ "if port_forwarding:\n", " hostname = \"localhost\" # use localhost for port forwarding\n", " print(\n", - " f\"Using port forwarding (ensure it is configured) and localhost. Open your graph here at http://{hostname}:{port}/index.html\"\n", + " f\"Using port forwarding (ensure it is configured) and localhost.\"\n", + " f\" Open your graph here at http://{hostname}:{port}/index.html\"\n", " )\n", "else:\n", " print(\n", - " f\"Not using port forwarding. Use the IFrame below, or open your graph here directly at http://{hostname}:{port}/index.html\"\n", + " f\"Not using port forwarding. Use the IFrame below, or\"\n", + " f\" open your graph here directly at http://{hostname}:{port}/index.html\"\n", " )\n", "\n", "if enable_iframe:\n", diff --git a/src/it_examples/notebooks/publish/.notebook_hashes.json b/src/it_examples/notebooks/publish/.notebook_hashes.json index 7461ce0..9620ec2 100644 --- a/src/it_examples/notebooks/publish/.notebook_hashes.json +++ b/src/it_examples/notebooks/publish/.notebook_hashes.json @@ -1,9 +1,9 @@ { "notebooks/dev/attribution_analysis/analysis_injection_config.yaml": "0153616832195ab9adfc9008cdf7f26d73426128f9fe01b9a614929018e1a53a", - "notebooks/dev/attribution_analysis/analysis_points.py": "a74256e26f10d00b4e0abeeb69f3bab7332357dc19ab006fa56afd1d0502db94", - "notebooks/dev/attribution_analysis/attribution_analysis.ipynb": "f70568bf095c184fd50123560e76e1ee4f0cb72cbd1529c841c993680b02fe54", + "notebooks/dev/attribution_analysis/analysis_points.py": "b27a7eb3d6cf02194ed776b4d7e9a885c2ebb563aac73822e3d21b0366c3e8a0", + "notebooks/dev/attribution_analysis/attribution_analysis.ipynb": "89d1509a99ba14a391dbde18d4bda68b276551b52eb4207942b7308f06980a00", "notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic.ipynb": "760c051bfb3932815e400918724661e46ef49fbe10a4e9231442653707dc83c9", - "notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb": "c74595c292c5964558943f3a2d8786e8e3e75425ab56eb48d112518d9133496e", + "notebooks/dev/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb": "c101e6594dd34cc55c7dd46439e316a8b4a4fff8d7afc1edb9b5bba88156ab0a", "notebooks/dev/circuit_tracer_examples/gradient_flow_analysis.ipynb": "21818d3fa9765b29b84d38d977a744c185d89fced8913893147d61600c124b45", "notebooks/dev/example_op_collections/hub_op_collection/hub_op_collection.yaml": "93d81adff9eb7380500df6e598a712d7ac1b73c54a831bdfbfaa05ed701b244b", "notebooks/dev/example_op_collections/hub_op_collection/hub_op_definitions.py": "ebcc238e3c0cefc94dba7963197f5a6d81a28158f5f13b1bb826859dd5475f60", diff --git a/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py b/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py index 76a2f07..460db76 100644 --- a/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py +++ b/src/it_examples/notebooks/publish/attribution_analysis/analysis_points.py @@ -68,48 +68,49 @@ def ap_build_input_vectors_end(local_vars: Dict[str, Any]) -> None: v = get_analysis_vars( context_keys=["target_token_analysis"], local_keys=[ - "logit_idx", - "logit_p", + "targets", "total_nodes", "total_active_feats", "max_feature_nodes", "edge_matrix", "row_to_node_index", - "logit_vecs", "n_layers", "n_pos", ], local_vars=local_vars, ) tta = v["target_token_analysis"] - tta.update_logit_info(v["logit_idx"], v["logit_p"]) - max_n_logits = len(v["logit_idx"]) + targets = v["targets"] + # Extract token IDs and probabilities from targets object + logit_idx, logit_p = targets.token_ids, targets.logit_probabilities + tta.update_logit_info(logit_idx, logit_p) + max_n_logits = len(targets) HOOK_REGISTRY.set_context( max_feature_nodes=v["max_feature_nodes"], total_nodes=v["total_nodes"], - logit_idx=v["logit_idx"], - logit_p=v["logit_p"], + logit_idx=logit_idx, + logit_p=logit_p, n_pos=v["n_pos"], max_n_logits=max_n_logits, ) data = { - "logit_idx": v["logit_idx"], - "logit_p": VarAnnotate("logit_p", var_value=v["logit_p"], annotation="non-demeaned logits probabilities"), + "logit_idx": logit_idx, + "logit_p": VarAnnotate("logit_p", var_value=logit_p, annotation="logit probabilities"), "target_tokens": tta.tokens, "target_logit_indices": tta.logit_indices, "target_logit_p": tta.logit_probabilities, - "logit_cumulative_prob": float(v["logit_p"].sum().item()), + "logit_cumulative_prob": float(logit_p.sum().item()), "total_nodes": v["total_nodes"], "max_feature_nodes": v["max_feature_nodes"], "total_active_feats": v["total_active_feats"], - "n_logits": len(v["logit_idx"]), + "n_logits": len(targets), "n_layers": v["n_layers"], "n_pos": v["n_pos"], "max_n_logits": max_n_logits, "edge_matrix.shape": v["edge_matrix"].shape, "row_to_node_index.shape": v["row_to_node_index"].shape if v["row_to_node_index"] is not None else None, - "logit_vecs.shape": v["logit_vecs"].shape, + "logit_vecs.shape": targets.logit_vectors.shape, } analysis_log_point("after building input vectors w/ target logits", data) @@ -151,7 +152,7 @@ def ap_compute_feature_attributions_end(local_vars: Dict[str, Any]) -> None: # Use dict directly for cleaner access v = get_analysis_vars( context_keys=["target_token_analysis"], - local_keys=["n_visited", "max_feature_nodes", "logit_p", "edge_matrix", "ctx"], + local_keys=["n_visited", "max_feature_nodes", "targets", "edge_matrix", "ctx"], local_vars=local_vars, ) tta = v["target_token_analysis"] diff --git a/src/it_examples/notebooks/publish/attribution_analysis/attribution_analysis.ipynb b/src/it_examples/notebooks/publish/attribution_analysis/attribution_analysis.ipynb index c08dc81..2986313 100644 --- a/src/it_examples/notebooks/publish/attribution_analysis/attribution_analysis.ipynb +++ b/src/it_examples/notebooks/publish/attribution_analysis/attribution_analysis.ipynb @@ -78,7 +78,7 @@ "if should_install_package(\"circuit-tracer\"):\n", " # Note: Using line continuation for long git URL\n", " !python -m pip install \\\n", - " 'git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3'" + " 'git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87'" ] }, { diff --git a/src/it_examples/notebooks/publish/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb b/src/it_examples/notebooks/publish/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb index 80dcafc..575d6c3 100644 --- a/src/it_examples/notebooks/publish/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb +++ b/src/it_examples/notebooks/publish/circuit_tracer_examples/circuit_tracer_adapter_example_basic_clt.ipynb @@ -124,7 +124,9 @@ "source": [ "# Parameters - These will be injected by papermill during parameterized test runs\n", "use_baseline_salient_logits = True # logits computation mode: True->salient logits, False->specific logits\n", - "use_baseline_transcoder_arch = True # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n", + "use_baseline_transcoder_arch = (\n", + " False # transcoder architecture: True->SingleLayerTranscoder, False->CrossLayerTranscoder\n", + ")\n", "core_log_dir = None # Directory to save analysis logs (if None, a temp directory will be created)" ] }, @@ -364,11 +366,13 @@ "if port_forwarding:\n", " hostname = \"localhost\" # use localhost for port forwarding\n", " print(\n", - " f\"Using port forwarding (ensure it is configured) and localhost. Open your graph here at http://{hostname}:{port}/index.html\"\n", + " f\"Using port forwarding (ensure it is configured) and localhost.\"\n", + " f\" Open your graph here at http://{hostname}:{port}/index.html\"\n", " )\n", "else:\n", " print(\n", - " f\"Not using port forwarding. Use the IFrame below, or open your graph here directly at http://{hostname}:{port}/index.html\"\n", + " f\"Not using port forwarding. Use the IFrame below, or\"\n", + " f\" open your graph here directly at http://{hostname}:{port}/index.html\"\n", " )\n", "\n", "if enable_iframe:\n", diff --git a/src/it_examples/utils/analysis_injection/version_manager.py b/src/it_examples/utils/analysis_injection/version_manager.py index f29362b..d3a92ac 100644 --- a/src/it_examples/utils/analysis_injection/version_manager.py +++ b/src/it_examples/utils/analysis_injection/version_manager.py @@ -20,6 +20,7 @@ import importlib.metadata import logging +import re import subprocess import sys import tempfile @@ -30,16 +31,12 @@ # Mapping of (package_name, version) to git-based installation URLs # Used as fallback when package is not available on PyPI -# Note: Supports both hyphenated and underscored package names +# Package names are normalized according to PEP 503 (canonical form) GIT_FALLBACK_URLS = { ( "circuit-tracer", "0.1.0", - ): "git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3", - ( - "circuit_tracer", - "0.1.0", - ): "git+https://github.com/speediedan/circuit-tracer.git@b228bf190fadb3cb30f6a5ba6691dc4c86d76ba3", + ): "git+https://github.com/speediedan/circuit-tracer.git@004f1b2822eca3f0c1ddd2389e9105b3abffde87", } @@ -58,6 +55,22 @@ class PackageVersionManager: >>> mgr.cleanup() """ + @staticmethod + def _normalize_package_name(name: str) -> str: + """Normalize package name according to PEP 503. + + Converts package name to canonical form by: + - Converting to lowercase + - Replacing any sequence of [-_.] with a single hyphen + + Args: + name: Package name to normalize + + Returns: + Normalized package name + """ + return re.sub(r"[-_.]+", "-", name).lower() + def __init__(self, package_name: str, required_version: str): """Initialize the version manager. @@ -137,7 +150,8 @@ def install_temp_version(self) -> Path: # If PyPI installation fails, try git-based fallback if not success: - git_url = GIT_FALLBACK_URLS.get((self.package_name, self.required_version)) + normalized_name = self._normalize_package_name(self.package_name) + git_url = GIT_FALLBACK_URLS.get((normalized_name, self.required_version)) if git_url: logger.warning( f"{self.package_name}=={self.required_version} not available on PyPI. " diff --git a/tests/examples/test_version_manager.py b/tests/examples/test_version_manager.py index f6b3dbd..4845d5d 100644 --- a/tests/examples/test_version_manager.py +++ b/tests/examples/test_version_manager.py @@ -26,6 +26,35 @@ def test_git_fallback_urls_configured(): assert "circuit-tracer" in url +def test_package_name_normalization(): + """Test that package name normalization works according to PEP 503.""" + # Test various forms normalize to the same canonical name + assert PackageVersionManager._normalize_package_name("circuit_tracer") == "circuit-tracer" + assert PackageVersionManager._normalize_package_name("circuit-tracer") == "circuit-tracer" + assert PackageVersionManager._normalize_package_name("Circuit.Tracer") == "circuit-tracer" + assert PackageVersionManager._normalize_package_name("circuit__tracer") == "circuit-tracer" + assert PackageVersionManager._normalize_package_name("CIRCUIT_TRACER") == "circuit-tracer" + + +def test_normalized_git_fallback_lookup(): + """Test that git fallback lookup works with normalized package names.""" + # Both forms should find the same URL after normalization + mgr_underscore = PackageVersionManager("circuit_tracer", "0.1.0") + mgr_hyphen = PackageVersionManager("circuit-tracer", "0.1.0") + + normalized_underscore = mgr_underscore._normalize_package_name(mgr_underscore.package_name) + normalized_hyphen = mgr_hyphen._normalize_package_name(mgr_hyphen.package_name) + + assert normalized_underscore == normalized_hyphen == "circuit-tracer" + + url_underscore = GIT_FALLBACK_URLS.get((normalized_underscore, "0.1.0")) + url_hyphen = GIT_FALLBACK_URLS.get((normalized_hyphen, "0.1.0")) + + assert url_underscore == url_hyphen + assert url_underscore is not None + assert url_underscore.startswith("git+https://") + + def test_version_manager_init(): """Test that PackageVersionManager can be initialized.""" mgr = PackageVersionManager("circuit-tracer", "0.1.0") diff --git a/tests/parity_acceptance/expected.py b/tests/parity_acceptance/expected.py index 56b2aa0..b425856 100644 --- a/tests/parity_acceptance/expected.py +++ b/tests/parity_acceptance/expected.py @@ -16,7 +16,7 @@ exact_results=def_results("cuda", 32, ds_cfg="train"), close_results=((0, "loss", 13.356880),) ), "train_cuda_bf16": TestResult( - exact_results=def_results("cuda", "bf16", ds_cfg="train"), close_results=((0, "loss", 13.402528),) + exact_results=def_results("cuda", "bf16", ds_cfg="train"), close_results=((0, "loss", 14.748956),) ), "train_cpu_bf16": TestResult(exact_results=def_results("cpu", "bf16", ds_cfg="train")), "test_cpu_32": TestResult(exact_results=def_results("cpu", 32, ds_cfg="test")),