diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index c50d66bea..3045ed1f4 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -21,6 +21,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ( ("abs", abs, Autograd.disable), ("add.Tensor", add, Autograd.disable), + ("add_.Tensor", inplace_add, Autograd.disable), ("addmm", addmm, Autograd.disable), ("arange.start_step", arange_start, Autograd.disable), ("arange.start", arange_start, Autograd.disable), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index f0a1cc9fe..d5f939ef4 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -1,5 +1,5 @@ from .abs import abs -from .add import add +from .add import add, inplace_add from .addmm import addmm from .all import all, all_dim, all_dims from .amax import amax @@ -148,6 +148,7 @@ "any_dim", "any_dims", "add", + "inplace_add", "abs", "addmm", "arange", diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index 9e4a816fa..545c832e6 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -28,6 +28,36 @@ def add_func_scalar_tensor(x, y, alpha): return x + y * alpha +@pointwise_dynamic( + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + is_inplace=True, +) +@triton.jit +def inplace_add_func(x, y, alpha): + return x + y * alpha + + +@pointwise_dynamic( + is_tensor=[True, False, False], + promotion_methods=[(0, 1, "DEFAULT")], + is_inplace=True, +) +@triton.jit +def inplace_add_func_tensor_scalar(x, y, alpha): + return x + y * alpha + + +@pointwise_dynamic( + is_tensor=[False, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + is_inplace=True, +) +@triton.jit +def inplace_add_func_scalar_tensor(x, y, alpha): + return x + y * alpha + + def add(A, B, *, alpha=1): logging.debug("GEMS ADD") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): @@ -38,3 +68,15 @@ def add(A, B, *, alpha=1): return add_func_scalar_tensor(A, B, alpha) else: return torch.tensor(A + B * alpha) + + +def inplace_add(A, B, *, alpha=1): + logging.debug("GEMS ADD") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return inplace_add_func(A, B, alpha) + elif isinstance(A, torch.Tensor): + return inplace_add_func_tensor_scalar(A, B, alpha) + elif isinstance(B, torch.Tensor): + return inplace_add_func_scalar_tensor(A, B, alpha) + else: + raise NotImplementedError("Inplace add not implemented for scalars") diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 2c151e584..fc8814321 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -77,6 +77,7 @@ def __init__( dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, promotion_methods=None, + is_inplace=False, ): if is_tensor is not None: _check_typed_list(is_tensor, bool) @@ -137,6 +138,11 @@ def __init__( self._num_input_tensors = sum(self._is_tensor) self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors self._input_id = self._compute_input_id() + self._is_inplace = is_inplace + if self._is_inplace: + assert ( + self._num_outputs == 1 + ), "Inplace operation should have only single output" @staticmethod def canonicalize_promotion_methods(promotion_methods): @@ -186,14 +192,16 @@ def signature(self, outputs_in_arg: bool = False) -> str: output_types = [] - if outputs_in_arg: - for i in range(self.num_outputs()): - output_types.append(f"StridedBuffer(a{1}!)") - input_types.extend(output_types) - else: - for _ in range(self.num_outputs()): - output_types.append("StridedBuffer") + if not self._is_inplace: + if outputs_in_arg: + for i in range(self.num_outputs()): + output_types.append(f"StridedBuffer(a{1}!)") + input_types.extend(output_types) + else: + for _ in range(self.num_outputs()): + output_types.append("StridedBuffer") sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' + print("signature is: ", sig) return sig def _compute_input_id(self): @@ -786,9 +794,11 @@ def gen_signature(self, code: IndentedBuffer): # tensors). We emphasize that these parameters are added in-addition, we enforce # that they be passed by keyword. After all, out0, out1, ... does not mismatch # names form the scalar function, since it does not have output parameters. + # if not schema._is_inplace: params.append("/") params.append("*") # output params must be passed by keyword + # if not schema._is_inplace: for i in range(schema.num_output_tensors()): params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") code.writeline(f"def {self.name}({_cs(params)}): ") @@ -1081,7 +1091,13 @@ def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None) def __call__(self, *args, **kwargs): # inputs must be passed by position, outputs must be passed by keyword + print("Before prepare args is: ", args) + print("Before prepare kwargs is: ", kwargs) + ndim, args, kwargs = self.prepare_args(*args, **kwargs) + print("After prepare args is: ", args) + print("After prepare kwargs is: ", kwargs) + overload = self.instantiate(ndim) out = overload(*args, **kwargs) # NOTE: overload keeps the type of outputs: @@ -1108,14 +1124,20 @@ def prepare_args(self, *args, **kwargs): schema = self.fx outputs_that_need_allocation: List[int] = [] out_tensors = [] + + # input arguments must be passed by position + in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] + for i in range(schema.num_output_tensors()): k = f"out{i}" - if k in kwargs: - out_tensors.append(kwargs[k]) + + if not schema._is_inplace: + if k in kwargs: + out_tensors.append(kwargs[k]) + else: + outputs_that_need_allocation.append(i) else: - outputs_that_need_allocation.append(i) - # input arguments must be passed by position - in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] + out_tensors.append(in_tensors[0]) # output dtype promotions outputs_dtypes_for_allocation = [] @@ -1125,7 +1147,11 @@ def prepare_args(self, *args, **kwargs): _, dtype = type_promotion(*promote_args, type_promotion=method) outputs_dtypes_for_allocation.append(dtype) + print("outputs_dtypes_for_allocation is: ", outputs_dtypes_for_allocation) + print("outputs_that_need_allocation is: ", outputs_that_need_allocation) + tensors = out_tensors + in_tensors + if self.use_fast_path(tensors): # dimension collapse & use physical ordering allocated_outputs = [ torch.empty_like(tensors[0], dtype=dtype) @@ -1197,6 +1223,12 @@ def prepare_args(self, *args, **kwargs): task_shape, broadcasted_stride(item.shape, item.stride(), task_shape), ) + + if schema._is_inplace: + # Inplace operator only have single output. + # need to check args[0] is tensor + kwargs["out0"] = args[0] + return (ndim, args, kwargs) def _unwrap(self, tensors): @@ -1278,6 +1310,7 @@ def pointwise_dynamic( num_outputs: Optional[int] = None, promotion_methods: Optional[Tuple[int, ...]] = None, config: Optional[CodeGenConfig] = None, + is_inplace=False, ): def decorator(fn): nonlocal num_inputs @@ -1289,6 +1322,7 @@ def decorator(fn): dtypes=dtypes, num_outputs=num_outputs, promotion_methods=promotion_methods, + is_inplace=is_inplace, ) return PointwiseDynamicFunction(op_desc, fn, config) diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index b77f85b67..0537b6b15 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -43,6 +43,29 @@ def test_accuracy_add(shape, alpha, dtype): gems_assert_close(res_out, ref_out, dtype) +@pytest.mark.inplace_add +# @pytest.mark.parametrize("shape", POINTWISE_SHAPES) +# @pytest.mark.parametrize("alpha", SCALARS) +# @pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("shape", [(3, 3)]) +@pytest.mark.parametrize("alpha", [1.0]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_accuracy_inplace_add(shape, alpha, dtype): + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + ref_inp1 = to_reference(inp1, False) + ref_inp2 = to_reference(inp2, False) + + ref_inp1.add_(ref_inp2) + # TODO(zzk): maximum recursion + # with flag_gems.use_gems(): + # inp1.add_(inp2) + + flag_gems.ops.inplace_add(inp1, inp2) + + gems_assert_close(ref_inp1, inp1, dtype) + + @pytest.mark.add @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("scalar", SCALARS)