Skip to content

Commit 9101c67

Browse files
committed
Add test for empty subspaces forward pass
1 parent 896a7dd commit 9101c67

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

tests/integration_tests/IntervenableBasicTestCase.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,40 @@ def test_linked_intervention_and_weights_sharing(self):
427427
subspaces=[[[1]], [[0]]],
428428
)
429429
self.assertTrue(torch.equal(pv_out3.last_hidden_state, pv_out4.last_hidden_state))
430-
430+
431+
def test_empty_subspaces_matches_default_behavior(self):
432+
433+
_, tokenizer, gpt2 = pv.create_gpt2(cache_dir=self._test_dir)
434+
435+
config = pv.IntervenableConfig([
436+
{"layer": 0, "component": "block_output"},
437+
], intervention_types=pv.VanillaIntervention)
438+
439+
pv_gpt2 = pv.IntervenableModel(config, model=gpt2)
440+
441+
base = tokenizer("The capital of Spain is", return_tensors="pt")
442+
source = tokenizer("The capital of Italy is", return_tensors="pt")
443+
444+
_, default_subspaces_output = pv_gpt2(
445+
base,
446+
[source],
447+
{"sources->base": 4},
448+
)
449+
450+
_, empty_subspaces_output = pv_gpt2(
451+
base,
452+
[source],
453+
{"sources->base": 4},
454+
subspaces=[],
455+
)
456+
457+
self.assertTrue(
458+
torch.equal(
459+
default_subspaces_output.last_hidden_state,
460+
empty_subspaces_output.last_hidden_state,
461+
)
462+
)
463+
431464
def test_new_model_type(self):
432465
try:
433466
import sentencepiece

0 commit comments

Comments
 (0)