Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions genlm/eval/domains/ds1000/runtime_no_error_potential.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional
import tempfile
import subprocess
import sys
import textwrap
import os

from genlm.control import Potential
from genlm.eval.domains.ds1000.utils import (
_sandbox_env,
_postprocess_code,
)
from genlm.eval.domains.ds1000.utils import _postprocess_code, _sandbox_env


class DS1000RuntimeNoErrorPotential(Potential):
Expand All @@ -25,6 +22,7 @@ def __init__(
timeout_seconds: float = 30.0,
python_executable: Optional[str] = None,
extra_env: Optional[Dict[str, str]] = None,
f: Optional[Callable[[List[bytes]], List[bytes]]] = None,
):
vocabulary = vocabulary or [bytes([i]) for i in range(256)]
super().__init__(vocabulary=vocabulary)
Expand All @@ -33,15 +31,21 @@ def __init__(
self.python_executable = python_executable or sys.executable
self.extra_env = dict(extra_env or {})
self.last_was_syntax_error = False
self.f = f

def coerce(self, other, f=None, prune=True):
# Overwrite coerce to adopt the LLM vocabulary without mapping tokens.
def coerce(
self,
other,
f: Optional[Callable[[List[bytes]], List[bytes]]] = None,
prune: bool = True,
):
return DS1000RuntimeNoErrorPotential(
vocabulary=list(other.vocab),
code_context=self.code_context,
timeout_seconds=self.timeout_seconds,
python_executable=self.python_executable,
extra_env=self.extra_env,
f=f,
)

def _bytes_to_str(self, toks):
Expand All @@ -54,6 +58,8 @@ def _bytes_to_str(self, toks):
return bytes_str

async def prefix(self, context: List[bytes]) -> float:
if self.f is not None:
context = self.f(context)
code = self._bytes_to_str(context)
# Newline guardrail when using the default sampler.
if not code.endswith("\n"):
Expand All @@ -63,6 +69,9 @@ async def prefix(self, context: List[bytes]) -> float:
return out

async def complete(self, context: List[bytes]):
# Apply transformation before processing
if self.f is not None:
context = self.f(context)
code = self._bytes_to_str(context)
code = _postprocess_code(code)
out = await self._score_no_error(code)
Expand Down
3 changes: 2 additions & 1 deletion genlm/eval/domains/goal_inference/goal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ def from_hf_planetarium(
& (pl.col("init_is_abstract") == 0)
& (pl.col("goal_is_abstract") == 0)
)
.unique(subset=["goal_natural_language"], keep="first", maintain_order=True)
.unique(subset=["goal_natural_language"])
.sample(fraction=1, shuffle=True, seed=1234)
.head(n_examples)
.sort("id")
.select(
pl.col("id").alias("instance_id"),
pl.col("goal_natural_language").alias("nl_goal"),
Expand Down