Skip to content

Commit

Permalink
[API] Add inpainting and /ai/generate-prompt
Browse files Browse the repository at this point in the history
No idea what this new endpoint is for
  • Loading branch information
Aedial committed May 25, 2023
1 parent abdca64 commit aa3a503
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 18 deletions.
4 changes: 3 additions & 1 deletion example/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class API:
logger: Logger
api: Optional[NovelAIAPI]

def __init__(self):
def __init__(self, base_address: Optional[str] = None):
dotenv.load_dotenv()

if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env:
Expand All @@ -35,6 +35,8 @@ def __init__(self):
self.logger.addHandler(StreamHandler())

self.api = NovelAIAPI(logger=self.logger)
if base_address is not None:
self.api.BASE_ADDRESS = base_address

@property
def encryption_key(self):
Expand Down
44 changes: 44 additions & 0 deletions example/generate_image_with_inpainting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
{filename}
==============================================================================
| Example of how to generate an image with inpainting
|
| The resulting image will be placed in a folder named "results"
"""

import asyncio
import base64
from pathlib import Path

from example.boilerplate import API
from novelai_api.ImagePreset import ImageGenerationType, ImageModel, ImagePreset


async def main():
d = Path("results")
d.mkdir(exist_ok=True)

async with API() as api_handler:
api = api_handler.api

image = base64.b64encode((d / "image.png").read_bytes()).decode()
mask = base64.b64encode((d / "inpainting_mask.png").read_bytes()).decode()

preset = ImagePreset()
preset.noise = 0.1
# note that steps = 28, not 50, which mean strength needs to be adjusted accordingly
preset.strength = 0.5
preset.image = image
preset.mask = mask
preset.add_original_image = False
preset.seed = 42

async for _, img in api.high_level.generate_image(
"1girl", ImageModel.Inpainting_Anime_Full, preset, ImageGenerationType.INPAINTING
):
(d / "image_with_inpainting.png").write_bytes(img)


if __name__ == "__main__":
asyncio.run(main())
38 changes: 29 additions & 9 deletions novelai_api/ImagePreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class ImageModel(enum.Enum):
Anime_Full = "nai-diffusion"
Furry = "nai-diffusion-furry"

Anime_Inpainting = "anime-diffusion-inpainting"
Inainting_Anime_Curated = "safe-diffusion-inpainting"
Inpainting_Anime_Full = "nai-diffusion-inpainting"
Inpainting_Furry = "furry-diffusion-inpainting"


class ControlNetModel(enum.Enum):
Expand Down Expand Up @@ -93,7 +95,7 @@ class ImageGenerationType(enum.Enum):

NORMAL = "generate"
IMG2IMG = "img2img"
# inpainting should go there
INPAINTING = "infill"


class ImagePreset:
Expand Down Expand Up @@ -124,6 +126,24 @@ class ImagePreset:
"{{unfinished}}, deformed, outline, pattern, simple background",
UCPreset.Preset_None: "low res",
},
ImageModel.Inainting_Anime_Curated: {
UCPreset.Preset_Low_Quality_Bad_Anatomy: "",
UCPreset.Preset_Bad_Anatomy: None,
UCPreset.Preset_Low_Quality: "",
UCPreset.Preset_None: "",
},
ImageModel.Inpainting_Anime_Full: {
UCPreset.Preset_Low_Quality_Bad_Anatomy: "",
UCPreset.Preset_Bad_Anatomy: None,
UCPreset.Preset_Low_Quality: "",
UCPreset.Preset_None: "",
},
ImageModel.Inpainting_Furry: {
UCPreset.Preset_Low_Quality_Bad_Anatomy: None,
UCPreset.Preset_Bad_Anatomy: "",
UCPreset.Preset_Low_Quality: "",
UCPreset.Preset_None: "",
},
}

_CONTROLNET_MODELS = {
Expand Down Expand Up @@ -153,9 +173,8 @@ class ImagePreset:
"controlnet_model": ControlNetModel,
"controlnet_strength": (int, float),
"decrisper": bool,
# TODO
# "dynamic_thresholding_mimic_scale": (int, float),
# "dynamic_thresholding_percentile": (int, float),
"add_original_image": bool,
"mask": str,
}

# type completion for __setitem__ and __getitem__
Expand Down Expand Up @@ -196,10 +215,10 @@ class ImagePreset:
controlnet_strength: float
#: Reduce the deepfrying effects of high scale (https://twitter.com/Birchlabs/status/1582165379832348672)
decrisper: bool

# TODO
# dynamic_thresholding_mimic_scale: float
# dynamic_thresholding_percentile: float
#: Prevent seams along the edges of the mask, but may change the image slightly
add_original_image: bool
#: Mask for inpainting (b64-encoded black and white png image, white is the inpainting area)
mask: str

_DEFAULT = {
"legacy": False,
Expand All @@ -217,6 +236,7 @@ class ImagePreset:
"smea_dyn": False,
"decrisper": False,
"controlnet_strength": 1.0,
"add_original_image": False,
}

_settings: Dict[str, Any]
Expand Down
7 changes: 5 additions & 2 deletions novelai_api/NovelAIError.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ class NovelAIError(Exception):
Expected raised by the NAI API when a problem occurs
"""

