Skip to content

Commit

Permalink
[API] Migrate Image Gen to new endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed Mar 20, 2024
1 parent ab93e75 commit 1541fbc
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 20 deletions.
2 changes: 2 additions & 0 deletions novelai_api/ImagePreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ class ImagePreset:

#: Use the old behavior of prompt separation at the 75 tokens mark (can cut words in half)
legacy_v3_extend: bool
#: ???
params_version: int

_settings: Dict[str, Any]

Expand Down
1 change: 1 addition & 0 deletions novelai_api/NovelAI_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class NovelAIAPI:
# Constants

# TODO: might want to make the base endpoint configurable
#: The base address for the API
BASE_ADDRESS: str = "https://api.novelai.net"
LIB_ROOT: str = dirname(abspath(__file__))
Expand Down
44 changes: 28 additions & 16 deletions novelai_api/_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

PRINT_WITH_PARAMETERS = os.environ.get("NAI_PRINT", False)

IMAGE_API_ADDRESS = "https://image.novelai.net"


# === INTERNALS === #
SSE_FIELDS = ["event", "data", "id", "retry"]
Expand All @@ -38,15 +40,15 @@ def print_with_parameters(args: Dict[str, Any]):
if "parameters" in a:
a["parameters"] = {k: str(v) for k, v in a["parameters"].items()}

for k in ["image", "mask", "controlnet_condition"]:
if k in a["parameters"]:
a["parameters"][k] = (
f"{a['parameters'][k][:10]}...{a['parameters'][k][-10:]}"
if 30 < len(a["parameters"][k])
else a["parameters"][k]
)
for k in ["image", "mask", "controlnet_condition"]:
if k in a["parameters"]:
a["parameters"][k] = (
f"{a['parameters'][k][:10]}...{a['parameters'][k][-10:]}"
if 30 < len(a["parameters"][k])
else a["parameters"][k]
)

print(json.dumps(a, indent=4))
print(json.dumps(a, indent=4, sort_keys=True))


# === API === #
Expand Down Expand Up @@ -181,7 +183,7 @@ async def _parse_response(cls, rsp: ClientResponse):
elif content_type in ("audio/mpeg", "audio/webm"):
yield await rsp.read()

elif content_type == "application/x-zip-compressed":
elif content_type in ("application/x-zip-compressed", "binary/octet-stream"):
z = zipfile.ZipFile(io.BytesIO(await rsp.read()))
for name in z.namelist():
yield name, z.read(name)
Expand All @@ -195,16 +197,29 @@ async def _parse_response(cls, rsp: ClientResponse):
url: str = rsp.url if isinstance(rsp.url, str) else rsp.url.human_repr()
raise NovelAIError(url, -1, f"Unsupported type: {rsp.content_type}")

async def request(self, method: str, endpoint: str, data: Optional[Union[Dict[str, Any], str]] = None):
async def request(
self,
method: str,
endpoint: str,
data: Optional[Union[Dict[str, Any], str]] = None,
custom_base_address: Union[str, None] = None,
):
"""
Send request with support for data streaming
:param method: Method of the request (get, post, delete)
:param endpoint: Endpoint of the request
:param data: Data to pass to the method if needed
:param custom_base_address: Custom address to use for the request
"""

url = f"{self._parent.BASE_ADDRESS}{endpoint}"
if PRINT_WITH_PARAMETERS:
print_with_parameters(data)

if custom_base_address is None:
custom_base_address = self._parent.BASE_ADDRESS

url = f"{custom_base_address}{endpoint}"

is_sync = self._parent.session is None
session = ClientSession() if is_sync else self._parent.session
Expand Down Expand Up @@ -647,9 +662,6 @@ async def generate(self, prompt: Union[List[int], str], model: Model, params: Di

endpoint = "/ai/generate-stream" if stream else "/ai/generate"

if PRINT_WITH_PARAMETERS:
print_with_parameters(data)

async for rsp, content in self.request("post", endpoint, data):
self._treat_response_object(rsp, content, 201)

Expand Down Expand Up @@ -680,7 +692,7 @@ async def generate_image(
"parameters": parameters,
}

async for rsp, content in self.request("post", "/ai/generate-image", data):
async for rsp, content in self.request("post", "/ai/generate-image", data, IMAGE_API_ADDRESS):
self._treat_response_object(rsp, content, 200)

yield content
Expand Down Expand Up @@ -784,7 +796,7 @@ async def suggest_tags(self, tag: str, model: ImageModel) -> Dict[str, Any]:
quote_via=quote,
)

async for rsp, content in self.request("get", f"/ai/generate-image/suggest-tags?{query}"):
async for rsp, content in self.request("get", f"/ai/generate-image/suggest-tags?{query}", IMAGE_API_ADDRESS):
self._treat_response_object(rsp, content, 200)

return content
Expand Down
5 changes: 1 addition & 4 deletions novelai_api/image_presets/presets_v1/default.preset
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@
"noise": 0,
"uc_preset": "Preset_Low_Quality_Bad_Anatomy",
"quality_toggle": true,
"auto_smea": true,
"smea": false,
"smea_dyn": false,
"decrisper": false,
"controlnet_strength": 1,
"legacy": false,
"add_original_image": true,
"uncond_scale": 1,
"cfg_rescale": 0,
"noise_schedule": "native",
"legacy_v3_extend": false,
"params_version": 1,

"seed": 0,
"uc": ""
Expand Down
1 change: 1 addition & 0 deletions novelai_api/image_presets/presets_v2/default.preset
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"cfg_rescale": 0,
"noise_schedule": "native",
"legacy_v3_extend": false,
"params_version": 1,

"seed": 0,
"uc": ""
Expand Down
1 change: 1 addition & 0 deletions novelai_api/image_presets/presets_v3/default.preset
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"legacy_v3_extend": false,
"reference_information_extracted": 1,
"reference_strength": 0.6,
"params_version": 1,

"seed": 0,
"uc": ""
Expand Down

0 comments on commit 1541fbc

Please sign in to comment.