Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atomic instruction of global value considered active (should be inactive) #2212

Open
martinjrobins opened this issue Dec 29, 2024 · 2 comments

Comments

@martinjrobins
Copy link
Contributor

I'm using the EnzymeCreateForwardDiff function in enzyme to calculate the gradient of a function (rhs). This function calculates the gradient of a vector-valued function. I'm using threading: the function is designed to be called from different threads, passing in two ints that give the current thread id, and the total number of threads. Each thread computes part of the vector of values

I'm using atomics to implement synchronisation barriers within the function, and this is where my problem is occuring. I have a global value thread_counter which is used to implement the barrier, and enzyme is incorrectly considering this global as active, but it should have no effect on the values I'm taking the gradient of. Is there a way of forcing enzyme to consider thread_counter as inactive?

The IR for the function and the error I'm getting is reproduced below:

error: <unknown>:0:0: in function preprocess_rhs void (double, ptr, ptr, ptr, i32, i32): Enzyme: ; Function Attrs: mustprogress nofree nounwind willreturn memory(readwrite, inaccessiblemem: none)
define void @preprocess_rhs(double %0, ptr noalias nocapture readonly %1, ptr noalias nocapture writeonly %2, ptr noalias nocapture writeonly %3, i32 %4, i32 %5) local_unnamed_addr #7 {
entry:
  %y = getelementptr inbounds double, ptr %1, i64 50
  %i_times_size = mul i32 %4, 50
  %start = udiv i32 %i_times_size, %5
  %done = icmp ugt i32 %start, 49
  br i1 %done, label %exit.thread, label %threading_block

exit.thread:                                      ; preds = %entry
  %6 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
  %current_value.i74 = load atomic i32, ptr @thread_counter monotonic, align 4
  br label %exit67

threading_block:                                  ; preds = %entry
  %i_plus_one_times_size = add i32 %i_times_size, 50
  %end = udiv i32 %i_plus_one_times_size, %5
  %end3 = tail call i32 @llvm.umin.i32(i32 %end, i32 50) #8
  %7 = zext i32 %start to i64
  %8 = add nuw i32 %start, 1
  %umax = tail call i32 @llvm.umax.i32(i32 %end3, i32 %8) #8
  %9 = xor i32 %start, -1
  %10 = add i32 %umax, %9
  %11 = zext i32 %10 to i64
  %12 = add nuw nsw i64 %11, 1
  %min.iters.check = icmp eq i32 %10, 0
  br i1 %min.iters.check, label %r-0.preheader, label %vector.ph

vector.ph:                                        ; preds = %threading_block
  %n.vec = and i64 %12, -2
  %ind.end = add nuw nsw i64 %n.vec, %7
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %iv = phi i64 [ %iv.next, %vector.body ], [ 0, %vector.ph ]
  %13 = shl nuw i64 %iv, 1
  %iv.next = add nuw nsw i64 %iv, 1
  %offset.idx = add i64 %13, %7
  %14 = getelementptr inbounds double, ptr %1, i64 %offset.idx
  %wide.load = load <2 x double>, ptr %14, align 8
  %15 = getelementptr inbounds double, ptr %y, i64 %offset.idx
  %wide.load95 = load <2 x double>, ptr %15, align 8
  %16 = fadd <2 x double> %wide.load, %wide.load95
  %17 = getelementptr inbounds double, ptr %2, i64 %13
  store <2 x double> %16, ptr %17, align 8
  %index.next = add nuw i64 %13, 2
  %18 = icmp eq i64 %index.next, %n.vec
  br i1 %18, label %middle.block, label %vector.body, !llvm.loop !4

middle.block:                                     ; preds = %vector.body
  %cmp.n = icmp eq i64 %12, %n.vec
  br i1 %cmp.n, label %threading_block23, label %r-0.preheader

r-0.preheader:                                    ; preds = %middle.block, %threading_block
  %indvars.iv83.ph = phi i64 [ %7, %threading_block ], [ %ind.end, %middle.block ]
  %indvars.iv.ph = phi i64 [ 0, %threading_block ], [ %n.vec, %middle.block ]
  br label %r-0

r-0:                                              ; preds = %r-0, %r-0.preheader
  %iv1 = phi i64 [ %iv.next2, %r-0 ], [ 0, %r-0.preheader ]
  %19 = add nuw nsw i64 %indvars.iv.ph, %iv1
  %iv.next2 = add nuw nsw i64 %iv1, 1
  %20 = add nuw nsw i64 %indvars.iv83.ph, %iv1
  %indvars.iv.next = add nuw nsw i64 %19, 1
  %r-06 = getelementptr inbounds double, ptr %1, i64 %20
  %r-07 = load double, ptr %r-06, align 8
  %r-08 = getelementptr inbounds double, ptr %y, i64 %20
  %r-09 = load double, ptr %r-08, align 8
  %r-010 = fadd double %r-07, %r-09
  %r-012 = getelementptr inbounds double, ptr %2, i64 %19
  store double %r-010, ptr %r-012, align 8
  %indvars.iv.next84 = add nuw nsw i64 %20, 1
  %21 = trunc i64 %indvars.iv.next84 to i32
  %r-014 = icmp ugt i32 %end3, %21
  br i1 %r-014, label %r-0, label %threading_block23.loopexit, !llvm.loop !5

threading_block23.loopexit:                       ; preds = %r-0
  br label %threading_block23

threading_block23:                                ; preds = %threading_block23.loopexit, %middle.block
  %22 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
  %current_value.i = load atomic i32, ptr @thread_counter monotonic, align 4
  %23 = add nuw i32 %start, 1
  %umax96 = tail call i32 @llvm.umax.i32(i32 %end3, i32 %23) #8
  %24 = xor i32 %start, -1
  %25 = add i32 %umax96, %24
  %26 = zext i32 %25 to i64
  %27 = add nuw nsw i64 %26, 1
  %min.iters.check99 = icmp eq i32 %25, 0
  br i1 %min.iters.check99, label %F-0.preheader, label %vector.ph100

vector.ph100:                                     ; preds = %threading_block23
  %n.vec102 = and i64 %27, -2
  %ind.end103 = add nuw nsw i64 %n.vec102, %7
  %ind.end105 = trunc i64 %n.vec102 to i32
  br label %vector.body108

vector.body108:                                   ; preds = %vector.body108, %vector.ph100
  %iv3 = phi i64 [ %iv.next4, %vector.body108 ], [ 0, %vector.ph100 ]
  %28 = shl nuw i64 %iv3, 1
  %iv.next4 = add nuw nsw i64 %iv3, 1
  %offset.idx111 = add i64 %28, %7
  %29 = getelementptr inbounds double, ptr %1, i64 %offset.idx111
  %wide.load112 = load <2 x double>, ptr %29, align 8
  %sext = shl i64 %28, 32
  %30 = ashr exact i64 %sext, 32
  %31 = getelementptr inbounds double, ptr %3, i64 %30
  store <2 x double> %wide.load112, ptr %31, align 8
  %index.next113 = add nuw i64 %28, 2
  %32 = icmp eq i64 %index.next113, %n.vec102
  br i1 %32, label %middle.block97, label %vector.body108, !llvm.loop !6

middle.block97:                                   ; preds = %vector.body108
  %cmp.n107 = icmp eq i64 %27, %n.vec102
  br i1 %cmp.n107, label %F-1.preheader, label %F-0.preheader

F-0.preheader:                                    ; preds = %middle.block97, %threading_block23
  %indvars.iv88.ph = phi i64 [ %7, %threading_block23 ], [ %ind.end103, %middle.block97 ]
  %next_expr_index3279.ph = phi i32 [ 0, %threading_block23 ], [ %ind.end105, %middle.block97 ]
  %33 = zext i32 %next_expr_index3279.ph to i64
  br label %F-0

F-0:                                              ; preds = %F-0, %F-0.preheader
  %iv5 = phi i64 [ %iv.next6, %F-0 ], [ 0, %F-0.preheader ]
  %34 = add i64 %33, %iv5
  %iv.next6 = add nuw nsw i64 %iv5, 1
  %35 = trunc i64 %34 to i32
  %36 = add nuw nsw i64 %indvars.iv88.ph, %iv5
  %next_expr_index32 = add i32 %35, 1
  %F-033 = getelementptr inbounds double, ptr %1, i64 %36
  %F-034 = load double, ptr %F-033, align 8
  %37 = sext i32 %35 to i64
  %F-036 = getelementptr inbounds double, ptr %3, i64 %37
  store double %F-034, ptr %F-036, align 8
  %indvars.iv.next89 = add nuw nsw i64 %36, 1
  %38 = trunc i64 %indvars.iv.next89 to i32
  %F-039 = icmp ugt i32 %end3, %38
  br i1 %F-039, label %F-0, label %F-1.preheader.loopexit, !llvm.loop !7

F-1.preheader.loopexit:                           ; preds = %F-0
  br label %F-1.preheader

