diff --git a/quantize.py b/quantize.py index fb566421..eb2df215 100644 --- a/quantize.py +++ b/quantize.py @@ -19,6 +19,8 @@ from model import Transformer +default_use_cuda = True if torch.cuda.is_available() else False + ##### Quantization Primitives ###### def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): @@ -530,6 +532,7 @@ def quantize( mode: str = 'int8', # following arguments only available when setting int4 quantization. groupsize: int = 128, + use_cuda = default_use_cuda, # following arguments only used for GPTQ calibration_tasks: list = ["hellaswag"], calibration_limit: int = 1000, @@ -566,7 +569,7 @@ def quantize( elif mode == 'int4': print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) - quantized_state_dict = quant_handler.create_quantized_state_dict() + quantized_state_dict = quant_handler.create_quantized_state_dict(use_cuda) dir_name = checkpoint_path.parent base_name = checkpoint_path.name @@ -610,6 +613,7 @@ def quantize( parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') + parser.add_argument('--use_cuda', type=bool, default=default_use_cuda, help='Whether to use cuda for int4 quantization.') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') @@ -619,4 +623,4 @@ def quantize( parser.add_argument('--label', type=str, default='_', help='label to add to output filename') args = parser.parse_args() - quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) + quantize(args.checkpoint_path, args.mode, args.groupsize, args.use_cuda, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)