-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support preview images embedded in safetensors metadata (#6119)
* Support preview images embedded in safetensors metadata * Add unit test for safetensors embedded image previews
- Loading branch information
1 parent
2dda7c1
commit 3cacd3f
Showing
2 changed files
with
85 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import pytest | ||
import base64 | ||
import json | ||
import struct | ||
from io import BytesIO | ||
from PIL import Image | ||
from aiohttp import web | ||
from unittest.mock import patch | ||
from app.model_manager import ModelFileManager | ||
|
||
pytestmark = ( | ||
pytest.mark.asyncio | ||
) # This applies the asyncio mark to all test functions in the module | ||
|
||
@pytest.fixture | ||
def model_manager(): | ||
return ModelFileManager() | ||
|
||
@pytest.fixture | ||
def app(model_manager): | ||
app = web.Application() | ||
routes = web.RouteTableDef() | ||
model_manager.add_routes(routes) | ||
app.add_routes(routes) | ||
return app | ||
|
||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path): | ||
img = Image.new('RGB', (100, 100), 'white') | ||
img_byte_arr = BytesIO() | ||
img.save(img_byte_arr, format='PNG') | ||
img_byte_arr.seek(0) | ||
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | ||
|
||
safetensors_file = tmp_path / "test_model.safetensors" | ||
header_bytes = json.dumps({ | ||
"__metadata__": { | ||
"ssmd_cover_images": json.dumps([img_b64]) | ||
} | ||
}).encode('utf-8') | ||
length_bytes = struct.pack('<Q', len(header_bytes)) | ||
with open(safetensors_file, 'wb') as f: | ||
f.write(length_bytes) | ||
f.write(header_bytes) | ||
|
||
with patch('folder_paths.folder_names_and_paths', { | ||
'test_folder': ([str(tmp_path)], None) | ||
}): | ||
client = await aiohttp_client(app) | ||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors') | ||
|
||
# Verify response | ||
assert response.status == 200 | ||
assert response.content_type == 'image/webp' | ||
|
||
# Verify the response contains valid image data | ||
img_bytes = BytesIO(await response.read()) | ||
img = Image.open(img_bytes) | ||
assert img.format | ||
assert img.format.lower() == 'webp' | ||
|
||
# Clean up | ||
img.close() |