旋转位置嵌入(Rotary Position Embedding)模块。定义于 InfiniCore/python/infinicore/nn/modules/rope.py。
class RoPE(Module):
def __init__(
self,
max_position_embeddings: int,
rope_theta: float,
head_dim: int,
device=None,
dtype=None,
)RoPE(Rotary Position Embedding)是一种位置编码方法,通过旋转矩阵将位置信息编码到注意力机制中。广泛应用于 GPT-J、LLaMA 等现代语言模型。
max_position_embeddings:最大序列长度。rope_theta:RoPE 的基础周期(base period)。head_dim:注意力头的维度。device:sin/cos 表所在的设备。dtype:sin/cos 表的数据类型。
forward(states, position_ids, algo=RopeAlgo.GPT_NEOX):前向传播,应用旋转位置嵌入。
- 输入
states:(bs, seq_len, num_heads, head_dim)。 - 输入
position_ids:(bs, seq_len)。 - 输出:
(bs, seq_len, num_heads, head_dim),与输入states形状相同。
RopeAlgo.GPT_J:GPT-J 风格的 RoPE 算法。RopeAlgo.GPT_NEOX:GPT-NeoX 风格的 RoPE 算法(默认)。
import infinicore as ic
from infinicore.nn import RoPE
from infinicore.nn.functional import RopeAlgo
device = ic.device("cuda:0")
# 创建 RoPE 模块
rope = RoPE(
max_position_embeddings=2048,
rope_theta=10000.0,
head_dim=128,
device=device,
dtype=ic.float16
)
# 输入状态和位置 ID
states = ic.empty((2, 10, 8, 128), dtype=ic.float16, device=device)
position_ids = ic.empty((2, 10), dtype=ic.int64, device=device)
# ... 填充 position_ids ...
# 应用 RoPE
output = rope(states, position_ids, algo=RopeAlgo.GPT_NEOX)