#: Url that caused the error
url: str
#: Provided status code, or -1 if no status code was provided
status: int
#: Provided error message
message: str

def __init__(self, status: int, message: str) -> None:
def __init__(self, url: str, status: int, message: str) -> None:
self.url = url
self.status = status
self.message = message

def __str__(self) -> str:
return f"{self.status} - {self.message}"
return f"{self.url} ({self.status}) - {self.message}"
50 changes: 45 additions & 5 deletions novelai_api/_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def print_with_parameters(args: Dict[str, Any]):
if "parameters" in a:
a["parameters"] = {k: str(v) for k, v in a["parameters"].items()}

for k in ["image", "mask", "controlnet_condition"]:
if k in a["parameters"]:
a["parameters"][k] = (
f"{a['parameters'][k][:10]}...{a['parameters'][k][-10:]}"
if 30 < len(a["parameters"][k])
else a["parameters"][k]
)

print(json.dumps(a, indent=4))


Expand All @@ -50,23 +58,25 @@ def __init__(self, parent: "NovelAIAPI"): # noqa: F821

@staticmethod
def _treat_response_object(rsp: ClientResponse, content: Any, status: int) -> Any:
url: str = rsp.url if isinstance(rsp.url, str) else rsp.url.human_repr()

# error is an unexpected fail and usually come with a success status
if isinstance(content, dict) and "error" in content:
raise NovelAIError(rsp.status, content["error"])
raise NovelAIError(url, rsp.status, content["error"])

# success
if rsp.status == status:
return content

# not success, but valid response
if isinstance(content, dict) and "message" in content: # NovelAI REST API error
raise NovelAIError(rsp.status, content["message"])
raise NovelAIError(url, rsp.status, content["message"])

# HTTPException error
if hasattr(rsp, "reason"):
raise NovelAIError(rsp.status, str(rsp.reason))
raise NovelAIError(url, rsp.status, str(rsp.reason))

raise NovelAIError(rsp.status, "Unknown error")
raise NovelAIError(url, rsp.status, "Unknown error")

def _treat_response_bool(self, rsp: ClientResponse, content: Any, status: int) -> bool:
if rsp.status == status:
Expand Down Expand Up @@ -164,7 +174,8 @@ async def _parse_response(cls, rsp: ClientResponse):
yield e["data"]

else:
raise NovelAIError(-1, f"Unsupported type: {rsp.content_type}")
url: str = rsp.url if isinstance(rsp.url, str) else rsp.url.human_repr()
raise NovelAIError(url, -1, f"Unsupported type: {rsp.content_type}")

async def request(self, method: str, endpoint: str, data: Optional[Union[Dict[str, Any], str]] = None):
"""
Expand Down Expand Up @@ -796,6 +807,35 @@ async def generate_image(

yield content

async def generate_prompt(self, model: Model, prompt: str, temp: float, length: int) -> Dict[str, Any]:
"""
Generate a prompt
:param model: Model to use for the prompt
:param prompt: Prompt to base the generation on
:param temp: Temperature for the generation
:param length: Length of the returned prompt
:return: Generated prompt
"""

assert_type(Model, model=model)
assert_type(str, prompt=prompt)
assert_type(float, temp=temp)
assert_type(int, length=length)

args = {
"model": model.value,
"prompt": prompt,
"temp": temp,
"tokens_to_generate": length,
}

async for rsp, content in self.request("post", "/ai/generate-prompt", args):
self._treat_response_object(rsp, content, 201)

return content

async def generate_controlnet_mask(self, model: ControlNetModel, image: str) -> Tuple[str, bytes]:
"""
Get the ControlNet's mask for the given image. Used for ImageSampler.controlnet_condition
Expand Down
2 changes: 1 addition & 1 deletion novelai_api/poetry_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,5 @@ def bump_version():

print(
"You can now push the commit and the tag with `git push --follow-tags`.\n"
"Ensure you're really ready to push with `git status` and `git log`."
"Ensure you're ready to push with `git status` and `git log`."
)

0 comments on commit aa3a503

Please sign in to comment.