Skip to content

Commit 2b2a88b

Browse files
committed
Avoid unnecessary copy in for loop.
1 parent 886dc04 commit 2b2a88b

File tree

5 files changed

+7
-10
lines changed

5 files changed

+7
-10
lines changed

examples/85_ada_ampere_gemm_with_blockwise_scaling/85a_ada_fp8_gemm_with_groupwise_scaling_cute.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ struct Options {
248248
float alpha = 1.f, beta = 0.f;
249249
int iterations = 1000;
250250
int warmup = 1000;
251-
int m = 1024, n = 512, k = 1024, l = 1;
251+
int m = 1024, n = 1024, k = 1024, l = 1;
252252
float epsilon = 0.02f;
253253
float non_zero_floor = 1.f;
254254

examples/85_ada_ampere_gemm_with_blockwise_scaling/85b_ada_fp8_gemm_with_blockwise_scaling_cute.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ struct Options {
250250
float alpha = 1.f, beta = 0.f;
251251
int iterations = 1000;
252252
int warmup = 1000;
253-
int m = 1024, n = 512, k = 1024, l = 1;
253+
int m = 1024, n = 1024, k = 1024, l = 1;
254254
float epsilon = 0.02f;
255255
float non_zero_floor = 1.f;
256256

examples/85_ada_ampere_gemm_with_blockwise_scaling/85c_ampere_int8_gemm_with_groupwise_scaling_cute.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ struct Options {
250250
float alpha = 1.f, beta = 0.f;
251251
int iterations = 1000;
252252
int warmup = 1000;
253-
int m = 1024, n = 512, k = 1024, l = 1;
253+
int m = 1024, n = 1024, k = 1024, l = 1;
254254
float epsilon = 0.02f;
255255
float non_zero_floor = 1.f;
256256

examples/85_ada_ampere_gemm_with_blockwise_scaling/85d_ampere_int8_gemm_with_blockwise_scaling_cute.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ struct Options {
251251
float alpha = 1.f, beta = 0.f;
252252
int iterations = 1000;
253253
int warmup = 1000;
254-
int m = 1024, n = 512, k = 1024, l = 1;
254+
int m = 1024, n = 1024, k = 1024, l = 1;
255255
float epsilon = 0.02f;
256256
float non_zero_floor = 1.f;
257257

include/cutlass/gemm/collective/sm80_mma_multistage_blockwise_scaling.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,14 +454,14 @@ struct CollectiveMma<
454454
// Prefetch the first rmem from the first k-tile
455455
copy(smem_tiled_copy_A, tCsA_p(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
456456
copy(smem_tiled_copy_B, tCsB_p(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
457-
// Load per block scale values from shared memory to registers
458-
copy(tCsSFA(_,_,_,make_coord(_0{}, _0{})), tCrSFA);
459-
copy(tCsSFB(_,_,_,make_coord(_0{}, _0{})), tCrSFB);
460457
}
461458

462459
CUTLASS_PRAGMA_NO_UNROLL
463460
for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count)
464461
{
462+
// Load per block scale values from shared memory to registers
463+
copy(tCsSFA(_,_,_,make_coord(_0{}, smem_pipe_read)), tCrSFA);
464+
copy(tCsSFB(_,_,_,make_coord(_0{}, smem_pipe_read)), tCrSFB);
465465
// Pipeline the outer products with a static for loop.
466466
//
467467
// Note, the for_each() function is required here to ensure `k_block` is of type Int<N>.
@@ -552,9 +552,6 @@ struct CollectiveMma<
552552
tCrAccum(i) = 0;
553553
}
554554
}
555-
// Load per block scale values from shared memory to registers
556-
copy(tCsSFA(_,_,_,make_coord(_0{}, smem_pipe_read)), tCrSFA);
557-
copy(tCsSFB(_,_,_,make_coord(_0{}, smem_pipe_read)), tCrSFB);
558555
}
559556

560557
cp_async_wait<0>();

0 commit comments

Comments
 (0)