diff --git a/quantize.py b/quantize.py index af17a698..0a91938b 100644 --- a/quantize.py +++ b/quantize.py @@ -539,7 +539,7 @@ def quantize( device: str = default_device, ) -> None: assert checkpoint_path.is_file(), checkpoint_path - device = 'cpu' + print(f"Using device={device}") precision = torch.bfloat16 print("Loading model ...")