Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel] special handling for transferWithinBlock for boolean values #3599

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

alexbaden
Copy link
Contributor

Given this convert op layout lowering:

#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [16], order = [0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [16], order = [0]}>

    %15 = arith.cmpi ne, %14, %cst : tensor<2048xi8, #blocked> loc(#loc9)
    %21 = ttg.convert_layout %15 : tensor<2048xi1, #blocked> -> tensor<2048xi1, #blocked1> loc(#loc15)

We discovered that the generated IR changed after https://github.com/intel/intel-xpu-backend-for-triton/pull/3515/files#diff-fd4c24537e95bcab1b909fd764c84d63e5a844e1aa1ffaf5354510572a7d8bc6

Previously the shared memory load (after transformation) was returned as an i8 pointer and each element was extracted using gep instructions:

tail call spir_func void @_Z7barrierj(i32 1) #4, !dbg !33
  %54 = shl nuw nsw i32 %urem, 1, !dbg !33
  %55 = zext nneg i32 %54 to i64, !dbg !33
  %56 = getelementptr i8, ptr addrspace(3) %4, i64 %55, !dbg !33
  %57 = load i8, ptr addrspace(3) %56, align 2, !dbg !33
  %58 = icmp ne i8 %57, 0, !dbg !33
  %59 = getelementptr inbounds nuw i8, ptr addrspace(3) %56, i64 1, !dbg !33
  %60 = load i8, ptr addrspace(3) %59, align 1, !dbg !33
  %61 = icmp ne i8 %60, 0, !dbg !33
  %62 = or disjoint i32 %54, 1024, !dbg !33
  %63 = zext nneg i32 %62 to i64, !dbg !33
  %64 = getelementptr i8, ptr addrspace(3) %4, i64 %63, !dbg !33
  %65 = load i8, ptr addrspace(3) %64, align 2, !dbg !33
  %66 = icmp ne i8 %65, 0, !dbg !33
  %67 = getelementptr inbounds nuw i8, ptr addrspace(3) %64, i64 1, !dbg !33
  %68 = load i8, ptr addrspace(3) %67, align 1, !dbg !33
  %69 = icmp ne i8 %68, 0, !dbg !33
  %70 = zext i1 %58 to i64, !dbg !34
  %71 = zext i1 %61 to i64, !dbg !34
  %72 = zext i1 %66 to i64, !dbg !34
  %73 = zext i1 %69 to i64, !dbg !34
  %74 = zext i1 %34 to i64, !dbg !34
  %75 = zext i1 %35 to i64, !dbg !34
  %76 = zext i1 %36 to i64, !dbg !34
  %77 = zext i1 %37 to i64, !dbg !34
  tail call spir_func void @_Z7barrierj(i32 1) #4, !dbg !35

but using the upstream method of transferring between blocks using linear layout, the extract is using an i1 vector of length 16:

 tail call spir_func void @_Z7barrierj(i32 1) #4, !dbg !33
  %53 = zext nneg i32 %17 to i64, !dbg !33
  %54 = getelementptr inbounds nuw i8, ptr addrspace(3) %4, i64 %53, !dbg !33
  %55 = load <16 x i1>, ptr addrspace(3) %54, align 2, !dbg !33
  %56 = zext nneg i32 %18 to i64, !dbg !33
  %57 = getelementptr inbounds nuw i8, ptr addrspace(3) %4, i64 %56, !dbg !33
  %58 = load <16 x i1>, ptr addrspace(3) %57, align 2, !dbg !33
  %59 = extractelement <16 x i1> %55, i64 0, !dbg !33
  %60 = extractelement <16 x i1> %55, i64 8, !dbg !33
  %61 = extractelement <16 x i1> %58, i64 0, !dbg !33
  %62 = extractelement <16 x i1> %58, i64 8, !dbg !33
  %63 = zext i1 %59 to i64, !dbg !34
  %64 = zext i1 %60 to i64, !dbg !34
  %65 = zext i1 %61 to i64, !dbg !34
  %66 = zext i1 %62 to i64, !dbg !34
  %67 = zext i1 %35 to i64, !dbg !34
  %68 = zext i1 %36 to i64, !dbg !34
  %69 = zext i1 %37 to i64, !dbg !34
  %70 = zext i1 %38 to i64, !dbg !34
  tail call spir_func void @_Z7barrierj(i32 1) #4, !dbg !35

