Skip to content

Commit

Permalink
Support preview images embedded in safetensors metadata (#6119)
Browse files Browse the repository at this point in the history
* Support preview images embedded in safetensors metadata

* Add unit test for safetensors embedded image previews
  • Loading branch information
catboxanon authored Dec 19, 2024
1 parent 2dda7c1 commit 3cacd3f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
29 changes: 23 additions & 6 deletions app/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
Expand Down Expand Up @@ -59,13 +62,13 @@ async def get_model_preview(request):
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)

preview_files = self.get_model_previews(full_filename)
default_preview_file = preview_files[0] if len(preview_files) > 0 else None
if default_preview_file is None or not os.path.isfile(default_preview_file):
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)

try:
with Image.open(default_preview_file) as img:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
Expand Down Expand Up @@ -143,7 +146,7 @@ def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list

return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()

def get_model_previews(self, filepath: str) -> list[str]:
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)

if not os.path.exists(dirname):
Expand All @@ -152,15 +155,29 @@ def get_model_previews(self, filepath: str) -> list[str]:
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}

result: list[str] = []
result: list[str | BytesIO] = []

for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)

if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))

return result

def __exit__(self, exc_type, exc_value, traceback):
Expand Down
62 changes: 62 additions & 0 deletions tests-unit/app_test/model_manager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import base64
import json
import struct
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager

pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module

@pytest.fixture
def model_manager():
return ModelFileManager()

@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app

async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
img = Image.new('RGB', (100, 100), 'white')
img_byte_arr = BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')

safetensors_file = tmp_path / "test_model.safetensors"
header_bytes = json.dumps({
"__metadata__": {
"ssmd_cover_images": json.dumps([img_b64])
}
}).encode('utf-8')
length_bytes = struct.pack('<Q', len(header_bytes))
with open(safetensors_file, 'wb') as f:
f.write(length_bytes)
f.write(header_bytes)

with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')

# Verify response
assert response.status == 200
assert response.content_type == 'image/webp'

# Verify the response contains valid image data
img_bytes = BytesIO(await response.read())
img = Image.open(img_bytes)
assert img.format
assert img.format.lower() == 'webp'

# Clean up
img.close()

0 comments on commit 3cacd3f

Please sign in to comment.