Skip to content

Commit

Permalink
Ops bug fix and args clean (#76)
Browse files Browse the repository at this point in the history
Summary:
Fix rope's bug for specific nsys profile. Clean torch.compile args.

Pull Request resolved: #76

Reviewed By: adamomainz

Differential Revision: D66461602

Pulled By: FindHao

fbshipit-source-id: 8f56e3e60826a6d712ee3ea338e3be5dda65b6ab
  • Loading branch information
FindHao authored and facebook-github-bot committed Nov 25, 2024
1 parent 2f12e83 commit ebdb921
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tritonbench/operators/embedding/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def liger_embedding(self, V, D, input) -> Callable:
@register_benchmark()
def inductor_embedding(self, V, D, input) -> Callable:
self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype)
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(input)

@register_x_val(label="(B, T, D, V)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def liger_lm_head_ce(self, input, target) -> Callable:

@register_benchmark()
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
compiled = torch.compile(self.baseline_model)
return lambda: compiled(input, target)

@register_x_val(label="(B*T, H)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fused_linear_jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def liger_lm_head_jsd(self, student_input, teacher_input) -> Callable:

@register_benchmark()
def inductor_lm_head_jsd(self, student_input, teacher_input) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(student_input, teacher_input)

@register_x_val(label="(B*T, H)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/geglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def liger_geglu(self, input) -> Callable:

@register_benchmark()
def inductor_geglu(self, input) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
compiled = torch.compile(self.baseline_model)
return lambda: compiled(input)

@register_x_val(label="(B, T, H)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/jsd/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def liger_jsd(self, _input, target) -> Callable:

@register_benchmark()
def inductor_jsd(self, _input, target) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(_input, target)

@register_x_val(label="(B, T, V)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/kl_div/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def liger_kl_div(self, input, target) -> Callable:

@register_benchmark()
def inductor_kl_div(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(input, target)

@register_x_val(label="(B, T, V)")
Expand Down
6 changes: 5 additions & 1 deletion tritonbench/operators/rms_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def liger_rms(self, H, input) -> Callable:

@register_benchmark()
def inductor_rms(self, H, input) -> Callable:
compiled = torch.compile(self.llama_rms_op, dynamic=False)
if self.llama_rms_op is None:
self.llama_rms_op = LlamaRMSNorm(hidden_size=H, eps=self.eps).to(
self.device
)
compiled = torch.compile(self.llama_rms_op)
return lambda: compiled(input)

@register_x_val(label="(M, H)")
Expand Down
6 changes: 2 additions & 4 deletions tritonbench/operators/rope/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable:
def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
head_dim = hidden_size // self.num_q_heads
compiled = torch.compile(
LlamaRotaryEmbedding(head_dim, device=self.device), dynamic=False
)
compiled = torch.compile(LlamaRotaryEmbedding(head_dim, device=self.device))
cos, sin = compiled(k, pos_ids)
compiled_func = torch.compile(apply_rotary_pos_emb, dynamic=False)
compiled_func = torch.compile(apply_rotary_pos_emb)
return lambda: compiled_func(q, k, cos, sin, pos_ids)

@register_x_val(label="(H, T)")
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/swiglu/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def liger_swiglu(self, input) -> Callable:

@register_benchmark()
def inductor_swiglu(self, input) -> Callable:
compiled = torch.compile(self.baseline_op, dynamic=False)
compiled = torch.compile(self.baseline_op)
return lambda: compiled(input)

@register_x_val(label="(B, T, H)")
Expand Down

0 comments on commit ebdb921

Please sign in to comment.