Skip to content

Commit

Permalink
add cli test
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Dec 11, 2024
1 parent a0e4184 commit 2e89b60
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
6 changes: 3 additions & 3 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def parse_args_neuronx(parser: "ArgumentParser"):
optional_group.add_argument(
"--torch_dtype",
type=str,
default="auto",
choices=["auto", "bfloat16", "float16", "float32"],
help="Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the dtype will be automatically derived from the model's weights.",
default=None,
choices=["bfloat16", "float16", "float32"],
help="Override the default `torch.dtype` and load the model under this dtype. If `None` is passed, the dtype will be automatically derived from the model's weights.",
)
optional_group.add_argument(
"--tensor_parallel_size",
Expand Down
1 change: 0 additions & 1 deletion optimum/exporters/neuron/model_configs/traced_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,6 @@ def patch_model_for_export(
return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True)


@register_in_tasks_manager("t5-encoder", "text2text-generation")
class T5EncoderBaseNeuronConfig(TextSeq2SeqNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
Expand Down
32 changes: 31 additions & 1 deletion tests/cli/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,36 @@ def test_stable_diffusion(self):
check=True,
)

@requires_neuronx
def test_pixart(self):
model_ids = ["hf-internal-testing/tiny-pixart-alpha-pipe"]
for model_id in model_ids:
with tempfile.TemporaryDirectory() as tempdir:
subprocess.run(
[
"optimum-cli",
"export",
"neuron",
"--model",
model_id,
"--batch_size",
"1",
"--height",
"8",
"--width",
"8",
"--sequence_length",
"16",
"--num_images_per_prompt",
"1",
"--torch_dtype",
"bfloat16",
tempdir,
],
shell=False,
check=True,
)

@requires_neuronx
def test_stable_diffusion_multi_lora(self):
model_id = "hf-internal-testing/tiny-stable-diffusion-torch"
Expand All @@ -196,7 +226,7 @@ def test_stable_diffusion_multi_lora(self):
lora_model_id,
"--lora_weight_names",
lora_weight_name,
"lora_adapter_names",
"--lora_adapter_names",
adpater_name,
"--lora_scales",
"0.9",
Expand Down

0 comments on commit 2e89b60

Please sign in to comment.