diff --git a/tritonbench/operators/embedding/operator.py b/tritonbench/operators/embedding/operator.py index 8c7ff41..00fd469 100644 --- a/tritonbench/operators/embedding/operator.py +++ b/tritonbench/operators/embedding/operator.py @@ -48,7 +48,7 @@ def liger_embedding(self, V, D, input) -> Callable: @register_benchmark() def inductor_embedding(self, V, D, input) -> Callable: self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype) - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(input) @register_x_val(label="(B, T, D, V)") diff --git a/tritonbench/operators/fused_linear_cross_entropy/operator.py b/tritonbench/operators/fused_linear_cross_entropy/operator.py index 2d38478..9f35934 100644 --- a/tritonbench/operators/fused_linear_cross_entropy/operator.py +++ b/tritonbench/operators/fused_linear_cross_entropy/operator.py @@ -104,7 +104,7 @@ def liger_lm_head_ce(self, input, target) -> Callable: @register_benchmark() def inductor_fused_linear_cross_entropy(self, input, target) -> Callable: - compiled = torch.compile(self.baseline_model, dynamic=False) + compiled = torch.compile(self.baseline_model) return lambda: compiled(input, target) @register_x_val(label="(B*T, H)") diff --git a/tritonbench/operators/fused_linear_jsd/operator.py b/tritonbench/operators/fused_linear_jsd/operator.py index 758c0c9..7ebdcc3 100644 --- a/tritonbench/operators/fused_linear_jsd/operator.py +++ b/tritonbench/operators/fused_linear_jsd/operator.py @@ -166,7 +166,7 @@ def liger_lm_head_jsd(self, student_input, teacher_input) -> Callable: @register_benchmark() def inductor_lm_head_jsd(self, student_input, teacher_input) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(student_input, teacher_input) @register_x_val(label="(B*T, H)") diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index fa2b2f0..237f850 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -59,7 +59,7 @@ def liger_geglu(self, input) -> Callable: @register_benchmark() def inductor_geglu(self, input) -> Callable: - compiled = torch.compile(self.baseline_model, dynamic=False) + compiled = torch.compile(self.baseline_model) return lambda: compiled(input) @register_x_val(label="(B, T, H)") diff --git a/tritonbench/operators/jsd/operator.py b/tritonbench/operators/jsd/operator.py index 881a8d5..5a42f29 100644 --- a/tritonbench/operators/jsd/operator.py +++ b/tritonbench/operators/jsd/operator.py @@ -87,7 +87,7 @@ def liger_jsd(self, _input, target) -> Callable: @register_benchmark() def inductor_jsd(self, _input, target) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(_input, target) @register_x_val(label="(B, T, V)") diff --git a/tritonbench/operators/kl_div/operator.py b/tritonbench/operators/kl_div/operator.py index 129ac5c..0d600cc 100644 --- a/tritonbench/operators/kl_div/operator.py +++ b/tritonbench/operators/kl_div/operator.py @@ -47,7 +47,7 @@ def liger_kl_div(self, input, target) -> Callable: @register_benchmark() def inductor_kl_div(self, input, target) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(input, target) @register_x_val(label="(B, T, V)") diff --git a/tritonbench/operators/rms_norm/operator.py b/tritonbench/operators/rms_norm/operator.py index 492f7de..0c62d39 100644 --- a/tritonbench/operators/rms_norm/operator.py +++ b/tritonbench/operators/rms_norm/operator.py @@ -65,7 +65,11 @@ def liger_rms(self, H, input) -> Callable: @register_benchmark() def inductor_rms(self, H, input) -> Callable: - compiled = torch.compile(self.llama_rms_op, dynamic=False) + if self.llama_rms_op is None: + self.llama_rms_op = LlamaRMSNorm(hidden_size=H, eps=self.eps).to( + self.device + ) + compiled = torch.compile(self.llama_rms_op) return lambda: compiled(input) @register_x_val(label="(M, H)") diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py index ab2b847..174626a 100644 --- a/tritonbench/operators/rope/operator.py +++ b/tritonbench/operators/rope/operator.py @@ -88,11 +88,9 @@ def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable: def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable: q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length) head_dim = hidden_size // self.num_q_heads - compiled = torch.compile( - LlamaRotaryEmbedding(head_dim, device=self.device), dynamic=False - ) + compiled = torch.compile(LlamaRotaryEmbedding(head_dim, device=self.device)) cos, sin = compiled(k, pos_ids) - compiled_func = torch.compile(apply_rotary_pos_emb, dynamic=False) + compiled_func = torch.compile(apply_rotary_pos_emb) return lambda: compiled_func(q, k, cos, sin, pos_ids) @register_x_val(label="(H, T)") diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index 7808da8..b21fede 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -60,7 +60,7 @@ def liger_swiglu(self, input) -> Callable: @register_benchmark() def inductor_swiglu(self, input) -> Callable: - compiled = torch.compile(self.baseline_op, dynamic=False) + compiled = torch.compile(self.baseline_op) return lambda: compiled(input) @register_x_val(label="(B, T, H)")