We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 01a38a9 commit 6653120Copy full SHA for 6653120
optimum/neuron/utils/optimization_utils.py
@@ -85,7 +85,7 @@ def to3d(x):
85
attention_scores = torch.bmm(key, query.transpose(-1, -2)) * (1 / math.sqrt(query.size(-1)))
86
attention_probs = attention_scores.softmax(dim=1)
87
if query.size() == key.size():
88
- attention_probs = attention_probs.permute(0, 2, 1)
+ attention_probs = attention_probs.permute(0, 2, 1)
89
attn_out = torch.bmm(attention_probs, value)
90
if orig_shape:
91
attn_out = attn_out.reshape(orig_shape[0], orig_shape[1], attn_out.shape[1], attn_out.shape[2])
0 commit comments