Skip to content

Commit 3cd6a44

Browse files
authored
fix error: conv_general_qt requires the same bwd_qtype as weight_qtype (#245)
* fix error: conv_general_qt requires the same bwd_qtype as weight_qtype * fix test error
1 parent fd53cde commit 3cd6a44

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_fp8_config(cls, quantization_calibration_method: str):
265265
module_path=".*", # Apply to all modules
266266
weight_qtype=jnp.float8_e4m3fn,
267267
act_qtype=jnp.float8_e4m3fn,
268-
bwd_qtype=jnp.float8_e5m2,
268+
bwd_qtype=jnp.float8_e4m3fn,
269269
bwd_use_original_residuals=True,
270270
disable_channelwise_axes=True, # per_tensor calibration
271271
weight_calibration_method=quantization_calibration_method,

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def test_get_qt_provider(self, mock_qt_rule):
316316
module_path=".*", # Apply to all modules
317317
weight_qtype=jnp.float8_e4m3fn,
318318
act_qtype=jnp.float8_e4m3fn,
319-
bwd_qtype=jnp.float8_e5m2,
319+
bwd_qtype=jnp.float8_e4m3fn,
320320
bwd_use_original_residuals=True,
321321
disable_channelwise_axes=True, # per_tensor calibration
322322
weight_calibration_method=config_fp8_full.quantization_calibration_method,

0 commit comments

Comments
 (0)