Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HLO Canonicalizations Todo list #54

Open
4 tasks done
wsmoses opened this issue Mar 14, 2024 · 31 comments
Open
4 tasks done

HLO Canonicalizations Todo list #54

wsmoses opened this issue Mar 14, 2024 · 31 comments

Comments

@wsmoses
Copy link
Member

wsmoses commented Mar 14, 2024

To mark which ones we see worth doing, are doing / need to do

cc @ivanradanov @ftynse

  • iota reshape (becomes single iota)
    %195 = stablehlo.iota dim = 0 : tensor<1024xi32>
    %196 = stablehlo.reshape %195 : (tensor<1024xi32>) -> tensor<1x1x1024xi32>
  • reshape of pad (becomes diff pad)
 %175 = stablehlo.pad %174, %148, low = [0, 0, 1024, 0, 0], high = [0, 0, 0, 0, 0], interior = [0, 0, 0, 0, 0] : (tensor<1x3x1024x1x1xf32>, tensor<f32>) -> tensor<1x3x2048x1x1xf32>
    %176 = stablehlo.reshape %175 : (tensor<1x3x2048x1x1xf32>) -> tensor<1x3x2048xf32>
    
  • mul of pad with 0 (becomes pad of mul) 44026d4
    %175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
    %177 = stablehlo.multiply %176, %112 : tensor<1x3x2048xf32>
  • broadcast of pad (becomes pad of broadcast)
    %175 = stablehlo.pad %174, %constant_0, low = [0, 0, 1024], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<1x3x1024xf32>, tensor<f32>) -> tensor<1x3x2048xf32>
    %189 = stablehlo.broadcast_in_dim %177, dims = [0, 2, 4] : (tensor<1x3x2048xf32>) -> tensor<1x1x3x1024x2048xf32>
@ftynse
Copy link
Collaborator

ftynse commented Mar 14, 2024

broadcast of pad (becomes pad of broadcast)

Is this beneficial? Pad will operate on a larger buffer as a result.

@ftynse
Copy link
Collaborator

ftynse commented Mar 14, 2024

iota reshape (becomes single iota)
reshape of pad (becomes diff pad)

I expect this to have little practical effect: reshape should be a metadata operation, I'm not sure it even affects the generated code.

@wsmoses
Copy link
Member Author

wsmoses commented Mar 14, 2024

Yeah I think the question right now is that there's a bunch of these unnecessarily in the way of other ops (which hopefully do have practical impact canonicalizing together), and hopefully these small things would enable them (while in isolation indeed probs not be individually perf-inducing).

I decided to started writing the immediate ones I saw out of the mlir file instead of the bigger ones to do down the line (like batching dots), since had mini IR's for.

@ftynse
Copy link
Collaborator

ftynse commented Mar 14, 2024

I started doing pad/mul as that is clearly beneficial by making mul smaller

@wsmoses
Copy link
Member Author

wsmoses commented Mar 14, 2024

So in the code I saw the actual code was mul(reshape(pad(...))), which I split out into the reshape(pad) and mul(reshape). So I think we'll also need to get that to work end-to-end (I can start it though assuming you haven't).

Ironically you can see I was lazy and didn't even rename the result ops (e.g. reshape(pad) was 175/176 and mul(pad) was 175/177 [in the real code it was using the reshape result])

@wsmoses
Copy link
Member Author

wsmoses commented Mar 14, 2024

Is this beneficial? Pad will operate on a larger buffer as a result.

I hypothesize so. In principle a pad of a broadcast should be fusable (e.g. a pad is just doing out[idx] = in bounds ? in[idx2] : pad, and the in[idx2] being a broadcast may get simplified). I suppose the reverse is also the case, but if they don't get fused doing the memset(x, 0) for the bigger buffers at once seems wise. Also like above moving the pads out may help downstream opts.

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • Reshape transpose (this one im iffy of its utility)
  • broadcast iota

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • generalize reshape of concat

