Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about the causal_conv1d_update_ref Function #42

Open
guoguo1314 opened this issue Dec 31, 2024 · 0 comments
Open

Questions about the causal_conv1d_update_ref Function #42

guoguo1314 opened this issue Dec 31, 2024 · 0 comments

Comments

@guoguo1314
Copy link

guoguo1314 commented Dec 31, 2024

Hello,

I have some questions regarding the function causal_conv1d_update_ref. I am trying to implement the function causal_conv1d_update on my CPU. To do this, I followed the reference implementation provided by causal_conv1d_update_ref. However, when I tested both implementations, I discovered that for the same input, the output differs between causal_conv1d_update and causal_conv1d_update_ref on my gpu. Below is my test script test.py:

import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update,causal_conv1d_update_ref,causal_conv1d_ref
from einops import rearrange

def test_causal_conv1d_update():
    # 设置随机种子
    torch.manual_seed(42)
    
    # 检查是否有可用的GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 设置参数
    batch_size = 1
    seq_len = 1
    dim = 1024
    width = 4  # d_conv
    
    # 创建输入数据并移到GPU
    x = torch.randn(batch_size, dim, seq_len).to(device)
    weight = torch.randn(dim, width).to(device)
    bias = torch.randn(dim).to(device)
    
    # 创建conv_state (使用padding方式)
    conv_state = F.pad(x, (width - x.shape[-1], 0)).to(device)
    print(f"conv_state shape: {conv_state.shape}")
    
    # 运行两个函数
    out_update = causal_conv1d_update(
        x=x.squeeze(1),  # 注意这里需要squeeze
        conv_state=conv_state,
        weight=weight,
        bias=bias,
        activation="silu"
    )
    
    out_update_ref = causal_conv1d_update_ref(
        x=x.squeeze(1),  # 注意这里需要squeeze
        conv_state=conv_state,
        weight=weight,
        bias=bias,
        activation="silu"
    )
    
    # 比较结果
    print("\n输入形状:")
    print(f"x shape: {x.shape}")
    print(f"weight shape: {weight.shape}")
    print(f"bias shape: {bias.shape}")
    print(f"conv_state shape: {conv_state.shape}")
    
    print("\n输出形状:")
    print(f"out_update shape: {out_update.shape}")
    print(f"out_update_ref shape: {out_update_ref.shape}")
    
    # 计算差异
    abs_diff = (out_update - out_update_ref).abs()
    max_diff = abs_diff.max().item()
    mean_diff = abs_diff.mean().item()
    is_close = torch.allclose(out_update, out_update_ref, rtol=1e-5, atol=1e-5)
    
    print("\n比较结果:")
    print(f"最大差异: {max_diff:.2e}")
    print(f"平均差异: {mean_diff:.2e}")
    print(f"输出是否一致: {is_close}")

if __name__ == "__main__":
    test_causal_conv1d_update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant