Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chlo.lgamma const prop #182

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

chlo.lgamma const prop #182

wants to merge 11 commits into from

Conversation

vimarsh6739
Copy link
Member

@vimarsh6739 vimarsh6739 commented Dec 9, 2024

for #179

if(!matchPattern(op.getOperand(),m_Constant(&inputAttr)))
return failure();

Value result = materializeLgamma(rewriter,op.getLoc(),op->getOperands());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is fun, but stablehlo at least has a eval method with the literal constant implemented. Apparently stablehlo doesn;t =/.

Since this code looks like it comes from lowerchlotostablehlo, maybe we can just run the relevant lowerchlo function here if it has a constant operand (rather than copying the lowering here).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but now that I look at it, it's marked static

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be happy to expose the functions in some shareable way! Could make a header ChloDecompositionUtils.h and can expose the individual decomp pattern, or the materialize function of interest.

Copy link
Member Author

@vimarsh6739 vimarsh6739 Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be helpful! It'd also be good to have a single implementation for the function definition

@wsmoses wsmoses requested a review from Pangoraw December 9, 2024 20:54
@vimarsh6739 vimarsh6739 changed the title [draft] chlo::lgamma const prop [draft] chlo.lgamma const prop Dec 11, 2024
@vimarsh6739 vimarsh6739 changed the title [draft] chlo.lgamma const prop chlo.lgamma const prop Dec 11, 2024
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Dec 15, 2024
We want to perform constant propogation through `chlo.lgamma` in
Enzyme-JaX


