From f1edacbe2be025f305d6146c2b77d53a7a8e346a Mon Sep 17 00:00:00 2001 From: Aedial Date: Fri, 9 Jun 2023 19:14:39 +0200 Subject: [PATCH] [API] Slight reorganization + add missing preambles --- novelai_api/Keystore.py | 2 + novelai_api/Preset.py | 4 + novelai_api/_high_level.py | 14 +- novelai_api/_low_level.py | 286 ++++++++++++++++++++----------------- 4 files changed, 168 insertions(+), 138 deletions(-) diff --git a/novelai_api/Keystore.py b/novelai_api/Keystore.py index 3ddd489..5ea78d5 100644 --- a/novelai_api/Keystore.py +++ b/novelai_api/Keystore.py @@ -92,6 +92,7 @@ def decrypt(self, key: bytes): keystore = self.data.copy() + # keystore is empty, create a new one if "keystore" in keystore and keystore["keystore"] is None: # keystore is null when empty self._nonce = random(SecretBox.NONCE_SIZE) self._version = 2 @@ -110,6 +111,7 @@ def decrypt(self, key: bytes): return + # keystore is not empty, decrypt it keystore = loads(b64decode(self.data["keystore"]).decode()) SchemaValidator.validate("schema_keystore_encrypted", keystore) diff --git a/novelai_api/Preset.py b/novelai_api/Preset.py index be17f95..ffb9b33 100644 --- a/novelai_api/Preset.py +++ b/novelai_api/Preset.py @@ -100,8 +100,12 @@ class Model(StrEnum): Inline = "infillmodel" +#: Prompt sent to the model when the context is empty PREAMBLE = { + # Model.Calliope: "⁂\n", Model.Sigurd: "⁂\n", + Model.Genji: [60, 198, 198], # "]\n\n" - impossible combination, so it is pre-tokenized + Model.Snek: "<|endoftext|>\n", Model.Euterpe: "\n***\n", Model.Krake: "<|endoftext|>[ Prologue ]\n", Model.Clio: "[ Author: Various ]\n[ Prologue ]\n", diff --git a/novelai_api/_high_level.py b/novelai_api/_high_level.py index 45a0c9b..cbc39e5 100644 --- a/novelai_api/_high_level.py +++ b/novelai_api/_high_level.py @@ -195,12 +195,12 @@ async def upload_user_content( object_data = data["data"] if encrypt: - if object_type in ("stories", "storycontent", "aimodules", "shelf"): + if object_type in ("stories", "storycontent", "aimodules"): if keystore is None: raise ValueError("'keystore' is not set, cannot encrypt data") encrypt_user_data(data, keystore) - elif object_type in ("presets",): + elif object_type in ("shelf", "presets"): compress_user_data(data) # clean data introduced by decrypt_user_data @@ -255,7 +255,7 @@ async def _generate( **kwargs, ): """ - Generate text with streaming support + Generate text with streaming support. :param prompt: Context to give to the AI (raw text or list of tokens) :param model: Model to use for the AI @@ -328,6 +328,10 @@ async def generate( To decode the text, the :meth:`novelai_api.utils.b64_to_tokens` and :meth:`novelai_api.Tokenizer.Tokenizer.decode` methods should be used. + As the model accepts a complete prompt, the context building must be done before calling this function. + Any content going beyond the tokens limit will be truncated, starting from the top. + + :param prompt: Context to give to the AI (raw text or list of tokens) :param model: Model to use for the AI :param preset: Preset to use for the generation settings @@ -367,6 +371,10 @@ async def generate_stream( """ Generate text. The text is returned one token at a time, as it is generated. + As the model accepts a complete prompt, the context building must be done before calling this function. + Any content going beyond the tokens limit will be truncated, starting from the top. + + :param prompt: Context to give to the AI (raw text or list of tokens) :param model: Model to use for the AI :param preset: Preset to use for the generation settings diff --git a/novelai_api/_low_level.py b/novelai_api/_low_level.py index 8047745..7e87139 100644 --- a/novelai_api/_low_level.py +++ b/novelai_api/_low_level.py @@ -592,6 +592,8 @@ async def set_settings(self, value: str) -> bool: async for rsp, content in self.request("put", "/user/clientsettings", value): return self._treat_response_bool(rsp, content, 200) + # TODO: add submission endpoints + async def bind_subscription(self, payment_processor: str, subscription_id: str) -> bool: """ Bind payment information to the account to renew subscription monthly @@ -616,7 +618,9 @@ async def change_subscription(self, new_plan: str) -> bool: async def generate(self, prompt: Union[List[int], str], model: Model, params: Dict[str, Any], stream: bool = False): """ - Generate text with streaming support + Generate text with streaming support. As the model accepts a complete prompt, + the context building must be done before calling this function. + Any content going beyond the tokens limit will be truncated, starting from the top. :param prompt: Input to be sent the AI :param model: Model of the AI @@ -635,98 +639,146 @@ async def generate(self, prompt: Union[List[int], str], model: Model, params: Di prompt = Tokenizer.encode(model, prompt) prompt = tokens_to_b64(prompt) - args = {"input": prompt, "model": model.value, "parameters": params} + data = {"input": prompt, "model": model.value, "parameters": params} endpoint = "/ai/generate-stream" if stream else "/ai/generate" - async for rsp, content in self.request("post", endpoint, args): + async for rsp, content in self.request("post", endpoint, data): self._treat_response_object(rsp, content, 201) yield content - async def classify(self) -> NoReturn: + async def generate_image( + self, prompt: str, model: ImageModel, action: ImageGenerationType, parameters: Dict[str, Any] + ) -> AsyncIterator[Tuple[str, bytes]]: """ - Not implemented + Generate one or multiple image(s) + + :param prompt: Prompt for the image + :param model: Model to generate the image + :param action: Type of image generation to use + :param parameters: Parameters for the images + + :return: (name, data) pairs for the raw PNG image(s) """ - raise NotImplementedError("Function is not implemented yet") + assert_type(str, prompt=prompt) + assert_type(ImageModel, model=model) + assert_type(dict, parameters=parameters) - async def train_module(self, data: str, rate: int, steps: int, name: str, desc: str) -> Dict[str, Any]: + data = { + "input": prompt, + "model": model.value, + "action": action.value, + "parameters": parameters, + } + + async for rsp, content in self.request("post", "/ai/generate-image", data): + self._treat_response_object(rsp, content, 200) + + yield content + + async def generate_prompt(self, model: Model, prompt: str, temp: float, length: int) -> Dict[str, Any]: """ - Train a module for text gen + Generate a prompt - :param data: Dataset of the module, in one single string - :param rate: Learning rate of the training - :param steps: Number of steps to train the module for - :param name: Name of the module - :param desc: Description of the module + :param model: Model to use for the prompt + :param prompt: Prompt to base the generation on + :param temp: Temperature for the generation + :param length: Length of the returned prompt - :return: Status of the module being trained + :return: Generated prompt """ - assert_type(str, data=data, name=name, desc=desc) - assert_type(int, rate=rate, steps=steps) + assert_type(Model, model=model) + assert_type(str, prompt=prompt) + assert_type(float, temp=temp) + assert_type(int, length=length) - params = { - "data": data, - "lr": rate, - "steps": steps, - "name": name, - "description": desc, + data = { + "model": model.value, + "prompt": prompt, + "temp": temp, + "tokens_to_generate": length, } - async for rsp, content in self.request("post", "/ai/module/train", params): + async for rsp, content in self.request("post", "/ai/generate-prompt", data): self._treat_response_object(rsp, content, 201) - # TODO: verify response ? - return content - async def get_trained_modules(self) -> List[Dict[str, Any]]: + async def generate_controlnet_mask(self, model: ControlNetModel, image: str) -> Tuple[str, bytes]: """ - Get the modules currently in training or that finished training + Get the ControlNet's mask for the given image. Used for ImageSampler.controlnet_condition + + :param model: ControlNet model to use + :param image: b64 encoded PNG image to get the mask of + + :return: A pair (name, data) for the raw PNG image """ - async for rsp, content in self.request("get", "/ai/module/all"): - self._treat_response_object(rsp, content, 200) + assert_type(ControlNetModel, model=model) + assert_type(str, image=image) - if self.is_schema_validation_enabled: - SchemaValidator.validate("schema_AiModuleDtos", content) + data = {"model": model.value, "parameters": {"image": image}} + + async for rsp, content in self.request("post", "/ai/annotate-image", data): + self._treat_response_object(rsp, content, 200) return content - async def get_trained_module(self, module_id: str) -> Dict[str, Any]: + async def upscale_image(self, image: str, width: int, height: int, scale: int) -> Tuple[str, bytes]: """ - Get a module currently in training or that finished training + Upscale the given image. Afaik, the only allowed values for scale are 2 and 4. - :param module_id: Id of the selected module + :param image: b64 encoded PNG image to upscale + :param width: Width of the starting image + :param height: Height of the starting image + :param scale: Upscaling factor (final width = starting width * scale, final height = starting height * scale) + + :return: A pair (name, data) for the raw PNG image """ - assert_type(str, module_id=module_id) + assert_type(str, image=image) + assert_type(int, width=width, height=height, scale=scale) - async for rsp, content in self.request("get", f"/ai/module/{module_id}"): - self._treat_response_object(rsp, content, 200) + data = {"image": image, "width": width, "height": height, "scale": scale} - if self.is_schema_validation_enabled: - SchemaValidator.validate("schema_AiModuleDto", content) + async for rsp, content in self.request("post", "/ai/upscale", data): + self._treat_response_object(rsp, content, 200) return content - async def delete_module(self, module_id: str) -> Dict[str, Any]: + async def classify(self) -> NoReturn: + """ + Not implemented """ - Delete a module currently in training or that finished training - :param module_id: Id of the selected module + raise NotImplementedError("Function is not implemented yet") - :return: Module that got deleted + async def suggest_tags(self, tag: str, model: ImageModel) -> Dict[str, Any]: """ + Suggest tags with a certain confidence, considering how much the tag is used in the dataset - assert_type(str, module_id=module_id) + :param tag: Tag to suggest others of + :param model: Image model to get the tags from - async for rsp, content in self.request("delete", f"/ai/module/{module_id}"): - self._treat_response_object(rsp, content, 200) + :return: List of similar tags with a confidence level + """ - # TODO: verify response ? + assert_type(str, tag=tag) + assert_type(ImageModel, model=model) + + query = urlencode( + { + "model": model.value, + "prompt": tag, + }, + quote_via=quote, + ) + + async for rsp, content in self.request("get", f"/ai/generate-image/suggest-tags?{query}"): + self._treat_response_object(rsp, content, 200) return content @@ -735,8 +787,8 @@ async def generate_voice(self, text: str, seed: str, voice: int, opus: bool, ver Generate the Text To Speech of the given text :param text: Text to synthesize into voice (text will be cut to 1000 characters backend-side) - :param seed: Voice to use - :param voice: Index of the voice to use + :param seed: Voice to use (TTS v2 only) + :param voice: Index of the voice to use (TTS v1 only) :param opus: True for WebM format, False for mp3 format :param version: Version of the TTS ("v1" or "v2") @@ -765,129 +817,93 @@ async def generate_voice(self, text: str, seed: str, voice: int, opus: bool, ver return content - async def suggest_tags(self, tag: str, model: ImageModel) -> Dict[str, Any]: + async def train_module(self, data: str, rate: int, steps: int, name: str, desc: str) -> Dict[str, Any]: """ - Suggest tags with a certain confidence, considering how much the tag is used in the dataset + Train a module for text gen - :param tag: Tag to suggest others of - :param model: Image model to get the tags from + :param data: Dataset of the module, in one single string + :param rate: Learning rate of the training + :param steps: Number of steps to train the module for + :param name: Name of the module + :param desc: Description of the module - :return: List of similar tags with a confidence level + :return: Status of the module being trained """ - assert_type(str, tag=tag) - assert_type(ImageModel, model=model) + assert_type(str, data=data, name=name, desc=desc) + assert_type(int, rate=rate, steps=steps) - query = urlencode( - { - "model": model.value, - "prompt": tag, - }, - quote_via=quote, - ) + params = { + "data": data, + "lr": rate, + "steps": steps, + "name": name, + "description": desc, + } - async for rsp, content in self.request("get", f"/ai/generate-image/suggest-tags?{query}"): - self._treat_response_object(rsp, content, 200) + async for rsp, content in self.request("post", "/ai/module/train", params): + self._treat_response_object(rsp, content, 201) + + # TODO: verify response ? return content - async def generate_image( - self, prompt: str, model: ImageModel, action: ImageGenerationType, parameters: Dict[str, Any] - ) -> AsyncIterator[Tuple[str, bytes]]: + async def get_trained_modules(self) -> List[Dict[str, Any]]: """ - Generate one or multiple image(s) - - :param prompt: Prompt for the image - :param model: Model to generate the image - :param action: Type of image generation to use - :param parameters: Parameters for the images - - :return: (name, data) pairs for the raw PNG image(s) + Get the modules currently in training or that finished training """ - assert_type(str, prompt=prompt) - assert_type(ImageModel, model=model) - assert_type(dict, parameters=parameters) - - args = { - "input": prompt, - "model": model.value, - "action": action.value, - "parameters": parameters, - } - - async for rsp, content in self.request("post", "/ai/generate-image", args): + async for rsp, content in self.request("get", "/ai/module/all"): self._treat_response_object(rsp, content, 200) - yield content + if self.is_schema_validation_enabled: + SchemaValidator.validate("schema_AiModuleDtos", content) - async def generate_prompt(self, model: Model, prompt: str, temp: float, length: int) -> Dict[str, Any]: - """ - Generate a prompt + return content - :param model: Model to use for the prompt - :param prompt: Prompt to base the generation on - :param temp: Temperature for the generation - :param length: Length of the returned prompt + async def get_trained_module(self, module_id: str) -> Dict[str, Any]: + """ + Get a module currently in training or that finished training - :return: Generated prompt + :param module_id: Id of the selected module """ - assert_type(Model, model=model) - assert_type(str, prompt=prompt) - assert_type(float, temp=temp) - assert_type(int, length=length) + assert_type(str, module_id=module_id) - args = { - "model": model.value, - "prompt": prompt, - "temp": temp, - "tokens_to_generate": length, - } + async for rsp, content in self.request("get", f"/ai/module/{module_id}"): + self._treat_response_object(rsp, content, 200) - async for rsp, content in self.request("post", "/ai/generate-prompt", args): - self._treat_response_object(rsp, content, 201) + if self.is_schema_validation_enabled: + SchemaValidator.validate("schema_AiModuleDto", content) return content - async def generate_controlnet_mask(self, model: ControlNetModel, image: str) -> Tuple[str, bytes]: + async def delete_module(self, module_id: str) -> Dict[str, Any]: """ - Get the ControlNet's mask for the given image. Used for ImageSampler.controlnet_condition + Delete a module currently in training or that finished training - :param model: ControlNet model to use - :param image: b64 encoded PNG image to get the mask of + :param module_id: Id of the selected module - :return: A pair (name, data) for the raw PNG image + :return: Module that got deleted """ - assert_type(ControlNetModel, model=model) - assert_type(str, image=image) - - args = {"model": model.value, "parameters": {"image": image}} + assert_type(str, module_id=module_id) - async for rsp, content in self.request("post", "/ai/annotate-image", args): + async for rsp, content in self.request("delete", f"/ai/module/{module_id}"): self._treat_response_object(rsp, content, 200) - return content - - async def upscale_image(self, image: str, width: int, height: int, scale: int) -> Tuple[str, bytes]: - """ - Upscale the given image. Afaik, the only allowed values for scale are 2 and 4. - - :param image: b64 encoded PNG image to upscale - :param width: Width of the starting image - :param height: Height of the starting image - :param scale: Upscaling factor (final width = starting width * scale, final height = starting height * scale) + # TODO: verify response ? - :return: A pair (name, data) for the raw PNG image - """ + return content - assert_type(str, image=image) - assert_type(int, width=width, height=height, scale=scale) + async def buy_steps(self, amount: int): + assert_type(int, amount=amount) - args = {"image": image, "width": width, "height": height, "scale": scale} + data = {"amount": amount} - async for rsp, content in self.request("post", "/ai/upscale", args): + async for rsp, content in self.request("delete", "/ai/module/buy-training-steps", data): self._treat_response_object(rsp, content, 200) + # TODO: verify response ? + return content