Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU AVX implementation for Softmax, Norm #357

Merged
merged 76 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
14cbb3b
initial commit
fishingguy456 Jul 6, 2023
7896c45
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
ff90ed5
initial commit
fishingguy456 Jul 6, 2023
fc61204
change imports
fishingguy456 Jul 20, 2023
f84201f
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
6f2e43c
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
25f22cf
initial commit
fishingguy456 Jul 6, 2023
aafbb0f
initial commit
fishingguy456 Jul 6, 2023
44993e2
change imports
fishingguy456 Jul 20, 2023
a86d866
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
b59ffa2
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
7edf0eb
wrap up softmax, starting layernorm
fishingguy456 Jul 28, 2023
44c04b3
layernorm kinda works but not rly
fishingguy456 Jul 31, 2023
2ccc4b6
better code for softmax
fishingguy456 Jul 31, 2023
13ea5dc
layernorm works for last layer
fishingguy456 Aug 1, 2023
d89036d
move find sum and find max to registered function
fishingguy456 Aug 1, 2023
b0659f6
find max in registered func
fishingguy456 Aug 1, 2023
904760b
not working softmax on not last dim, minor changes
fishingguy456 Aug 3, 2023
29b7ba7
layernorm works for any dims
fishingguy456 Aug 3, 2023
0c8dc3a
comments
fishingguy456 Aug 4, 2023
77fe8d9
tuning, fix for flowgraph operator resolve
fishingguy456 Aug 4, 2023
ac40695
softmax works
fishingguy456 Aug 5, 2023
4938a1f
commented tensors dont work, i.e. axis is not last 2 AND not multiple…
fishingguy456 Aug 5, 2023
1d447cf
actually works rn frfr so fast :100:
fishingguy456 Aug 8, 2023
30224ce
cleanup
fishingguy456 Aug 8, 2023
67d4d56
more cleanup
fishingguy456 Aug 9, 2023
09ca2f8
random testing stuff
fishingguy456 Aug 11, 2023
8352dd8
allow epilogue
fishingguy456 Aug 18, 2023
27f6cbb
better epiloguing
fishingguy456 Aug 18, 2023
cce1d42
janky matmul resolve
fishingguy456 Aug 25, 2023
f92de53
still epilogue problem?
fishingguy456 Aug 25, 2023
63dfed4
initial commit
fishingguy456 Jul 6, 2023
73a063a
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
1c129c0
initial commit
fishingguy456 Jul 6, 2023
bf8a5b5
change imports
fishingguy456 Jul 20, 2023
3aa5cb6
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
b849ebf
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
12fdbd1
initial commit
fishingguy456 Jul 6, 2023
9c7ecd0
initial commit
fishingguy456 Jul 6, 2023
b155bbd
change imports
fishingguy456 Jul 20, 2023
de72bc6
fix for diff size, compiledmodule error fix
fishingguy456 Jul 21, 2023
17b8d76
works on multidimensional, axis=-1
fishingguy456 Jul 25, 2023
1b52167
wrap up softmax, starting layernorm
fishingguy456 Jul 28, 2023
e479db7
layernorm kinda works but not rly
fishingguy456 Jul 31, 2023
c623630
better code for softmax
fishingguy456 Jul 31, 2023
b44b69e
layernorm works for last layer
fishingguy456 Aug 1, 2023
29ea558
move find sum and find max to registered function
fishingguy456 Aug 1, 2023
339e549
find max in registered func
fishingguy456 Aug 1, 2023
88c423c
not working softmax on not last dim, minor changes
fishingguy456 Aug 3, 2023
9c91875
layernorm works for any dims
fishingguy456 Aug 3, 2023
6e0d8e5
comments
fishingguy456 Aug 4, 2023
552aebb
tuning, fix for flowgraph operator resolve
fishingguy456 Aug 4, 2023
dc258e3
softmax works
fishingguy456 Aug 5, 2023
95f6be7
commented tensors dont work, i.e. axis is not last 2 AND not multiple…
fishingguy456 Aug 5, 2023
d0b99a4
actually works rn frfr so fast :100:
fishingguy456 Aug 8, 2023
67a43a5
cleanup
fishingguy456 Aug 8, 2023
4443780
more cleanup
fishingguy456 Aug 9, 2023
4088fc6
random testing stuff
fishingguy456 Aug 11, 2023
7430696
allow epilogue
fishingguy456 Aug 18, 2023
8a1167e
better epiloguing
fishingguy456 Aug 18, 2023
0f4876f
janky matmul resolve
fishingguy456 Aug 25, 2023
49c072f
still epilogue problem?
fishingguy456 Aug 25, 2023
0bd13d8
Merge remote-tracking branch 'origin/main'
fishingguy456 Sep 14, 2023
de74231
clean up for pr
fishingguy456 Sep 14, 2023
9ab0bac
fix test
fishingguy456 Sep 18, 2023
f779a1d
lint
fishingguy456 Sep 18, 2023
124fb09
minor pr edits
fishingguy456 Sep 19, 2023
6c4efd9
pytests, cpu child class
fishingguy456 Sep 19, 2023
40fd71f
potential fix for failing tests? but prob not will have to investigat…
fishingguy456 Sep 19, 2023
90c4ffb
weird diff
fishingguy456 Jan 4, 2024
587ba64
merge conflict resolve build.py
fishingguy456 Jan 4, 2024
89d5646
remove shady batch mat mul
fishingguy456 Jan 4, 2024
a3a4b03
lint thing
fishingguy456 Jan 5, 2024
aec95d2
move helpers to new file
fishingguy456 Jan 8, 2024
7a41b5c
lint
fishingguy456 Jan 8, 2024
dcc6a45
change tolerance for flaky test for test_dynamic_shape
fishingguy456 Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def compile(
# support avx intrinsics
'-mavx2',
'-m64',
'-ffast-math',
'-march={arch}'.format(arch=arch),
# compile into position independent code.
'-fPIC',
Expand Down
10 changes: 8 additions & 2 deletions python/hidet/graph/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hidet.ir.expr import if_then_else, BitwiseAnd
from .utils import Tensor, Operator, normalize_dim, input_like
from .arithmetic import UnaryElementwiseOp, BinaryElementwiseOp
from .softmax import SoftmaxTask
from .softmax import CPUSoftmaxTask, SoftmaxTask


class ReluOp(UnaryElementwiseOp):
Expand Down Expand Up @@ -189,7 +189,13 @@ def __init__(self, x: Tensor, lambda_val: float = 0.5):
class SoftmaxOp(Operator):
def __init__(self, x: Tensor, axis: int = 1):
axis = normalize_dim(axis, len(x.shape))
super().__init__(inputs=[x], attributes={'axis': axis}, task=SoftmaxTask(input_like(x, 'x'), axis))
super().__init__(
inputs=[x],
attributes={'axis': axis},
task=CPUSoftmaxTask(input_like(x, 'x'), axis)
if x.device.is_cpu()
else SoftmaxTask(input_like(x, 'x'), axis),
)


def relu(x) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/matmul/matmul_f32_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, a: TensorNode, b: TensorNode):
)

