Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward register #423

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open

Backward register #423

wants to merge 20 commits into from

Conversation

StrongSpoon
Copy link
Collaborator

@StrongSpoon StrongSpoon commented Jan 16, 2025

PR Category

Operator

Type of Change

New Feature

Description

register backward functions as aten interfaces
implement threshold operator incidentally

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@StrongSpoon StrongSpoon force-pushed the bwd branch 2 times, most recently from 9f79739 to 01bee17 Compare February 6, 2025 09:26
@StrongSpoon StrongSpoon marked this pull request as ready for review February 11, 2025 02:04
save_invstd=None,
train=False,
eps=1e-05,
output_mask=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last argument should be grad_input_mask.

affine: tl.constexpr,
input_grad_mask: tl.constexpr,
weight_grad_mask: tl.constexpr,
bias_grad_mask: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backward kernel may need is_train arg also, to distinguish between train and non-train cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave it for future work tho.

running_var=None,
save_mean=None,
save_invstd=None,
train=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel should be able to handle train=True case.


def native_dropout(x, p=0.5, train=True):
return NativeDropout.apply(x, p, train)
def dropout(input, p, train):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arg train is optional.

logging.debug("GEMS NATIVE DROPOUT FORWARD")
assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
device = input.device
input = input.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that we'll remove contiguous enforcement in the future.

Comment on lines +119 to +120
indices = indices.contiguous()
weight = weight.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor this in TODOs.

mean = mean.contiguous()
rstd = rstd.contiguous()
weight = None if weight is None else weight.contiguous()
group_size = C // group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cdiv?

BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups),
BLOCK_HW_SIZE=triton.next_power_of_2(HW),
HxW,
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // group),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cdiv(C, group)?


def native_dropout(x, p=0.5, train=True):
return NativeDropout.apply(x, p, train)
def dropout(input, p, train):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we didn't handle we train=False correctly in the previous version. Let's fix that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants