Skip to content

Commit

Permalink
Merge pull request #25 from agentd00nut/bugfix/key-valuerror-after-copy
Browse files Browse the repository at this point in the history
Fix key/value error after copy for ImagePreset
  • Loading branch information
Aedial committed Dec 3, 2023
2 parents 37f2250 + 52df5f2 commit d652e99
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
4 changes: 4 additions & 0 deletions novelai_api/ImagePreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class ImagePreset:
}

_TYPE_MAPPING = {
"legacy": bool,
"quality_toggle": bool,
"resolution": (ImageResolution, tuple),
"uc_preset": (UCPreset, NoneType),
Expand All @@ -208,6 +209,7 @@ class ImagePreset:
"noise": (int, float),
"strength": (int, float),
"scale": (int, float),
"uncond_scale": (int, float),
"steps": int,
"uc": str,
"smea": bool,
Expand All @@ -219,6 +221,8 @@ class ImagePreset:
"decrisper": bool,
"add_original_image": bool,
"mask": str,
"cfg_rescale": float,
"noise_schedule": str,
}

# type completion for __setitem__ and __getitem__
Expand Down
17 changes: 11 additions & 6 deletions novelai_api/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,17 @@ async def _generate(
global_settings.rep_pen_whitelist = repetition_penalty_default_whitelist

params = {
"repetition_penalty_whitelist": list(set(
item for sublist in [
global_params.pop("repetition_penalty_whitelist", []),
preset_params.pop("repetition_penalty_whitelist", []),
] for inner_list in sublist for item in inner_list
))
"repetition_penalty_whitelist": list(
set(
item
for sublist in [
global_params.pop("repetition_penalty_whitelist", []),
preset_params.pop("repetition_penalty_whitelist", []),
]
for inner_list in sublist
for item in inner_list
)
)
}

params.update(preset_params)
Expand Down
1 change: 1 addition & 0 deletions tests/api/test_imagegen_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def test_samplers(
logger.info(f"Testing model {model} with sampler {sampler}")

preset = ImagePreset(sampler=sampler)
preset.copy()

# Furry doesn't have UCPreset.Preset_Low_Quality_Bad_Anatomy
if model is ImageModel.Furry:
Expand Down

0 comments on commit d652e99

Please sign in to comment.