def allow_epilogue(self) -> bool:
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we change this to False? 🤔


def allow_prologue(self) -> bool:
return False
Expand Down
122 changes: 120 additions & 2 deletions python/hidet/graph/ops/normalize/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Union
from hidet.ir import primitives as prim
from hidet.ir.library import tune
from hidet.ir.module import IRModule
Expand All @@ -26,6 +26,7 @@
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import compute, input_like, normalize_dim
from hidet.utils import prod
from hidet.lang import float32


class NormalizeTask(Task):
Expand Down Expand Up @@ -353,14 +354,131 @@ def norm_kernel(x: dtype[x.shape], y: dtype[y.shape]):
return ir_module


class CPUNormalizeTask(NormalizeTask):
def allow_prologue(self) -> bool:
return False

def allow_epilogue(self) -> bool:
return False

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if self.dims[-1] != len(self.inputs[0].shape) - 1 or self.inputs[0].type.dtype != float32:
return NotImplemented
return tune.extract_ir_modules(self.schedule_norm_cpu)

@tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96])
@tune.space(1, nthreads=['', 8, 16])
def schedule_norm_cpu(self, nthreads='') -> IRModule:
import hidet
from hidet.ir.primitives.cpu.avx import (
avx_f32x8_subtract,
avx_f32x8_load,
avx_f32x8_setzero,
avx_f32x8_store,
avx_f32x8_add,
avx_f32x8_set1,
avx_f32x8_divide,
avx_f32x8_multiply,
avx_f32x8_sum,
avx_f32x8_sqrt,
)
from hidet.lang import tensor

shape = self.inputs[0].shape
head = shape[: -len(self.dims)]
tail = shape[-len(self.dims) :]
head_size = prod(head)
tail_size = prod(tail)
with hidet.script_module() as module:

@hidet.script
def norm_cpu_kernel(x: float32[shape], out: float32[shape]):
attrs.func_kind = "cpu_kernel"
para = "p" + str(nthreads)
for k in grid(head_size, attrs=para):
head_idx = spatial(*head).map(k)

mean_vec = avx_f32x8_setzero()
M2_vec = avx_f32x8_setzero()
epsilon_vec = avx_f32x8_set1(self.attrs['epsilon'])

mean_combined = 0.0
M2_combined = 0.0
if tail_size >= 8:
for i in range(tail_size // 8):
tail_idx = spatial(*tail).map(i * 8)
# welford algorithm
n_vec = avx_f32x8_set1(cast(i + 1, float32))
data_vec = avx_f32x8_load(~x[head_idx][tail_idx])
delta = avx_f32x8_subtract(data_vec, mean_vec)
mean_vec = avx_f32x8_add(mean_vec, avx_f32x8_divide(delta, n_vec))
delta2 = avx_f32x8_subtract(data_vec, mean_vec)
M2_vec = avx_f32x8_add(M2_vec, avx_f32x8_multiply(delta, delta2))

# welford combine
# TODO: case for numerical stability? (number too high for large matrix)
# TODO: look at the cascade thing in pytorch github
mean_combined = avx_f32x8_sum(mean_vec) / 8
mean_combined_vec = avx_f32x8_set1(mean_combined)
delta_vec = avx_f32x8_subtract(mean_vec, mean_combined_vec)
M2_combined = avx_f32x8_sum(M2_vec) + avx_f32x8_sum(
avx_f32x8_multiply(delta_vec, delta_vec)
) * (tail_size // 8)
mean_tail = 0.0
M2_tail = 0.0
# welford on remaining parts past 8
for i in range(tail_size % 8):
tail_idx = spatial(*tail).map(tail_size - tail_size % 8 + i)
delta_tail = x[head_idx][tail_idx] - mean_tail
mean_tail += delta_tail / cast(i + 1, float32)
delta_tail2 = x[head_idx][tail_idx] - mean_tail
M2_tail += delta_tail * delta_tail2
# welford combine vectorized and unvectorized
delta_end = mean_tail - mean_combined
mean = (mean_combined * (tail_size - tail_size % 8) + mean_tail * (tail_size % 8)) / tail_size
var = (
M2_combined
+ M2_tail
+ delta_end * delta_end * (tail_size - tail_size % 8) * (tail_size % 8) / tail_size
) / tail_size
mean_vec = avx_f32x8_set1(mean)
var_vec = avx_f32x8_set1(var)
if tail_size >= 8:
for i in range(tail_size // 8):
# norm calculation
tail_idx = spatial(*tail).map(i * 8)
temp_out = tensor(dtype=float32, shape=[8])
avx_f32x8_store(
temp_out,
avx_f32x8_divide(
avx_f32x8_subtract(avx_f32x8_load(~x[head_idx][tail_idx]), mean_vec),
avx_f32x8_sqrt(avx_f32x8_add(var_vec, epsilon_vec)),
),
)
for j in range(8):
tail_idx = spatial(*tail).map(i * 8 + j)
out[head_idx][tail_idx] = temp_out[j]
for i in range(tail_size % 8):
tail_idx = spatial(*tail).map(tail_size - tail_size % 8 + i)
out[head_idx][tail_idx] = (x[head_idx][tail_idx] - mean) * prim.rsqrt(
var + self.attrs['epsilon']
)

assert isinstance(norm_cpu_kernel, hidet.ir.Function)
ir_module = module.ir_module()
return ir_module


class NormalizeOp(Operator):
def __init__(self, x: Tensor, dims, epsilon: float, accumulate_dtype: str):
rank = len(x.shape)
dims = normalize_dim(dims, rank=rank)
super().__init__(
inputs=[x],
attributes={'dims': dims, 'epsilon': epsilon, 'accumulate_dtype': accumulate_dtype},
task=NormalizeTask(input_like(x, 'x'), dims, epsilon, accumulate_dtype),
task=CPUNormalizeTask(input_like(x, 'x'), dims, epsilon, accumulate_dtype)
if x.device.is_cpu()
else NormalizeTask(input_like(x, 'x'), dims, epsilon, accumulate_dtype),
)


Expand Down
159 changes: 159 additions & 0 deletions python/hidet/graph/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from hidet.ir.module import IRModule
from hidet.ir import primitives as prim
from hidet.ir.expr import is_constant
from hidet.ir.stmt import Stmt, AssignStmt
from hidet.ir.builders import StmtBuilder
from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync
from hidet.ir.dtypes import float32
from hidet.ir.library import tune
from .utils import Task, TensorNode, compute, reduce


Expand Down Expand Up @@ -153,3 +156,159 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]):
ir_module = module.ir_module()

return ir_module

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if self.inputs[0].type.dtype != float32:
return NotImplemented # use auto-scheduler
return tune.extract_ir_modules(self.schedule_softmax_cpu)


class CPUSoftmaxTask(SoftmaxTask):
def allow_epilogue(self) -> bool:
return False

def allow_prologue(self) -> bool:
return False
Comment on lines +167 to +171
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a CPU version of the operator because cuda version allows prologue & epilogue.


@tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96])
@tune.space(1, nthreads=['', 8, 16])
def schedule_softmax_cpu(self, nthreads='') -> IRModule:
import hidet
from hidet.ir.primitives.cpu.avx import (
avx_f32x8_subtract,
avx_f32x8_load,
avx_f32x8_setzero,
avx_f32x8_store,
avx_f32x8_add,
avx_f32x8_max,
avx_f32x8_set1,
avx_f32x8_divide,
avx_f32x8_sum,
avx_f32x8_scalar_max,
)
from hidet.lang import tensor, attrs, grid
from hidet.ir.stmt import DeclareScope
from hidet.lang.mapping import spatial
from hidet.utils import prod
from hidet.ir.dtypes import float32x8

shape = self.inputs[0].shape
head = shape[: self.axis]
tail = shape[self.axis :] if self.axis == len(shape) - 1 else shape[self.axis + 1 :]
head_size = prod(head)
tail_size = prod(tail)
axis_size = shape[self.axis]

with hidet.script_module() as module:

@hidet.script
def apply_exponent(vec: float32x8) -> float32x8:
attrs.func_kind = "cpu_internal"
arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[8])
avx_f32x8_store(arr, vec)
for n in range(8):
arr[n] = prim.exp(arr[n])
return avx_f32x8_load(arr)

@hidet.script
def softmax_cpu_kernel(x: float32[shape], out: float32[shape]):
attrs.func_kind = "cpu_kernel"
para = 'p' + str(nthreads)
for k in grid(head_size, attrs=para):
head_idx = spatial(*head).map(k)
if self.axis == len(shape) - 1: # last dim
temp_exp = tensor(dtype=float32, shape=tail)
max_val = x[head_idx][0]
if tail_size >= 8:
# vectorized find max value
max_vec = avx_f32x8_load(~x[head_idx][0])
for i in range(tail_size // 8):
data_vec = avx_f32x8_load(~x[head_idx][i * 8])
max_vec = avx_f32x8_max(max_vec, data_vec)
max_val = avx_f32x8_scalar_max(max_vec)
for i in range(tail_size % 8):
# max value of remaining unvectorized parts
max_val = (
max_val
if max_val > x[head_idx][tail_size - tail_size % 8 + i]
else x[head_idx][tail_size - tail_size % 8 + i]
)

# subtract max, take exp and find exp sum
sum_value = 0.0
if tail_size >= 8:
sum_exp_vec = avx_f32x8_setzero()
max_vec = avx_f32x8_set1(max_val)
for i in range(tail_size // 8):
val_vec = avx_f32x8_load(~x[head_idx][i * 8])
val_vec = avx_f32x8_subtract(val_vec, max_vec)
val_vec = apply_exponent(val_vec)
avx_f32x8_store(~temp_exp[i * 8], val_vec)
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec)
sum_value = avx_f32x8_sum(sum_exp_vec)
for i in range(tail_size % 8):
temp_exp[tail_size - tail_size % 8 + i] = prim.exp(
x[head_idx][tail_size - tail_size % 8 + i] - max_val
)
sum_value += temp_exp[tail_size - tail_size % 8 + i]

# divide by exp sum
if tail_size >= 8:
# divide
sum_vec8 = avx_f32x8_set1(sum_value)
for i in range(tail_size // 8):
avx_f32x8_store(
~temp_exp[i * 8], avx_f32x8_divide(avx_f32x8_load(~temp_exp[i * 8]), sum_vec8)
)
for i in range(tail_size % 8):
temp_exp[tail_size - tail_size % 8 + i] /= sum_value
for i in range(tail_size):
out[head_idx][i] = temp_exp[i]
else: # not last dim
temp_exp = tensor(dtype=float32, shape=[shape[self.axis]] + tail)
# vectorized operations across all contiguous memory for relevant axis
for g in range(tail_size // 8):
tail_idx = spatial(*tail).map(g * 8)
max_vec = avx_f32x8_load(~x[head_idx][0][tail_idx])
for i in range(axis_size):
data_vec = avx_f32x8_load(~x[head_idx][i][tail_idx])
max_vec = avx_f32x8_max(max_vec, data_vec)
sum_exp_vec = avx_f32x8_setzero()
for i in range(axis_size):
val_vec = avx_f32x8_load(~x[head_idx][i][tail_idx])
val_vec = avx_f32x8_subtract(val_vec, max_vec)
val_vec = apply_exponent(val_vec)
avx_f32x8_store(~temp_exp[i][tail_idx], val_vec)
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec)
for i in range(axis_size):
avx_f32x8_store(
~temp_exp[i][tail_idx],
avx_f32x8_divide(avx_f32x8_load(~temp_exp[i][tail_idx]), sum_exp_vec),
)
for j in range(8):
tail_end_idx = spatial(*tail).map(g * 8 + j)
out[head_idx][i][tail_end_idx] = temp_exp[i][tail_end_idx]
# unvectorized operations for the remaining elements
max_arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[tail_size % 8])
for j in range(tail_size % 8):
max_arr[j] = 0.0
for p in range(axis_size):
for j in range(tail_size % 8):
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j)
max_arr[j] = prim.max(max_arr[j], x[head_idx][p][last_idx]) # TODO: index
sum_exp_arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[tail_size % 8])
for j in range(tail_size % 8):
sum_exp_arr[j] = 0.0
for p in range(axis_size):
for j in range(tail_size % 8):
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j)
out[head_idx][p][last_idx] = prim.exp(x[head_idx][p][last_idx] - max_arr[j])
sum_exp_arr[j] += out[head_idx][p][last_idx]
for p in range(axis_size):
for j in range(tail_size % 8):
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j)
out[head_idx][p][last_idx] = out[head_idx][p][last_idx] / sum_exp_arr[j]

assert isinstance(softmax_cpu_kernel, hidet.ir.Function)
ir_module = module.ir_module()
return ir_module
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def max_value(self):

int8x4 = VectorType(int8, 4)
i8x4 = int8x4
float32x4 = VectorType(float32, 4)
float32x8 = VectorType(float32, 8)
float16x2 = VectorType(float16, 2)

uint8x4 = VectorType(uint8, 4)
u8x4 = uint8x4
Expand Down
1 change: 0 additions & 1 deletion python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ class Call(Expr):
def __init__(self, func_var, args):
self.func_var: Var = func_var
self.args: Tuple[Expr, ...] = args

assert isinstance(func_var, Var) and isinstance(args, tuple)
for arg in args:
assert isinstance(arg, Expr)
Expand Down
Loading