Skip to content

Commit

Permalink
temporarily hardcode torch_dtype=float16 for SimpleGenerativeModel
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Feb 6, 2025
1 parent 2e4fd26 commit 520d6ab
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/pie_modules/models/simple_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
# Note: We do not set expected_super_type=PreTrainedModel for resolve_type() because
# AutoModel* classed such as AutoModelForSeq2SeqLM do not inherit from that.
resolved_base_model_type: Type[PreTrainedModel] = resolve_type(base_model_type)
base_model_config["torch_dtype"] = torch.float16
self.model = resolved_base_model_type.from_pretrained(**base_model_config)
self.generation_config = self.configure_generation(**(override_generation_kwargs or {}))

Expand Down

0 comments on commit 520d6ab

Please sign in to comment.