Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify a budget limit in term number of leaves found. The limit equals the beam size. #45

Open
wants to merge 3 commits into
base: keep-searching
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions ults/ults.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(
)
self.betaparameters = torch.from_numpy(self.init_prior()).to(self.device)

self._leaves_found = 0

def init_prior(self) -> np.ndarray:
"""Build the approximate prior over Delta or load if already exists.

Expand Down Expand Up @@ -268,7 +270,10 @@ def budget_left(self) -> bool:
Returns:
is_budge_left: `True` if there is budget left, otherwise `False`.
"""
return self.max_beam_size >= self.used_max_beam_size[-1]
return (
self.max_beam_size >= self.used_max_beam_size[-1]
and self._leaves_found < self.max_beam_size
)

def log_diversity(self, tokens) -> float:
"""Diversity measure of a token sequence (Also see: https://arxiv.org/pdf/2202.06417)"""
Expand Down Expand Up @@ -355,7 +360,7 @@ def search(self) -> tuple[torch.Tensor, float, int]:
best_observed_value: Total logprob of the best path.
n_llm_calls: Number of LLM forward passes done during the search.
"""
best_path: torch.Tensor = torch.tensor(0).long()
best_path: torch.Tensor = torch.tensor([[0]]).long()
best_observed_value: float = -np.inf
n_llm_calls: int = 0
prob_result_nodes: float = 0
Expand Down Expand Up @@ -456,12 +461,11 @@ def search(self) -> tuple[torch.Tensor, float, int]:
if child_depth == self.depth or (
self.stop_at_eos and child_tokens[0, -1] == self.eos_token
):
if self.use_full_budget:
# we want to compare by average log likelihood
observed_value = (
child_obs / child_tokens.size(-1)
+ self.ngram_penalty * penalty
)
# we want to compare by average log likelihood
observed_value = (
child_obs / child_tokens.size(-1)
+ self.ngram_penalty * penalty
)

if observed_value > best_observed_value:
best_path = children_tokens[i][None, :]
Expand All @@ -471,6 +475,8 @@ def search(self) -> tuple[torch.Tensor, float, int]:
if self.use_full_budget:
best_observed_loglike /= child_tokens.size(-1)

self._leaves_found += 1

# Update optimal value distribution of parents
self.backup(new_node_name)

Expand All @@ -480,9 +486,15 @@ def search(self) -> tuple[torch.Tensor, float, int]:
else:
overall_max_samples = self.tree.nodes["0"]["samples"]

prob_result_nodes = (
torch.sum(best_observed_value >= overall_max_samples) / self.sample_size
)
if self.use_full_budget:
# If use full budget, then set to 0 so that it always be < 1-epsilon
# i.e., we ignore this termination criterion.
prob_result_nodes = 0
else:
prob_result_nodes = (
torch.sum(best_observed_value >= overall_max_samples)
/ self.sample_size
)

if self.ngram_penalty > 0:
best_observed_value = best_observed_loglike
Expand Down
Loading