this seems to be causing some trouble in the IGC lowering.

Inserting the icmp_ne instruction (which was previously used in processReplica here: https://github.com/intel/intel-xpu-backend-for-triton/pull/3515/files#diff-3fa75fa6b39886d9576a671c306d98b0deb43f81c2fc7873ad08892d190d2622L215) forces us back to the existing method for doing the conversion. We need to figure out whether there is a true hardware limitation here or a bug, and it is possible there are better ways to handle this when converting the layouts. For now, I left the change in common upstream code and am marking this as a draft. But if this is the most expedient way to resolve the regression without side effect then I think we should move forward.

cc #3570

@chengjunlu
Copy link
Contributor

chengjunlu commented Mar 4, 2025

I think the i1 is just a concept type which is need to be materialized later because all the HW arch is only byte addressable memory. What is the PTX on NV backend for this case?

The code pieces seems no sense to me. It loads 16 x i1 for two bytes from the SLM. But only uses 2 values 0 and 8 from it. Should we just make it to 2 x i8 and then convert it to i1 by trunc or cmp.ne

  %58 = load <16 x i1>, ptr addrspace(3) %57, align 2, !dbg !33
  %61 = extractelement <16 x i1> %58, i64 0, !dbg !33
  %62 = extractelement <16 x i1> %58, i64 8, !dbg !33

@alexbaden
Copy link
Contributor Author

The code pieces seems no sense to me. It loads 16 x i1 for two bytes from the SLM. But only uses 2 values 0 and 8 from it.

That is also what the working code does - it loads two bytes from each location but takes double the number of load instructions to do it. The "broken" code is actually a bit more clear - it loads two bytes each from two locations in SLM and reads the first bit from each byte. What I do not understand is why the broken code is not working - we looked at the IGC shader dumps and didn't see anything obvious:

Working shader (four 1 byte loads - note that they have been optimized to two, two byte loads just as we would expect)

call void @llvm.genx.GenISA.threadgroupbarrier(), !dbg !408
  %IntToPtr2102 = inttoptr i32 %16 to i16 addrspace(3)*, !dbg !408
  %vCastload103 = load i16, i16 addrspace(3)* %IntToPtr2102, align 2, !dbg !408
  %IntToPtr2104 = inttoptr i32 %17 to i16 addrspace(3)*, !dbg !408
  %vCastload105 = load i16, i16 addrspace(3)* %IntToPtr2104, align 2, !dbg !408
  %64 = bitcast i16 %vCastload103 to <2 x i8>, !dbg !408
  %65 = extractelement <2 x i8> %64, i32 0, !dbg !408
  %66 = extractelement <2 x i8> %64, i32 1, !dbg !408
  %67 = bitcast i16 %vCastload105 to <2 x i8>, !dbg !408
  %68 = extractelement <2 x i8> %67, i32 0, !dbg !408
  %69 = extractelement <2 x i8> %67, i32 1, !dbg !408
  %b2s121 = zext i8 %65 to i16, !dbg !408
  %b2s122 = icmp ne i16 %b2s121, 0, !dbg !408
  %70 = sext i1 %b2s122 to i64, !dbg !409
  %71 = sub i64 0, %70, !dbg !409
  %b2s123 = zext i8 %66 to i16, !dbg !408
  %b2s124 = icmp ne i16 %b2s123, 0, !dbg !408
  %72 = sext i1 %b2s124 to i64, !dbg !409
  %73 = sub i64 0, %72, !dbg !409
  %b2s125 = zext i8 %68 to i16, !dbg !408
  %b2s126 = icmp ne i16 %b2s125, 0, !dbg !408
  %74 = sext i1 %b2s126 to i64, !dbg !409
  %75 = sub i64 0, %74, !dbg !409
  %b2s127 = zext i8 %69 to i16, !dbg !408
  %b2s128 = icmp ne i16 %b2s127, 0, !dbg !408
  %76 = sext i1 %b2s128 to i64, !dbg !409
  %77 = sub i64 0, %76, !dbg !409
  call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i32 0), !dbg !410
  call void @llvm.genx.GenISA.threadgroupbarrier(), !dbg !410
  %78 = call i32 @llvm.genx.GenISA.WaveBallot(i1 true, i32 0), !dbg !410
  %79 = call i32 @llvm.genx.GenISA.bfrev(i32 %78), !dbg !410
  %80 = call i32 @llvm.ctlz.i32(i32 %79, i1 false) #10, !dbg !410, !range !411

