Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug causing disabling edits #44

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rome/compute_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def compute_u(
context_templates: List[str],
) -> torch.Tensor:
"""
Computes the right vector used in constructing the rank-1 update matrix.
Computes the right vector used in constructing the rank-1 update matrix,
as well as the key representing the subject (averaged over random prefixes).
"""

print("Computing left vector (u)...")
Expand Down Expand Up @@ -117,4 +118,4 @@ def compute_u(
) @ u.unsqueeze(1)
u = u.squeeze()

return u / u.norm()
return u / u.norm(), cur_repr
10 changes: 5 additions & 5 deletions rome/compute_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def compute_v(
hparams: ROMEHyperParams,
layer: int,
left_vector: torch.Tensor,
k_star: torch.Tensor,
context_templates: List[str],
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -165,9 +166,8 @@ def edit_output_fn(cur_out, cur_layer):

target = target_init + delta

# Retrieve cur_input, the current input to the 2nd MLP layer, and
# cur_output, the original output of the 2nd MLP layer.
cur_input, cur_output = get_module_input_output_at_word(
# Retrieve cur_output, the original output of the 2nd MLP layer.
_, cur_output = get_module_input_output_at_word(
model,
tok,
layer,
Expand All @@ -178,12 +178,12 @@ def edit_output_fn(cur_out, cur_layer):
)

# Solving the linear system to compute the right vector
right_vector = (target - cur_output) / torch.dot(cur_input, left_vector)
right_vector = (target - cur_output) / torch.dot(k_star, left_vector)
print(f"Delta norm: {(target - cur_output).norm().item()}")
print(
f"Change in target norm: {target_init.norm().item()} to {target.norm().item()} => {(target.norm() - target_init.norm()).item()}"
)
print(f"Division Factor: {torch.dot(cur_input, left_vector).item()}")
print(f"Division Factor: {torch.dot(k_star, left_vector).item()}")
print(f"Right vector norm: {right_vector.norm()}")

return right_vector
Expand Down
5 changes: 4 additions & 1 deletion rome/rome_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def execute_rome(
deltas = {}
for layer in sorted(hparams.layers):
# Compute rank-1 update matrix
left_vector: torch.Tensor = compute_u(
left_vector: torch.Tensor
k_star: torch.Tensor
left_vector, k_star = compute_u(
model,
tok,
request,
Expand All @@ -107,6 +109,7 @@ def execute_rome(
hparams,
layer,
left_vector,
k_star,
get_context_templates(model, tok, hparams.context_template_length_params),
)
print("Right vector shape:", right_vector.shape)
Expand Down