Skip to content

Commit

Permalink
Add content-type filter method to folder_paths (#4054)
Browse files Browse the repository at this point in the history
* Add content-type filter method to folder_paths

* Add unit tests

* Hardcode webp content-type

* Annotate content_types as Literal["image", "video", "audio"]
  • Loading branch information
christian-byrne authored Sep 11, 2024
1 parent 36c83cd commit e760bf5
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 8 deletions.
9 changes: 1 addition & 8 deletions comfy_extras/nodes_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,10 @@ def INPUT_TYPES(s):
}

class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')

@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
]
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}

CATEGORY = "audio"
Expand Down
28 changes: 28 additions & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection

supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
Expand Down Expand Up @@ -44,6 +46,10 @@

filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}

extension_mimetypes_cache = {
"webp" : "image",
}

def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
Expand Down Expand Up @@ -89,6 +95,28 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory()
return None

def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
result = []
for file in files:
extension = file.split('.')[-1]
if extension not in extension_mimetypes_cache:
mime_type, _ = mimetypes.guess_type(file, strict=False)
if not mime_type:
continue
content_type = mime_type.split('/')[0]
extension_mimetypes_cache[extension] = content_type
else:
content_type = extension_mimetypes_cache[extension]

if content_type in content_types:
result.append(file)
return result

# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
Expand Down
Empty file.
52 changes: 52 additions & 0 deletions tests-unit/folder_paths_test/filter_by_content_types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types

@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}


@pytest.fixture(scope="module")
def mock_dir(file_extensions):
with tempfile.TemporaryDirectory() as directory:
for content_type, extensions in file_extensions.items():
for extension in extensions:
with open(f"{directory}/sample_{content_type}.{extension}", "w") as f:
f.write(f"Sample {content_type} file in {extension} format")
yield directory


def test_categorizes_all_correctly(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
for extension in extensions:
assert f"sample_{content_type}.{extension}" in filtered_files


def test_categorizes_all_uniquely(mock_dir, file_extensions):
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
assert len(filtered_files) == len(extensions)


def test_handles_bad_extensions():
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []


def test_handles_no_extension():
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []


def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

0 comments on commit e760bf5

Please sign in to comment.