[Kevin](EnzymeAD/Enzyme-JAX#182 (comment))
mentioned he was open to exposing some materialize functions (which are
currently static, and not callable from [our
pass](https://github.com/EnzymeAD/Enzyme-JAX/blob/main/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp)
atm)



@wsmoses @GleasonK
@wsmoses
Copy link
Member

wsmoses commented Jan 3, 2025

@vimarsh6739 anything blocking here (I think we've since updated to a stablehlo version with the interface)?

@vimarsh6739
Copy link
Member Author

Not really, I'll cleanup and update the PR in a bit(need to add a lit test)
We also need to propagate through log and log1p but that can be a separate PR.

@vimarsh6739
Copy link
Member Author

So, interestingly, it seems like some other optimization is messing with chlo.lgamma after rebasing to main, because by itself, the constprop lowering works - I have attached the lit test output in test/lit_tests/chlo_lower_to_stablehlo.mlir here (basically only enabled GammaConstProp in the patterns).

module {
  func.func @lgamma_f32() -> tensor<f32> {
    %cst = stablehlo.constant dense<0x7F800000> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.14472985> : tensor<f32>
    %cst_1 = stablehlo.constant dense<3.14159274> : tensor<f32>
    %cst_2 = stablehlo.constant dense<0.918938517> : tensor<f32>
    %cst_3 = stablehlo.constant dense<2.01490307> : tensor<f32>
    %cst_4 = stablehlo.constant dense<7.500000e+00> : tensor<f32>
    %cst_5 = stablehlo.constant dense<8.000000e+00> : tensor<f32>
    %cst_6 = stablehlo.constant dense<1.50563267E-7> : tensor<f32>
    %cst_7 = stablehlo.constant dense<7.000000e+00> : tensor<f32>
    %cst_8 = stablehlo.constant dense<9.98436917E-6> : tensor<f32>
    %cst_9 = stablehlo.constant dense<6.000000e+00> : tensor<f32>
    %cst_10 = stablehlo.constant dense<-0.138571098> : tensor<f32>
    %cst_11 = stablehlo.constant dense<5.000000e+00> : tensor<f32>
    %cst_12 = stablehlo.constant dense<12.5073433> : tensor<f32>
    %cst_13 = stablehlo.constant dense<4.000000e+00> : tensor<f32>
    %cst_14 = stablehlo.constant dense<-176.615036> : tensor<f32>
    %cst_15 = stablehlo.constant dense<3.000000e+00> : tensor<f32>
    %cst_16 = stablehlo.constant dense<771.323425> : tensor<f32>
    %cst_17 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %cst_18 = stablehlo.constant dense<-1259.13916> : tensor<f32>
    %cst_19 = stablehlo.constant dense<676.520386> : tensor<f32>
    %cst_20 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_21 = stablehlo.constant dense<5.000000e-01> : tensor<f32>
    %0 = stablehlo.compare  LT, %cst_20, %cst_21 : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %1 = stablehlo.negate %cst_20 : tensor<f32>
    %2 = stablehlo.subtract %cst_20, %cst_20 : tensor<f32>
    %3 = stablehlo.select %0, %1, %2 : tensor<i1>, tensor<f32>
    %4 = stablehlo.add %3, %cst_20 : tensor<f32>
    %5 = stablehlo.divide %cst_19, %4 : tensor<f32>
    %6 = stablehlo.add %cst_20, %5 : tensor<f32>
    %7 = stablehlo.add %3, %cst_17 : tensor<f32>
    %8 = stablehlo.divide %cst_18, %7 : tensor<f32>
    %9 = stablehlo.add %6, %8 : tensor<f32>
    %10 = stablehlo.add %3, %cst_15 : tensor<f32>
    %11 = stablehlo.divide %cst_16, %10 : tensor<f32>
    %12 = stablehlo.add %9, %11 : tensor<f32>
    %13 = stablehlo.add %3, %cst_13 : tensor<f32>
    %14 = stablehlo.divide %cst_14, %13 : tensor<f32>
    %15 = stablehlo.add %12, %14 : tensor<f32>
    %16 = stablehlo.add %3, %cst_11 : tensor<f32>
    %17 = stablehlo.divide %cst_12, %16 : tensor<f32>
    %18 = stablehlo.add %15, %17 : tensor<f32>
    %19 = stablehlo.add %3, %cst_9 : tensor<f32>
    %20 = stablehlo.divide %cst_10, %19 : tensor<f32>
    %21 = stablehlo.add %18, %20 : tensor<f32>
    %22 = stablehlo.add %3, %cst_7 : tensor<f32>
    %23 = stablehlo.divide %cst_8, %22 : tensor<f32>
    %24 = stablehlo.add %21, %23 : tensor<f32>
    %25 = stablehlo.add %3, %cst_5 : tensor<f32>
    %26 = stablehlo.divide %cst_6, %25 : tensor<f32>
    %27 = stablehlo.add %24, %26 : tensor<f32>
    %28 = stablehlo.add %cst_4, %3 : tensor<f32>
    %29 = stablehlo.divide %3, %cst_4 : tensor<f32>
    %30 = stablehlo.log_plus_one %29 : tensor<f32>
    %31 = stablehlo.add %cst_3, %30 : tensor<f32>
    %32 = stablehlo.divide %28, %31 : tensor<f32>
    %33 = stablehlo.add %3, %cst_21 : tensor<f32>
    %34 = stablehlo.subtract %33, %32 : tensor<f32>
    %35 = stablehlo.multiply %34, %31 : tensor<f32>
    %36 = stablehlo.log %27 : tensor<f32>
    %37 = stablehlo.add %cst_2, %35 : tensor<f32>
    %38 = stablehlo.add %37, %36 : tensor<f32>
    %39 = stablehlo.abs %cst_20 : tensor<f32>
    %40 = stablehlo.floor %39 : tensor<f32>
    %41 = stablehlo.subtract %39, %40 : tensor<f32>
    %42 = stablehlo.compare  LT, %cst_21, %41 : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %43 = stablehlo.subtract %cst_20, %41 : tensor<f32>
    %44 = stablehlo.select %42, %43, %41 : tensor<i1>, tensor<f32>
    %45 = stablehlo.multiply %cst_1, %44 : tensor<f32>
    %46 = stablehlo.sine %45 : tensor<f32>
    %47 = stablehlo.log %46 : tensor<f32>
    %48 = stablehlo.subtract %cst_0, %47 : tensor<f32>
    %49 = stablehlo.subtract %48, %38 : tensor<f32>
    %50 = stablehlo.is_finite %47 : (tensor<f32>) -> tensor<i1>
    %51 = stablehlo.negate %47 : tensor<f32>
    %52 = stablehlo.select %50, %49, %51 : tensor<i1>, tensor<f32>
    %53 = stablehlo.select %0, %52, %38 : tensor<i1>, tensor<f32>
    %54 = chlo.is_inf %cst_20 : tensor<f32> -> tensor<i1>
    %55 = stablehlo.select %54, %cst, %53 : tensor<i1>, tensor<f32>
    return %55 : tensor<f32>
  }
}

After re-enabling all other patterns though, I get this.

PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: bazel-bin/enzymexlamlir-opt --pass-pipeline=builtin.module(enzyme-hlo-opt) test/lit_te
sts/chlo_to_stablehlo.mlir
[1]    58336 segmentation fault  bazel-bin/enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt)" 

Currently trying to sweep through the list of emitted ops, but would appreciate pointers as to why

@wsmoses
Copy link
Member

wsmoses commented Jan 4, 2025

If you run it in gdb can you tell which one is triggering.

another option is you can pass —debug and it will spew for every transform

@vimarsh6739
Copy link
Member Author

Haven't enabled debug symbols(laptop) but seems to crash in DivideSqrtToMultiplyRsqrt

% lldb bazel-bin/enzymexlamlir-opt -- --pass-pipeline="builtin.module(enzyme-hlo-opt)" test/lit_tests/debugger.ml
ir
(lldb) target create "bazel-bin/enzymexlamlir-opt"
Current executable set to '/Users/vsathia/dev/Enzyme-JAX/bazel-bin/enzymexlamlir-opt' (arm64).
(lldb) settings set -- target.run-args  "--pass-pipeline=builtin.module(enzyme-hlo-opt)" "test/lit_tests/debugger
.mlir"
(lldb) run
Process 59839 launched: '/Users/vsathia/dev/Enzyme-JAX/bazel-bin/enzymexlamlir-opt' (arm64)
Process 59839 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x24)
    frame #0: 0x00000001003db86c enzymexlamlir-opt`DivideSqrtToMultiplyRsqrt::matchAndRewrite(mlir::stablehlo::Di
vOp, mlir::PatternRewriter&) const + 80
enzymexlamlir-opt`DivideSqrtToMultiplyRsqrt::matchAndRewrite:
->  0x1003db86c <+80>: ldr    w8, [x8]
    0x1003db870 <+84>: cmp    w8, #0x0
    0x1003db874 <+88>: mov    x9, #-0x10 ; =-16 
    0x1003db878 <+92>: csel   x9, xzr, x9, eq
(lldb) 

@vimarsh6739
Copy link
Member Author

which is again strange as the expanded op doesn't have a sqrt....

@wsmoses
Copy link
Member

wsmoses commented Jan 4, 2025

cc @avik-pal

@vimarsh6739
Copy link
Member Author

Ok, disabling it works for now - @avik-pal you can use the above example as a test case.

@wsmoses
Copy link
Member

wsmoses commented Jan 4, 2025

@vimarsh6739 does #216 fix it for you?

@vimarsh6739
Copy link
Member Author

let me check

@vimarsh6739
Copy link
Member Author

that works

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, but maybe it would make sense to add the log, is_inf, and log_plus_one constprop ones first (so this test just becomes a single const return)?

@vimarsh6739
Copy link
Member Author

Yep, will add those as a separate PR. Let's hold off on merging for now then.

@vimarsh6739
Copy link
Member Author

log and log1p in #218

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants