-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backward register #423
base: master
Are you sure you want to change the base?
Backward register #423
Conversation
9f79739
to
01bee17
Compare
save_invstd=None, | ||
train=False, | ||
eps=1e-05, | ||
output_mask=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last argument should be grad_input_mask.
affine: tl.constexpr, | ||
input_grad_mask: tl.constexpr, | ||
weight_grad_mask: tl.constexpr, | ||
bias_grad_mask: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backward kernel may need is_train arg also, to distinguish between train and non-train cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave it for future work tho.
running_var=None, | ||
save_mean=None, | ||
save_invstd=None, | ||
train=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel should be able to handle train=True case.
|
||
def native_dropout(x, p=0.5, train=True): | ||
return NativeDropout.apply(x, p, train) | ||
def dropout(input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arg train is optional.
logging.debug("GEMS NATIVE DROPOUT FORWARD") | ||
assert p > 0.0 and p < 1.0, "p must be in (0, 1)" | ||
device = input.device | ||
input = input.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note that we'll remove contiguous enforcement in the future.
indices = indices.contiguous() | ||
weight = weight.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor this in TODOs.
mean = mean.contiguous() | ||
rstd = rstd.contiguous() | ||
weight = None if weight is None else weight.contiguous() | ||
group_size = C // group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cdiv?
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups), | ||
BLOCK_HW_SIZE=triton.next_power_of_2(HW), | ||
HxW, | ||
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // group), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cdiv(C, group)?
|
||
def native_dropout(x, p=0.5, train=True): | ||
return NativeDropout.apply(x, p, train) | ||
def dropout(input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized we didn't handle we train=False correctly in the previous version. Let's fix that.
PR Category
Operator
Type of Change
New Feature
Description
register backward functions as aten interfaces
implement threshold operator incidentally
Issue
Progress
Performance