Skip to content

Commit

Permalink
metal lowbit kernels: check contiguity of scales and zeros
Browse files Browse the repository at this point in the history
Differential Revision: D65957327

Pull Request resolved: #1287
  • Loading branch information
manuelcandales authored Nov 18, 2024
1 parent d4ca98f commit 20b08ee
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion torchao/experimental/kernels/mps/test/test_lowbit.mm
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ void init() {
int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize;
for (int idx = 0; idx < N * ceil_K_group_size; ++idx) {
s_ptr[idx] = (idx + 1.0) / N;
z_ptr[idx] = int_distrib(generator);
auto zp = int_distrib(generator);
z_ptr[idx] = -zp * s_ptr[idx];
}
for (int idx = 0; idx < M * N; ++idx) {
c_ptr[idx] = -1.0;
Expand Down
2 changes: 2 additions & 0 deletions torchao/experimental/ops/mps/register.mm
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ void check_linear_mps_args(
": expect S to be 2d tensor with shape [:, ",
N,
"]");
TORCH_CHECK(S.is_contiguous(), __func__, " : expect S to be contiguous.");

TORCH_CHECK(
Z.dim() == 2 && Z.size(1) == N,
__func__,
": expect Z to be 2d tensor with shape [:, ",
N,
"]");
TORCH_CHECK(Z.is_contiguous(), __func__, " : expect Z to be contiguous.");
}

template <int nbit>
Expand Down
8 changes: 4 additions & 4 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ class TestLowBitQuantWeightsLinear(unittest.TestCase):
]

def _init_tensors(self, group_size, M, K, N, nbit, device="mps"):
max_abs = 1 << (nbit - 1)
ceil_K_group_size = (K + group_size - 1) // group_size
A = 2 * torch.rand(M, K, dtype=torch.float32, device=device) - 1
W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device)
A = torch.rand(M, K, dtype=torch.float32, device=device)
W = torch.randint(0, 1 << nbit, (N, K), dtype=torch.uint8, device=device)
S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01
Z = torch.randint(
0,
2 * max_abs,
1 << nbit,
(ceil_K_group_size, N),
dtype=torch.float32,
device=device,
)
Z = -Z * S
return A, W, S, Z

def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):
Expand Down

0 comments on commit 20b08ee

Please sign in to comment.