diff --git a/genlm/eval/domains/ds1000/runtime_no_error_potential.py b/genlm/eval/domains/ds1000/runtime_no_error_potential.py index 497f5f6..25b89c8 100644 --- a/genlm/eval/domains/ds1000/runtime_no_error_potential.py +++ b/genlm/eval/domains/ds1000/runtime_no_error_potential.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional import tempfile import subprocess import sys @@ -6,10 +6,7 @@ import os from genlm.control import Potential -from genlm.eval.domains.ds1000.utils import ( - _sandbox_env, - _postprocess_code, -) +from genlm.eval.domains.ds1000.utils import _postprocess_code, _sandbox_env class DS1000RuntimeNoErrorPotential(Potential): @@ -25,6 +22,7 @@ def __init__( timeout_seconds: float = 30.0, python_executable: Optional[str] = None, extra_env: Optional[Dict[str, str]] = None, + f: Optional[Callable[[List[bytes]], List[bytes]]] = None, ): vocabulary = vocabulary or [bytes([i]) for i in range(256)] super().__init__(vocabulary=vocabulary) @@ -33,15 +31,21 @@ def __init__( self.python_executable = python_executable or sys.executable self.extra_env = dict(extra_env or {}) self.last_was_syntax_error = False + self.f = f - def coerce(self, other, f=None, prune=True): - # Overwrite coerce to adopt the LLM vocabulary without mapping tokens. + def coerce( + self, + other, + f: Optional[Callable[[List[bytes]], List[bytes]]] = None, + prune: bool = True, + ): return DS1000RuntimeNoErrorPotential( vocabulary=list(other.vocab), code_context=self.code_context, timeout_seconds=self.timeout_seconds, python_executable=self.python_executable, extra_env=self.extra_env, + f=f, ) def _bytes_to_str(self, toks): @@ -54,6 +58,8 @@ def _bytes_to_str(self, toks): return bytes_str async def prefix(self, context: List[bytes]) -> float: + if self.f is not None: + context = self.f(context) code = self._bytes_to_str(context) # Newline guardrail when using the default sampler. if not code.endswith("\n"): @@ -63,6 +69,9 @@ async def prefix(self, context: List[bytes]) -> float: return out async def complete(self, context: List[bytes]): + # Apply transformation before processing + if self.f is not None: + context = self.f(context) code = self._bytes_to_str(context) code = _postprocess_code(code) out = await self._score_no_error(code) diff --git a/genlm/eval/domains/goal_inference/goal_inference.py b/genlm/eval/domains/goal_inference/goal_inference.py index 44c1cd5..a49a993 100644 --- a/genlm/eval/domains/goal_inference/goal_inference.py +++ b/genlm/eval/domains/goal_inference/goal_inference.py @@ -132,9 +132,10 @@ def from_hf_planetarium( & (pl.col("init_is_abstract") == 0) & (pl.col("goal_is_abstract") == 0) ) - .unique(subset=["goal_natural_language"], keep="first", maintain_order=True) + .unique(subset=["goal_natural_language"]) .sample(fraction=1, shuffle=True, seed=1234) .head(n_examples) + .sort("id") .select( pl.col("id").alias("instance_id"), pl.col("goal_natural_language").alias("nl_goal"),