Skip to content

Commit

Permalink
Make stable diffusion unet and vae number of channels static (hugging…
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored May 15, 2024
1 parent b3ecb6c commit 02c6ed5
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ class UNetOnnxConfig(VisionOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"sample": {0: "batch_size", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}
Expand All @@ -998,7 +998,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"out_sample": {0: "batch_size", 2: "height", 3: "width"},
}

@property
Expand Down Expand Up @@ -1045,13 +1045,13 @@ class VaeEncoderOnnxConfig(VisionOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"sample": {0: "batch_size", 2: "height", 3: "width"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"latent_sample": {0: "batch_size", 1: "num_channels_latent", 2: "height_latent", 3: "width_latent"},
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
}


Expand All @@ -1069,13 +1069,13 @@ class VaeDecoderOnnxConfig(VisionOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"latent_sample": {0: "batch_size", 1: "num_channels_latent", 2: "height_latent", 3: "width_latent"},
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"sample": {0: "batch_size", 2: "height", 3: "width"},
}


Expand Down

0 comments on commit 02c6ed5

Please sign in to comment.