diff --git a/example/boilerplate.py b/example/boilerplate.py index e531eaf..dfb01d7 100644 --- a/example/boilerplate.py +++ b/example/boilerplate.py @@ -22,7 +22,7 @@ class API: logger: Logger api: Optional[NovelAIAPI] - def __init__(self): + def __init__(self, base_address: Optional[str] = None): dotenv.load_dotenv() if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env: @@ -35,6 +35,8 @@ def __init__(self): self.logger.addHandler(StreamHandler()) self.api = NovelAIAPI(logger=self.logger) + if base_address is not None: + self.api.BASE_ADDRESS = base_address @property def encryption_key(self): diff --git a/example/generate_image_with_inpainting.py b/example/generate_image_with_inpainting.py new file mode 100644 index 0000000..b876009 --- /dev/null +++ b/example/generate_image_with_inpainting.py @@ -0,0 +1,44 @@ +""" +{filename} +============================================================================== + +| Example of how to generate an image with inpainting +| +| The resulting image will be placed in a folder named "results" +""" + +import asyncio +import base64 +from pathlib import Path + +from example.boilerplate import API +from novelai_api.ImagePreset import ImageGenerationType, ImageModel, ImagePreset + + +async def main(): + d = Path("results") + d.mkdir(exist_ok=True) + + async with API() as api_handler: + api = api_handler.api + + image = base64.b64encode((d / "image.png").read_bytes()).decode() + mask = base64.b64encode((d / "inpainting_mask.png").read_bytes()).decode() + + preset = ImagePreset() + preset.noise = 0.1 + # note that steps = 28, not 50, which mean strength needs to be adjusted accordingly + preset.strength = 0.5 + preset.image = image + preset.mask = mask + preset.add_original_image = False + preset.seed = 42 + + async for _, img in api.high_level.generate_image( + "1girl", ImageModel.Inpainting_Anime_Full, preset, ImageGenerationType.INPAINTING + ): + (d / "image_with_inpainting.png").write_bytes(img) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/novelai_api/ImagePreset.py b/novelai_api/ImagePreset.py index d785d7e..6a1f722 100644 --- a/novelai_api/ImagePreset.py +++ b/novelai_api/ImagePreset.py @@ -18,7 +18,9 @@ class ImageModel(enum.Enum): Anime_Full = "nai-diffusion" Furry = "nai-diffusion-furry" - Anime_Inpainting = "anime-diffusion-inpainting" + Inainting_Anime_Curated = "safe-diffusion-inpainting" + Inpainting_Anime_Full = "nai-diffusion-inpainting" + Inpainting_Furry = "furry-diffusion-inpainting" class ControlNetModel(enum.Enum): @@ -93,7 +95,7 @@ class ImageGenerationType(enum.Enum): NORMAL = "generate" IMG2IMG = "img2img" - # inpainting should go there + INPAINTING = "infill" class ImagePreset: @@ -124,6 +126,24 @@ class ImagePreset: "{{unfinished}}, deformed, outline, pattern, simple background", UCPreset.Preset_None: "low res", }, + ImageModel.Inainting_Anime_Curated: { + UCPreset.Preset_Low_Quality_Bad_Anatomy: "", + UCPreset.Preset_Bad_Anatomy: None, + UCPreset.Preset_Low_Quality: "", + UCPreset.Preset_None: "", + }, + ImageModel.Inpainting_Anime_Full: { + UCPreset.Preset_Low_Quality_Bad_Anatomy: "", + UCPreset.Preset_Bad_Anatomy: None, + UCPreset.Preset_Low_Quality: "", + UCPreset.Preset_None: "", + }, + ImageModel.Inpainting_Furry: { + UCPreset.Preset_Low_Quality_Bad_Anatomy: None, + UCPreset.Preset_Bad_Anatomy: "", + UCPreset.Preset_Low_Quality: "", + UCPreset.Preset_None: "", + }, } _CONTROLNET_MODELS = { @@ -153,9 +173,8 @@ class ImagePreset: "controlnet_model": ControlNetModel, "controlnet_strength": (int, float), "decrisper": bool, - # TODO - # "dynamic_thresholding_mimic_scale": (int, float), - # "dynamic_thresholding_percentile": (int, float), + "add_original_image": bool, + "mask": str, } # type completion for __setitem__ and __getitem__ @@ -196,10 +215,10 @@ class ImagePreset: controlnet_strength: float #: Reduce the deepfrying effects of high scale (https://twitter.com/Birchlabs/status/1582165379832348672) decrisper: bool - - # TODO - # dynamic_thresholding_mimic_scale: float - # dynamic_thresholding_percentile: float + #: Prevent seams along the edges of the mask, but may change the image slightly + add_original_image: bool + #: Mask for inpainting (b64-encoded black and white png image, white is the inpainting area) + mask: str _DEFAULT = { "legacy": False, @@ -217,6 +236,7 @@ class ImagePreset: "smea_dyn": False, "decrisper": False, "controlnet_strength": 1.0, + "add_original_image": False, } _settings: Dict[str, Any] diff --git a/novelai_api/NovelAIError.py b/novelai_api/NovelAIError.py index 94881f2..de000a9 100644 --- a/novelai_api/NovelAIError.py +++ b/novelai_api/NovelAIError.py @@ -3,14 +3,17 @@ class NovelAIError(Exception): Expected raised by the NAI API when a problem occurs """ + #: Url that caused the error + url: str #: Provided status code, or -1 if no status code was provided status: int #: Provided error message message: str - def __init__(self, status: int, message: str) -> None: + def __init__(self, url: str, status: int, message: str) -> None: + self.url = url self.status = status self.message = message def __str__(self) -> str: - return f"{self.status} - {self.message}" + return f"{self.url} ({self.status}) - {self.message}" diff --git a/novelai_api/_low_level.py b/novelai_api/_low_level.py index 4fc8c5d..c4d5df3 100644 --- a/novelai_api/_low_level.py +++ b/novelai_api/_low_level.py @@ -34,6 +34,14 @@ def print_with_parameters(args: Dict[str, Any]): if "parameters" in a: a["parameters"] = {k: str(v) for k, v in a["parameters"].items()} + for k in ["image", "mask", "controlnet_condition"]: + if k in a["parameters"]: + a["parameters"][k] = ( + f"{a['parameters'][k][:10]}...{a['parameters'][k][-10:]}" + if 30 < len(a["parameters"][k]) + else a["parameters"][k] + ) + print(json.dumps(a, indent=4)) @@ -50,9 +58,11 @@ def __init__(self, parent: "NovelAIAPI"): # noqa: F821 @staticmethod def _treat_response_object(rsp: ClientResponse, content: Any, status: int) -> Any: + url: str = rsp.url if isinstance(rsp.url, str) else rsp.url.human_repr() + # error is an unexpected fail and usually come with a success status if isinstance(content, dict) and "error" in content: - raise NovelAIError(rsp.status, content["error"]) + raise NovelAIError(url, rsp.status, content["error"]) # success if rsp.status == status: @@ -60,13 +70,13 @@ def _treat_response_object(rsp: ClientResponse, content: Any, status: int) -> An # not success, but valid response if isinstance(content, dict) and "message" in content: # NovelAI REST API error - raise NovelAIError(rsp.status, content["message"]) + raise NovelAIError(url, rsp.status, content["message"]) # HTTPException error if hasattr(rsp, "reason"): - raise NovelAIError(rsp.status, str(rsp.reason)) + raise NovelAIError(url, rsp.status, str(rsp.reason)) - raise NovelAIError(rsp.status, "Unknown error") + raise NovelAIError(url, rsp.status, "Unknown error") def _treat_response_bool(self, rsp: ClientResponse, content: Any, status: int) -> bool: if rsp.status == status: @@ -164,7 +174,8 @@ async def _parse_response(cls, rsp: ClientResponse): yield e["data"] else: - raise NovelAIError(-1, f"Unsupported type: {rsp.content_type}") + url: str = rsp.url if isinstance(rsp.url, str) else rsp.url.human_repr() + raise NovelAIError(url, -1, f"Unsupported type: {rsp.content_type}") async def request(self, method: str, endpoint: str, data: Optional[Union[Dict[str, Any], str]] = None): """ @@ -796,6 +807,35 @@ async def generate_image( yield content + async def generate_prompt(self, model: Model, prompt: str, temp: float, length: int) -> Dict[str, Any]: + """ + Generate a prompt + + :param model: Model to use for the prompt + :param prompt: Prompt to base the generation on + :param temp: Temperature for the generation + :param length: Length of the returned prompt + + :return: Generated prompt + """ + + assert_type(Model, model=model) + assert_type(str, prompt=prompt) + assert_type(float, temp=temp) + assert_type(int, length=length) + + args = { + "model": model.value, + "prompt": prompt, + "temp": temp, + "tokens_to_generate": length, + } + + async for rsp, content in self.request("post", "/ai/generate-prompt", args): + self._treat_response_object(rsp, content, 201) + + return content + async def generate_controlnet_mask(self, model: ControlNetModel, image: str) -> Tuple[str, bytes]: """ Get the ControlNet's mask for the given image. Used for ImageSampler.controlnet_condition diff --git a/novelai_api/poetry_scripts.py b/novelai_api/poetry_scripts.py index 63733db..f5a57ef 100644 --- a/novelai_api/poetry_scripts.py +++ b/novelai_api/poetry_scripts.py @@ -114,5 +114,5 @@ def bump_version(): print( "You can now push the commit and the tag with `git push --follow-tags`.\n" - "Ensure you're really ready to push with `git status` and `git log`." + "Ensure you're ready to push with `git status` and `git log`." )