Skip to content

Commit

Permalink
Remove calibration args from generate.py (#1258)
Browse files Browse the repository at this point in the history
Removed calibration args from generate.py script because quantization speedup analysis does not require "real" calibration
  • Loading branch information
vayuda authored Nov 20, 2024
1 parent f87fb56 commit 129316d
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
calibration_limit: int = 10,
calibration_seq_length: int = 256,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool=False,
Expand Down Expand Up @@ -268,16 +266,16 @@ def main(
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model=model.to(device)
# get calibration data
insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
max_seq_length=256,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=['wikitext'],
limit=calibration_limit,
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
Expand Down Expand Up @@ -491,8 +489,6 @@ def callback(x):
+'embed-int8wo, marlin_qqq'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration")
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand All @@ -508,5 +504,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)

0 comments on commit 129316d

Please sign in to comment.