You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM:
ce_loss = torch.nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
wte = model.resize_token_embeddings()
for _ in range(horizon_length):
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
_, curr_unpert_past, curr_all_hidden = model(
past=curr_unpert_past,
inputs_embeds=inputs_embeds
)
curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
curr_hidden, dim=1)
inputs_embeds is updated in the for loop,
but curr_probs and wte.weight.data seem not to be updated in the loop.
Could you please tell me the reason inputs_embeds is calculated in the for loop?
Thank you in advance!
The text was updated successfully, but these errors were encountered:
I think that is actually a bug, and might also explain why our experiments with horizon_length > 1 did not work so well (we use horizon-length=1 in all of our experiments). If you're running with horizon-length=1, it shouldn't matter but that is a bug. We do need to update curr_probs inside the loop here.
Excuse me for my reply was so delayed. I've missed the notification that this issue was kindly responded to by you.
Thank you for the detailed information. Now I understood that what I mentioned is a bug and you used horizon-lenghth=1 for your experiments to make it do not matter.
Updating curr_probs inside the loop here is needed, I understood.
I'm not a good programmer, but if I can do somewhat helpful to you, please tell me.
Hi,
Thank you for sharing your great work!
I have a question about how to handle
input_embeds
in the PPLM code.When I look
run_pplm.py
, I found something I cannot understand the intention.https://github.com/uber-research/PPLM/blob/master/run_pplm.py#L220
inputs_embeds
is updated in the for loop,but
curr_probs
andwte.weight.data
seem not to be updated in the loop.Could you please tell me the reason
inputs_embeds
is calculated in the for loop?Thank you in advance!
The text was updated successfully, but these errors were encountered: