Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointwise support inplace #448

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
(
("abs", abs, Autograd.disable),
("add.Tensor", add, Autograd.disable),
("add_.Tensor", inplace_add, Autograd.disable),
("addmm", addmm, Autograd.disable),
("arange.start_step", arange_start, Autograd.disable),
("arange.start", arange_start, Autograd.disable),
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .abs import abs
from .add import add
from .add import add, inplace_add
from .addmm import addmm
from .all import all, all_dim, all_dims
from .amax import amax
Expand Down Expand Up @@ -148,6 +148,7 @@
"any_dim",
"any_dims",
"add",
"inplace_add",
"abs",
"addmm",
"arange",
Expand Down
42 changes: 42 additions & 0 deletions src/flag_gems/ops/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,36 @@ def add_func_scalar_tensor(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[True, True, False],
promotion_methods=[(0, 1, "DEFAULT")],
is_inplace=True,
)
@triton.jit
def inplace_add_func(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[True, False, False],
promotion_methods=[(0, 1, "DEFAULT")],
is_inplace=True,
)
@triton.jit
def inplace_add_func_tensor_scalar(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(
is_tensor=[False, True, False],
promotion_methods=[(0, 1, "DEFAULT")],
is_inplace=True,
)
@triton.jit
def inplace_add_func_scalar_tensor(x, y, alpha):
return x + y * alpha


def add(A, B, *, alpha=1):
logging.debug("GEMS ADD")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
Expand All @@ -38,3 +68,15 @@ def add(A, B, *, alpha=1):
return add_func_scalar_tensor(A, B, alpha)
else:
return torch.tensor(A + B * alpha)


def inplace_add(A, B, *, alpha=1):
logging.debug("GEMS ADD")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return inplace_add_func(A, B, alpha)
elif isinstance(A, torch.Tensor):
return inplace_add_func_tensor_scalar(A, B, alpha)
elif isinstance(B, torch.Tensor):
return inplace_add_func_scalar_tensor(A, B, alpha)
else:
raise NotImplementedError("Inplace add not implemented for scalars")
58 changes: 46 additions & 12 deletions src/flag_gems/utils/pointwise_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
dtypes: Optional[List[Optional[type]]] = None,
num_outputs: Optional[int] = None,
promotion_methods=None,
is_inplace=False,
):
if is_tensor is not None:
_check_typed_list(is_tensor, bool)
Expand Down Expand Up @@ -137,6 +138,11 @@ def __init__(
self._num_input_tensors = sum(self._is_tensor)
self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
self._input_id = self._compute_input_id()
self._is_inplace = is_inplace
if self._is_inplace:
assert (
self._num_outputs == 1
), "Inplace operation should have only single output"

@staticmethod
def canonicalize_promotion_methods(promotion_methods):
Expand Down Expand Up @@ -186,14 +192,16 @@ def signature(self, outputs_in_arg: bool = False) -> str:

output_types = []

if outputs_in_arg:
for i in range(self.num_outputs()):
output_types.append(f"StridedBuffer(a{1}!)")
input_types.extend(output_types)
else:
for _ in range(self.num_outputs()):
output_types.append("StridedBuffer")
if not self._is_inplace:
if outputs_in_arg:
for i in range(self.num_outputs()):
output_types.append(f"StridedBuffer(a{1}!)")
input_types.extend(output_types)
else:
for _ in range(self.num_outputs()):
output_types.append("StridedBuffer")
sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
print("signature is: ", sig)
return sig

def _compute_input_id(self):
Expand Down Expand Up @@ -786,9 +794,11 @@ def gen_signature(self, code: IndentedBuffer):
# tensors). We emphasize that these parameters are added in-addition, we enforce
# that they be passed by keyword. After all, out0, out1, ... does not mismatch
# names form the scalar function, since it does not have output parameters.
# if not schema._is_inplace:
params.append("/")
params.append("*") # output params must be passed by keyword

# if not schema._is_inplace:
for i in range(schema.num_output_tensors()):
params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
code.writeline(f"def {self.name}({_cs(params)}): ")
Expand Down Expand Up @@ -1081,7 +1091,13 @@ def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None)

def __call__(self, *args, **kwargs):
# inputs must be passed by position, outputs must be passed by keyword
print("Before prepare args is: ", args)
print("Before prepare kwargs is: ", kwargs)

ndim, args, kwargs = self.prepare_args(*args, **kwargs)
print("After prepare args is: ", args)
print("After prepare kwargs is: ", kwargs)

overload = self.instantiate(ndim)
out = overload(*args, **kwargs)
# NOTE: overload keeps the type of outputs:
Expand All @@ -1108,14 +1124,20 @@ def prepare_args(self, *args, **kwargs):
schema = self.fx
outputs_that_need_allocation: List[int] = []
out_tensors = []

# input arguments must be passed by position
in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]

for i in range(schema.num_output_tensors()):
k = f"out{i}"
if k in kwargs:
out_tensors.append(kwargs[k])

if not schema._is_inplace:
if k in kwargs:
out_tensors.append(kwargs[k])
else:
outputs_that_need_allocation.append(i)
else:
outputs_that_need_allocation.append(i)
# input arguments must be passed by position
in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
out_tensors.append(in_tensors[0])

# output dtype promotions
outputs_dtypes_for_allocation = []
Expand All @@ -1125,7 +1147,11 @@ def prepare_args(self, *args, **kwargs):
_, dtype = type_promotion(*promote_args, type_promotion=method)
outputs_dtypes_for_allocation.append(dtype)

print("outputs_dtypes_for_allocation is: ", outputs_dtypes_for_allocation)
print("outputs_that_need_allocation is: ", outputs_that_need_allocation)

tensors = out_tensors + in_tensors

if self.use_fast_path(tensors): # dimension collapse & use physical ordering
allocated_outputs = [
torch.empty_like(tensors[0], dtype=dtype)
Expand Down Expand Up @@ -1197,6 +1223,12 @@ def prepare_args(self, *args, **kwargs):
task_shape,
broadcasted_stride(item.shape, item.stride(), task_shape),
)

if schema._is_inplace:
# Inplace operator only have single output.
# need to check args[0] is tensor
kwargs["out0"] = args[0]

return (ndim, args, kwargs)

def _unwrap(self, tensors):
Expand Down Expand Up @@ -1278,6 +1310,7 @@ def pointwise_dynamic(
num_outputs: Optional[int] = None,
promotion_methods: Optional[Tuple[int, ...]] = None,
config: Optional[CodeGenConfig] = None,
is_inplace=False,
):
def decorator(fn):
nonlocal num_inputs
Expand All @@ -1289,6 +1322,7 @@ def decorator(fn):
dtypes=dtypes,
num_outputs=num_outputs,
promotion_methods=promotion_methods,
is_inplace=is_inplace,
)
return PointwiseDynamicFunction(op_desc, fn, config)

Expand Down
23 changes: 23 additions & 0 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ def test_accuracy_add(shape, alpha, dtype):
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.inplace_add
# @pytest.mark.parametrize("shape", POINTWISE_SHAPES)
# @pytest.mark.parametrize("alpha", SCALARS)
# @pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("shape", [(3, 3)])
@pytest.mark.parametrize("alpha", [1.0])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_accuracy_inplace_add(shape, alpha, dtype):
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp1 = to_reference(inp1, False)
ref_inp2 = to_reference(inp2, False)

ref_inp1.add_(ref_inp2)
# TODO(zzk): maximum recursion
# with flag_gems.use_gems():
# inp1.add_(inp2)

flag_gems.ops.inplace_add(inp1, inp2)

gems_assert_close(ref_inp1, inp1, dtype)


@pytest.mark.add
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("scalar", SCALARS)
Expand Down
Loading