broken shader, loads are 4 bytes each and there are 4 of them!:

call void @llvm.genx.GenISA.threadgroupbarrier(), !dbg !408
  %IntToPtr2134 = inttoptr i32 %16 to i32 addrspace(3)*, !dbg !408
  %vCastload135 = load i32, i32 addrspace(3)* %IntToPtr2134, align 2, !dbg !408
  %64 = add nuw nsw i32 %16, 8, !dbg !408
  %IntToPtr2136 = inttoptr i32 %64 to i32 addrspace(3)*, !dbg !408
  %vCastload137 = load i32, i32 addrspace(3)* %IntToPtr2136, align 2, !dbg !408
  %IntToPtr2138 = inttoptr i32 %17 to i32 addrspace(3)*, !dbg !408
  %vCastload139 = load i32, i32 addrspace(3)* %IntToPtr2138, align 2, !dbg !408
  %65 = add nuw nsw i32 %16, 1032, !dbg !408
  %IntToPtr2140 = inttoptr i32 %65 to i32 addrspace(3)*, !dbg !408
  %vCastload141 = load i32, i32 addrspace(3)* %IntToPtr2140, align 2, !dbg !408
  %66 = bitcast i32 %vCastload135 to <4 x i8>, !dbg !408
  %67 = bitcast i32 %vCastload137 to <4 x i8>, !dbg !408
  %68 = extractelement <4 x i8> %66, i32 0, !dbg !408
  %b2s149 = sext i8 %68 to i16, !dbg !408
  %69 = extractelement <4 x i8> %67, i32 0, !dbg !408
  %b2s150 = sext i8 %69 to i16, !dbg !408
  %70 = bitcast i32 %vCastload139 to <4 x i8>, !dbg !408
  %71 = bitcast i32 %vCastload141 to <4 x i8>, !dbg !408
  %72 = extractelement <4 x i8> %70, i32 0, !dbg !408
  %b2s151 = sext i8 %72 to i16, !dbg !408
  %73 = extractelement <4 x i8> %71, i32 0, !dbg !408
  %b2s152 = sext i8 %73 to i16, !dbg !408
  %b2s153 = and i16 %b2s149, 1, !dbg !408
  %b2s154 = and i16 %b2s150, 1, !dbg !408
  %b2s155 = and i16 %b2s151, 1, !dbg !408
  %b2s156 = and i16 %b2s152, 1, !dbg !409
  call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i32 0), !dbg !409
  call void @llvm.genx.GenISA.threadgroupbarrier(), !dbg !409
  %74 = call i32 @llvm.genx.GenISA.WaveBallot(i1 true, i32 0), !dbg !409
  %75 = call i32 @llvm.genx.GenISA.bfrev(i32 %74), !dbg !409
  %76 = call i32 @llvm.ctlz.i32(i32 %75, i1 false) #10, !dbg !409, !range !410

Should we just make it to 2 x i8 and then convert it to i1 by trunc or cmp.ne

Well, that's what the patch does. But I think long term we should figure out why the code as generated is not working.

I am working on a unit test to make the IR a little easier to read and then I can try and get some PTX.

@alexbaden
Copy link
Contributor Author

After further investigation the code generated by Triton without this patch is correct. But the PromoteBools pass within IGC appears to be changing the bit type in the llvm vector type to bytes, which loads the incorrect data. A ticket has been filed with IGC, but I think we should merge this patch and the test for now and then revert changes in common code once IGC resolves the problem.

@alexbaden alexbaden force-pushed the alex/boolean_convert branch from 2af840a to 212b45e Compare March 6, 2025 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants