Skip to content

Commit

Permalink
[CPU] AVX implementation for Softmax, Norm (#357)
Browse files Browse the repository at this point in the history
Add softmax and norm operator template
  • Loading branch information
fishingguy456 committed Jan 9, 2024
1 parent 2f5e3f1 commit 52fe368
Show file tree
Hide file tree
Showing 10 changed files with 426 additions and 6 deletions.
1 change: 1 addition & 0 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def compile(
# support avx intrinsics
'-mavx2',
'-m64',
'-ffast-math',
'-march={arch}'.format(arch=arch),
# compile into position independent code.
'-fPIC',
Expand Down
10 changes: 8 additions & 2 deletions python/hidet/graph/ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hidet.ir.expr import if_then_else, BitwiseAnd
from .utils import Tensor, Operator, normalize_dim, input_like
from .arithmetic import UnaryElementwiseOp, BinaryElementwiseOp
from .softmax import SoftmaxTask
from .softmax import CPUSoftmaxTask, SoftmaxTask


class ReluOp(UnaryElementwiseOp):
Expand Down Expand Up @@ -189,7 +189,13 @@ def __init__(self, x: Tensor, lambda_val: float = 0.5):
class SoftmaxOp(Operator):
def __init__(self, x: Tensor, axis: int = 1):
axis = normalize_dim(axis, len(x.shape))
super().__init__(inputs=[x], attributes={'axis': axis}, task=SoftmaxTask(input_like(x, 'x'), axis))
super().__init__(
inputs=[x],
attributes={'axis': axis},
task=CPUSoftmaxTask(input_like(x, 'x'), axis)
if x.device.is_cpu()
else SoftmaxTask(input_like(x, 'x'), axis),
)


def relu(x) -> Tensor:
Expand Down
122 changes: 120 additions & 2 deletions python/hidet/graph/ops/normalize/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Union
from hidet.ir import primitives as prim
from hidet.ir.library import tune
from hidet.ir.module import IRModule
Expand All @@ -26,6 +26,7 @@
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import compute, input_like, normalize_dim
from hidet.utils import prod
from hidet.lang import float32


class NormalizeTask(Task):
Expand Down Expand Up @@ -353,14 +354,131 @@ def norm_kernel(x: dtype[x.shape], y: dtype[y.shape]):
return ir_module


class CPUNormalizeTask(NormalizeTask):
def allow_prologue(self) -> bool:
return False

def allow_epilogue(self) -> bool:
return False

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if self.dims[-1] != len(self.inputs[0].shape) - 1 or self.inputs[0].type.dtype != float32:
return NotImplemented
return tune.extract_ir_modules(self.schedule_norm_cpu)

@tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96])
@tune.space(1, nthreads=['', 8, 16])
def schedule_norm_cpu(self, nthreads='') -> IRModule:
import hidet
from hidet.ir.primitives.cpu.avx import (
avx_f32x8_subtract,
avx_f32x8_load,
avx_f32x8_setzero,
avx_f32x8_store,
avx_f32x8_add,
avx_f32x8_set1,
avx_f32x8_divide,
avx_f32x8_multiply,
avx_f32x8_sqrt,
)
from hidet.ir.primitives.cpu.avx_helper import avx_f32x8_sum
from hidet.lang import tensor

shape = self.inputs[0].shape
head = shape[: -len(self.dims)]
tail = shape[-len(self.dims) :]
head_size = prod(head)
tail_size = prod(tail)
with hidet.script_module() as module:

@hidet.script
def norm_cpu_kernel(x: float32[shape], out: float32[shape]):
attrs.func_kind = "cpu_kernel"
para = "p" + str(nthreads)
for k in grid(head_size, attrs=para):
head_idx = spatial(*head).map(k)

mean_vec = avx_f32x8_setzero()
M2_vec = avx_f32x8_setzero()
epsilon_vec = avx_f32x8_set1(self.attrs['epsilon'])

mean_combined = 0.0
M2_combined = 0.0
if tail_size >= 8:
for i in range(tail_size // 8):
tail_idx = spatial(*tail).map(i * 8)
# welford algorithm
n_vec = avx_f32x8_set1(cast(i + 1, float32))
data_vec = avx_f32x8_load(~x[head_idx][tail_idx])
delta = avx_f32x8_subtract(data_vec, mean_vec)
mean_vec = avx_f32x8_add(mean_vec, avx_f32x8_divide(delta, n_vec))
delta2 = avx_f32x8_subtract(data_vec, mean_vec)
M2_vec = avx_f32x8_add(M2_vec, avx_f32x8_multiply(delta, delta2))

# welford combine
# TODO: case for numerical stability? (number too high for large matrix)
# TODO: look at the cascade thing in pytorch github
mean_combined = avx_f32x8_sum(mean_vec) / 8
mean_combined_vec = avx_f32x8_set1(mean_combined)
delta_vec = avx_f32x8_subtract(mean_vec, mean_combined_vec)
M2_combined = avx_f32x8_sum(M2_vec) + avx_f32x8_sum(
avx_f32x8_multiply(delta_vec, delta_vec)
) * (tail_size // 8)
mean_tail = 0.0
M2_tail = 0.0
# welford on remaining parts past 8
for i in range(tail_size % 8):
tail_idx = spatial(*tail).map(tail_size - tail_size % 8 + i)
delta_tail = x[head_idx][tail_idx] - mean_tail
mean_tail += delta_tail / cast(i + 1, float32)
delta_tail2 = x[head_idx][tail_idx] - mean_tail
M2_tail += delta_tail * delta_tail2
# welford combine vectorized and unvectorized
delta_end = mean_tail - mean_combined
mean = (mean_combined * (tail_size - tail_size % 8) + mean_tail * (tail_size % 8)) / tail_size
var = (
M2_combined
+ M2_tail
+ delta_end * delta_end * (tail_size - tail_size % 8) * (tail_size % 8) / tail_size
) / tail_size
mean_vec = avx_f32x8_set1(mean)
var_vec = avx_f32x8_set1(var)
if tail_size >= 8:
for i in range(tail_size // 8):
# norm calculation
tail_idx = spatial(*tail).map(i * 8)
temp_out = tensor(dtype=float32, shape=[8])
avx_f32x8_store(
temp_out,
avx_f32x8_divide(
avx_f32x8_subtract(avx_f32x8_load(~x[head_idx][tail_idx]), mean_vec),
avx_f32x8_sqrt(avx_f32x8_add(var_vec, epsilon_vec)),
),
)
for j in range(8):
tail_idx = spatial(*tail).map(i * 8 + j)
out[head_idx][tail_idx] = temp_out[j]
for i in range(tail_size % 8):
tail_idx = spatial(*tail).map(tail_size - tail_size % 8 + i)
out[head_idx][tail_idx] = (x[head_idx][tail_idx] - mean) * prim.rsqrt(
var + self.attrs['epsilon']
)

assert isinstance(norm_cpu_kernel, hidet.ir.Function)
ir_module = module.ir_module()
return ir_module


class NormalizeOp(Operator):
def __init__(self, x: Tensor, dims, epsilon: float, accumulate_dtype: str):
rank = len(x.shape)
dims = normalize_dim(dims, rank=rank)
super().__init__(
inputs=[x],
attributes={'dims': dims, 'epsilon': epsilon, 'accumulate_dtype': accumulate_dtype},
task=NormalizeTask(input_like(x, 'x'), dims, epsilon, accumulate_dtype),
task=CPUNormalizeTask(input_like(x, 'x'), dims, epsilon, accumulate_dtype)
if x.device.is_cpu()
else NormalizeTask(input_like(x, 'x'), dims, epsilon, accumulate_dtype),
)


Expand Down
158 changes: 158 additions & 0 deletions python/hidet/graph/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from hidet.ir.module import IRModule
from hidet.ir import primitives as prim
from hidet.ir.expr import is_constant
from hidet.ir.stmt import Stmt, AssignStmt
from hidet.ir.builders import StmtBuilder
from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync
from hidet.ir.dtypes import float32
from hidet.ir.library import tune
from .utils import Task, TensorNode, compute, reduce


Expand Down Expand Up @@ -153,3 +156,158 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]):
ir_module = module.ir_module()

return ir_module

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if self.inputs[0].type.dtype != float32:
return NotImplemented # use auto-scheduler
return tune.extract_ir_modules(self.schedule_softmax_cpu)


class CPUSoftmaxTask(SoftmaxTask):
def allow_epilogue(self) -> bool:
return False

def allow_prologue(self) -> bool:
return False

@tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96])
@tune.space(1, nthreads=['', 8, 16])
def schedule_softmax_cpu(self, nthreads='') -> IRModule:
import hidet
from hidet.ir.primitives.cpu.avx import (
avx_f32x8_subtract,
avx_f32x8_load,
avx_f32x8_setzero,
avx_f32x8_store,
avx_f32x8_add,
avx_f32x8_max,
avx_f32x8_set1,
avx_f32x8_divide,
)
from hidet.ir.primitives.cpu.avx_helper import avx_f32x8_sum, avx_f32x8_scalar_max
from hidet.lang import tensor, attrs, grid
from hidet.ir.stmt import DeclareScope
from hidet.lang.mapping import spatial
from hidet.utils import prod
from hidet.ir.dtypes import float32x8

shape = self.inputs[0].shape
head = shape[: self.axis]
tail = shape[self.axis :] if self.axis == len(shape) - 1 else shape[self.axis + 1 :]
head_size = prod(head)
tail_size = prod(tail)
axis_size = shape[self.axis]

with hidet.script_module() as module:

@hidet.script
def apply_exponent(vec: float32x8) -> float32x8:
attrs.func_kind = "cpu_internal"
arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[8])
avx_f32x8_store(arr, vec)
for n in range(8):
arr[n] = prim.exp(arr[n])
return avx_f32x8_load(arr)

@hidet.script
def softmax_cpu_kernel(x: float32[shape], out: float32[shape]):
attrs.func_kind = "cpu_kernel"
para = 'p' + str(nthreads)
for k in grid(head_size, attrs=para):
head_idx = spatial(*head).map(k)
if self.axis == len(shape) - 1: # last dim
temp_exp = tensor(dtype=float32, shape=tail)
max_val = x[head_idx][0]
if tail_size >= 8:
# vectorized find max value
max_vec = avx_f32x8_load(~x[head_idx][0])
for i in range(tail_size // 8):
data_vec = avx_f32x8_load(~x[head_idx][i * 8])
max_vec = avx_f32x8_max(max_vec, data_vec)
max_val = avx_f32x8_scalar_max(max_vec)
for i in range(tail_size % 8):
# max value of remaining unvectorized parts
max_val = (
max_val
if max_val > x[head_idx][tail_size - tail_size % 8 + i]
else x[head_idx][tail_size - tail_size % 8 + i]
)

# subtract max, take exp and find exp sum
sum_value = 0.0
if tail_size >= 8:
sum_exp_vec = avx_f32x8_setzero()
max_vec = avx_f32x8_set1(max_val)
for i in range(tail_size // 8):
val_vec = avx_f32x8_load(~x[head_idx][i * 8])
val_vec = avx_f32x8_subtract(val_vec, max_vec)
val_vec = apply_exponent(val_vec)
avx_f32x8_store(~temp_exp[i * 8], val_vec)
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec)
sum_value = avx_f32x8_sum(sum_exp_vec)
for i in range(tail_size % 8):
temp_exp[tail_size - tail_size % 8 + i] = prim.exp(
x[head_idx][tail_size - tail_size % 8 + i] - max_val
)
sum_value += temp_exp[tail_size - tail_size % 8 + i]

# divide by exp sum
if tail_size >= 8:
# divide
sum_vec8 = avx_f32x8_set1(sum_value)
for i in range(tail_size // 8):
avx_f32x8_store(
~temp_exp[i * 8], avx_f32x8_divide(avx_f32x8_load(~temp_exp[i * 8]), sum_vec8)
)
for i in range(tail_size % 8):
temp_exp[tail_size - tail_size % 8 + i] /= sum_value
for i in range(tail_size):
out[head_idx][i] = temp_exp[i]
else: # not last dim
temp_exp = tensor(dtype=float32, shape=[shape[self.axis]] + tail)
# vectorized operations across all contiguous memory for relevant axis
for g in range(tail_size // 8):
tail_idx = spatial(*tail).map(g * 8)
max_vec = avx_f32x8_load(~x[head_idx][0][tail_idx])
for i in range(axis_size):
data_vec = avx_f32x8_load(~x[head_idx][i][tail_idx])
max_vec = avx_f32x8_max(max_vec, data_vec)
sum_exp_vec = avx_f32x8_setzero()
for i in range(axis_size):
val_vec = avx_f32x8_load(~x[head_idx][i][tail_idx])
val_vec = avx_f32x8_subtract(val_vec, max_vec)
val_vec = apply_exponent(val_vec)
avx_f32x8_store(~temp_exp[i][tail_idx], val_vec)
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec)
for i in range(axis_size):
avx_f32x8_store(
~temp_exp[i][tail_idx],
avx_f32x8_divide(avx_f32x8_load(~temp_exp[i][tail_idx]), sum_exp_vec),
)
for j in range(8):
tail_end_idx = spatial(*tail).map(g * 8 + j)
out[head_idx][i][tail_end_idx] = temp_exp[i][tail_end_idx]
# unvectorized operations for the remaining elements
max_arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[tail_size % 8])
for j in range(tail_size % 8):
max_arr[j] = 0.0
for p in range(axis_size):
for j in range(tail_size % 8):
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j)
max_arr[j] = prim.max(max_arr[j], x[head_idx][p][last_idx]) # TODO: index
sum_exp_arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[tail_size % 8])
for j in range(tail_size % 8):
sum_exp_arr[j] = 0.0
for p in range(axis_size):
for j in range(tail_size % 8):
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j)
out[head_idx][p][last_idx] = prim.exp(x[head_idx][p][last_idx] - max_arr[j])
sum_exp_arr[j] += out[head_idx][p][last_idx]
for p in range(axis_size):
for j in range(tail_size % 8):
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j)
out[head_idx][p][last_idx] = out[head_idx][p][last_idx] / sum_exp_arr[j]

assert isinstance(softmax_cpu_kernel, hidet.ir.Function)
ir_module = module.ir_module()
return ir_module
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def max_value(self):

int8x4 = VectorType(int8, 4)
i8x4 = int8x4
float32x4 = VectorType(float32, 4)
float32x8 = VectorType(float32, 8)
float16x2 = VectorType(float16, 2)

uint8x4 = VectorType(uint8, 4)
u8x4 = uint8x4
Expand Down
1 change: 0 additions & 1 deletion python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ class Call(Expr):
def __init__(self, func_var, args):
self.func_var: Var = func_var
self.args: Tuple[Expr, ...] = args

assert isinstance(func_var, Var) and isinstance(args, tuple)
for arg in args:
assert isinstance(arg, Expr)
Expand Down
Loading

0 comments on commit 52fe368

Please sign in to comment.