diff --git a/requirements.txt b/requirements.txt index 2fc6dbe8..14056409 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,7 @@ numpy==1.26.4 # WD Tagger huggingface-hub==0.26.2 onnxruntime==1.19.2 +onnxruntime-directml==1.19.2; platform_system == "Windows" # FlashAttention (Florence-2, Phi-3-Vision) flash-attn==2.6.3; platform_system == "Linux" diff --git a/taggui/auto_captioning/models/wd_tagger.py b/taggui/auto_captioning/models/wd_tagger.py index a2d5dc30..3614bd64 100644 --- a/taggui/auto_captioning/models/wd_tagger.py +++ b/taggui/auto_captioning/models/wd_tagger.py @@ -37,7 +37,7 @@ def __init__(self, model_id: str): if not tags_path.is_file(): tags_path = huggingface_hub.hf_hub_download( model_id, filename='selected_tags.csv') - self.inference_session = InferenceSession(model_path) + self.inference_session = InferenceSession(model_path, providers=['DmlExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) self.tags = [] self.rating_tags_indices = [] self.general_tags_indices = []