F-1.preheader:                                    ; preds = %F-1.preheader.loopexit, %middle.block97
  %39 = add nuw i32 %start, 1
  %umax114 = tail call i32 @llvm.umax.i32(i32 %end3, i32 %39) #8
  %40 = xor i32 %start, -1
  %41 = add i32 %umax114, %40
  %42 = zext i32 %41 to i64
  %43 = add nuw nsw i64 %42, 1
  %min.iters.check117 = icmp eq i32 %41, 0
  br i1 %min.iters.check117, label %F-1.preheader133, label %vector.ph118

vector.ph118:                                     ; preds = %F-1.preheader
  %n.vec120 = and i64 %43, -2
  %ind.end121 = add nuw nsw i64 %n.vec120, %7
  %ind.end123 = trunc i64 %n.vec120 to i32
  br label %vector.body126

vector.body126:                                   ; preds = %vector.body126, %vector.ph118
  %iv7 = phi i64 [ %iv.next8, %vector.body126 ], [ 0, %vector.ph118 ]
  %44 = shl nuw i64 %iv7, 1
  %iv.next8 = add nuw nsw i64 %iv7, 1
  %offset.idx129 = add i64 %44, %7
  %45 = getelementptr inbounds double, ptr %y, i64 %offset.idx129
  %wide.load130 = load <2 x double>, ptr %45, align 8
  %46 = shl i64 %44, 32
  %sext132 = add i64 %46, 214748364800
  %47 = ashr exact i64 %sext132, 32
  %48 = getelementptr inbounds double, ptr %3, i64 %47
  store <2 x double> %wide.load130, ptr %48, align 8
  %index.next131 = add nuw i64 %44, 2
  %49 = icmp eq i64 %index.next131, %n.vec120
  br i1 %49, label %middle.block115, label %vector.body126, !llvm.loop !8

middle.block115:                                  ; preds = %vector.body126
  %cmp.n125 = icmp eq i64 %43, %n.vec120
  br i1 %cmp.n125, label %exit67, label %F-1.preheader133

F-1.preheader133:                                 ; preds = %middle.block115, %F-1.preheader
  %indvars.iv91.ph = phi i64 [ %7, %F-1.preheader ], [ %ind.end121, %middle.block115 ]
  %next_expr_index5881.ph = phi i32 [ 0, %F-1.preheader ], [ %ind.end123, %middle.block115 ]
  %50 = zext i32 %next_expr_index5881.ph to i64
  br label %F-1

F-1:                                              ; preds = %F-1, %F-1.preheader133
  %iv9 = phi i64 [ %iv.next10, %F-1 ], [ 0, %F-1.preheader133 ]
  %51 = add i64 %50, %iv9
  %iv.next10 = add nuw nsw i64 %iv9, 1
  %52 = trunc i64 %51 to i32
  %53 = add nuw nsw i64 %indvars.iv91.ph, %iv9
  %next_expr_index58 = add i32 %52, 1
  %F-159 = getelementptr inbounds double, ptr %y, i64 %53
  %F-160 = load double, ptr %F-159, align 8
  %F-161 = add i32 %52, 50
  %54 = sext i32 %F-161 to i64
  %F-162 = getelementptr inbounds double, ptr %3, i64 %54
  store double %F-160, ptr %F-162, align 8
  %indvars.iv.next92 = add nuw nsw i64 %53, 1
  %55 = trunc i64 %indvars.iv.next92 to i32
  %F-165 = icmp ugt i32 %end3, %55
  br i1 %F-165, label %F-1, label %exit67.loopexit, !llvm.loop !9

exit67.loopexit:                                  ; preds = %F-1
  br label %exit67

exit67:                                           ; preds = %exit67.loopexit, %middle.block115, %exit.thread
  ret void
}

  %6 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
 Active atomic inst not yet handled
@martinjrobins
Copy link
Contributor Author

martinjrobins commented Dec 31, 2024

After looking through the enzyme code, I figured that I could mark thread_counter as inactive by either:

  1. adding the metadata "enzyme_inactive", or
  2. incorporating the string "enzyme_const" in the name of the global

I tried 1, but this didn't seem to have any effect (it could be an inkwell bug as I also tried to print the module and it segfaulted when it got to the metadata I added). I tried 2 and this worked and enzyme no longer errored on the atomicrmw instruction. I'll carry on and check that the forward gradient is calculated correctly

@wsmoses
Copy link
Member

wsmoses commented Jan 4, 2025

in this case the two easiest thigns are

  1. Marking thread_counter and/or the instruction as inactive (e.g. with metadata)
  2. Marking thread_counter and/or result as as an integer (implies inactive) (e.g. add an enzyme_type attribute
  3. Marking all non globals as inactive (there's a flag:
    cl::opt<bool> nonmarkedglobals_inactiveloads(
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants