Skip to content

Commit

Permalink
Fixed assertions for fast forward and backtrack tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
nopdive committed Nov 11, 2024
1 parent cfb303d commit 7122614
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions tests/model_specific/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,21 @@ def ff_prompt(lm):

# We should have significantly less output tokens in the fast-forwarded version (1 output)

# gpt2_noff = models.Transformers("gpt2", enable_ff_tokens=False, enable_backtrack=False)
# gpt2_noff += ff_prompt()
# assert gpt2_noff.engine.metrics.engine_output_tokens > 1
gpt2_noff = models.Transformers("gpt2", enable_ff_tokens=False)
gpt2_noff += ff_prompt()
noff_count = gpt2_noff.engine.metrics.engine_output_tokens

gpt2_ff = models.Transformers("gpt2", enable_ff_tokens=True)
gpt2_nobt = models.Transformers("gpt2", enable_backtrack=False)
gpt2_nobt += ff_prompt()
nobt_count = gpt2_nobt.engine.metrics.engine_output_tokens

gpt2_ff = models.Transformers("gpt2")
gpt2_ff += ff_prompt()
assert gpt2_ff.engine.metrics.engine_output_tokens == 1
ff_count = gpt2_ff.engine.metrics.engine_output_tokens

assert nobt_count == 3
assert ff_count == 3
assert noff_count > ff_count



Expand Down

0 comments on commit 7122614

Please sign in to comment.