actually jk the outer one is hard

    %1384 = stablehlo.concatenate %1382, %1383, dim = 1 : (tensor<1x2x48x4xbf16>, tensor<1x1x48x4xbf16>) -> tensor<1x3x48x4xbf16>
    %1385 = stablehlo.reshape %1384 : (tensor<1x3x48x4xbf16>) -> tensor<1x144x4xbf16>

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • Reduce sum of reshape: -> reduce of unreshaped
    %1819 = stablehlo.reshape %1818 : (tensor<56xf32>) -> tensor<7x8f32>
    %1913 = stablehlo.multiply %1819, %1819 : tensor<7x8xf32>
    %1914 = stablehlo.reduce(%1913 init: %147) applies stablehlo.add across dimensions = [0, 1] : (tensor<7x8xf32>, tensor<f32>) -> tensor<f32>

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • Full Reduce of concat -> concat of reduces
  • Full reduce of transposes -> reduce of operands
  • Reduce sum of convert -> move the convert inside the reduce [jk this is possibly not representable]
  • Reduce of batched dot (aka matmul) -> dot of reduced operands
        %1205 = stablehlo.dot_general %1204, %778, contracting_dims = [0, 1] x [0, 1], precision = [DEFAULT, DEFAULT] : (tensor<46x123x56xbf16>, tensor<46x1234x32xbf16>) -> tensor<56x32xbf16>
        %1914 = stablehlo.reduce(%1205 init: %147) applies stablehlo.add across dimensions = [0, 1] : (tensor<56x32xf32>, tensor<f32>) -> tensor<f32>
  • Sum reduce of add -> add of sum reduce
  • Sum reduce of pad -> sum reduce padded op + ( number of inserted vals ) * padded value [easier if 0 which may just do]
    • Did the zero verison of this

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • convert of pad -> pad of convert [esp for constants]

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • pad of constants (which may be different, but should do the small size check)

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • negate(mul(broadcast(x), y) ) -> mul(broadcast(negate(x)), y))

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • negate(divide(constant, b)) -> divide(constant2, b) [if assuming no infinite values]

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • pad of same values [also taking into acct negative vs positive zero]
    %147 = stablehlo.constant dense<0.000000e+00> 
    %81 = stablehlo.constant dense<-0.000000e+00> 
    %1185 = stablehlo.pad %81, %147, low = [...], high = [...], interior = [...]

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • dot(pad(zero from up to i and j, axis=contract, x), y) -> dot(x, y[i:j])

@wsmoses
Copy link
Member Author

wsmoses commented Mar 15, 2024

  • mul(bcast x, mul(bcast y, z)) -> mul(z, bcast(mul x, y))

  • distributive property full reduce add mul(z, bcast y) -> mul(full reduce add z, y)

@wsmoses
Copy link
Member Author

wsmoses commented Mar 16, 2024

  • slice of broadcast -> broadcast
  • slice of transpose -> transpose of slice

@wsmoses
Copy link
Member Author

wsmoses commented Mar 16, 2024

  • slice of convert -> convert of slice. (perhaps generalize to any unary op, if only user)

@wsmoses
Copy link
Member Author

wsmoses commented Mar 16, 2024

@wsmoses
Copy link
Member Author

wsmoses commented Mar 16, 2024

  • (partial) sum reduce of broadcast

@wsmoses
Copy link
Member Author

wsmoses commented Mar 17, 2024

  • convert of convert
  • generalize transpose transpose to transpose convert transpose

@wsmoses
Copy link
Member Author

wsmoses commented Mar 17, 2024

  • dot general transpose(A), B or A, transpose(B) -> dot general A, B (where applicable)

@ftynse
Copy link
Collaborator

ftynse commented Mar 18, 2024

(partial) sum reduce of broadcast

Do you have an example of this? There is a bunch of cases that are already supported https://github.com/EnzymeAD/Enzyme-JAX/blob/main/test/lit_tests/broadcastreduce.mlir

@wsmoses
Copy link
Member Author

wsmoses commented Mar 22, 2024

  • Select of Pad

@wsmoses
Copy link
Member Author

wsmoses commented Mar 22, 2024

(partial) sum reduce of broadcast

Do you have an example of this? There is a bunch of cases that are already supported https://github.com/EnzymeAD/Enzyme-JAX/blob/main/test/lit_tests/broadcastreduce.mlir

Unfortunately not presently, but I'll post when it comes up again.

@wsmoses
Copy link
Member Author

wsmoses commented Mar 22, 2024

@wsmoses
Copy link
Member Author

wsmoses commented Mar 29, 2024

  • slice of dot general
 %1205 = stablehlo.dot_general %1181, %482, batching_dims = [0, 2] x [0, 1], contracting_dims = [3, 1] x [2, 3], precision = [DEFAULT, DEFAULT] : (tensor<1x16x1x20x100xbf16>, tensor<1x1x20x16x123xbf16>) -> tensor<1x1x100x123xbf16>
 %1208 = stablehlo.slice %1205 [0:1, 0:1, 75:100, 0:256] : (tensor<1x1x100x123xbf16>) -> tensor<1x1x25x123xbf16>

@wsmoses
Copy link
Member Author

wsmoses commented Apr 13, 2024

  • reshapce(concat(constants or reshapes...)) -> concat(constants or reshapes...)

@wsmoses
Copy link
Member Author

wsmoses commented Apr 14, 2024

Generalize pad propagations to work with an interstitial reshape

  • PadPad. (need PadReshapePad)
  • BinopPadtoConcat (need BinopReshapePadtoConcat)
  • ConcatPad (need ConcatReshapePad)
  • ReducePad
  • BroadcastPad
  • MulZeroPad
  • DivZeroPad
  • BinopConstPad
  • BinopBinopPadPad
  • AddPadPadtoConcat
  • UnaryPadPush
  • TransposePad
  • PadDotGeneral

@wsmoses
Copy link
Member Author

wsmoses commented Apr 15, 2024

  • concatenate of N of the same elements -> broadcast
  • select of pad
  • select reshape pad ( select(x, reshape(pad(x) with const_0), const_0) -> pad(reshape(select(x, const_0)), const_0) )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants