11import operator
2+ import re
23from fnmatch import fnmatchcase
34
45from PySide6 .QtCore import QModelIndex , QSortFilterProxyModel , Qt
89from utils .image import Image
910import utils .target_dimension as target_dimension
1011
12+ comparison_operators = {
13+ '=' : operator .eq ,
14+ '==' : operator .eq ,
15+ '!=' : operator .ne ,
16+ '<' : operator .lt ,
17+ '>' : operator .gt ,
18+ '<=' : operator .le ,
19+ '>=' : operator .ge
20+ }
21+
22+
1123class ProxyImageListModel (QSortFilterProxyModel ):
1224 def __init__ (self , image_list_model : ImageListModel ,
1325 tokenizer : PreTrainedTokenizerBase , tag_separator : str ):
@@ -34,7 +46,23 @@ def does_image_match_filter(self, image: Image,
3446 caption = self .tag_separator .join (image .tags )
3547 return fnmatchcase (caption , f'*{ filter_ [1 ]} *' )
3648 if filter_ [0 ] == 'marking' :
37- return any (fnmatchcase (markings .label , filter_ [1 ]) for markings in image .markings )
49+ last_colon_index = filter_ [1 ].rfind (':' )
50+ if last_colon_index < 0 :
51+ return any (fnmatchcase (marking .label , filter_ [1 ])
52+ for marking in image .markings )
53+ else :
54+ label = filter_ [1 ][:last_colon_index ]
55+ confidence = filter_ [1 ][last_colon_index + 1 :]
56+ pattern = r'^(<=|>=|==|<|>|=)\s*(0?[.,][0-9]+)'
57+ match = re .match (pattern , confidence )
58+ if not match or len (match .group (2 )) == 0 :
59+ return False
60+ comparison_operator = comparison_operators [match .group (1 )]
61+ confidence_target = float (match .group (2 ).replace (',' , '.' ))
62+ return any ((fnmatchcase (marking .label , label ) and
63+ comparison_operator (marking .confidence ,
64+ confidence_target ))
65+ for marking in image .markings )
3866 if filter_ [0 ] == 'name' :
3967 return fnmatchcase (image .path .name , f'*{ filter_ [1 ]} *' )
4068 if filter_ [0 ] == 'path' :
@@ -48,7 +76,7 @@ def does_image_match_filter(self, image: Image,
4876 if filter_ [0 ] == 'target' :
4977 # accept any dimension separator of [x:]
5078 dimension = (filter_ [1 ]).replace (':' , 'x' ).split ('x' )
51- if image .target_dimension == None :
79+ if image .target_dimension is None :
5280 image .target_dimension = target_dimension .get (image .dimensions )
5381 return (len (dimension ) == 2
5482 and dimension [0 ] == str (image .target_dimension .width ())
@@ -59,15 +87,6 @@ def does_image_match_filter(self, image: Image,
5987 if filter_ [1 ] == 'OR' :
6088 return (self .does_image_match_filter (image , filter_ [0 ])
6189 or self .does_image_match_filter (image , filter_ [2 :]))
62- comparison_operators = {
63- '=' : operator .eq ,
64- '==' : operator .eq ,
65- '!=' : operator .ne ,
66- '<' : operator .lt ,
67- '>' : operator .gt ,
68- '<=' : operator .le ,
69- '>=' : operator .ge
70- }
7190 comparison_operator = comparison_operators [filter_ [1 ]]
7291 number_to_compare = None
7392 if filter_ [0 ] == 'tags' :
@@ -79,9 +98,9 @@ def does_image_match_filter(self, image: Image,
7998 caption = self .tag_separator .join (image .tags )
8099 # Subtract 2 for the `<|startoftext|>` and `<|endoftext|>` tokens.
81100 number_to_compare = len (self .tokenizer (caption ).input_ids ) - 2
82- elif filter_ [0 ] == 'x ' :
101+ elif filter_ [0 ] == 'width ' :
83102 number_to_compare = image .dimensions [0 ]
84- elif filter_ [0 ] == 'y ' :
103+ elif filter_ [0 ] == 'height ' :
85104 number_to_compare = image .dimensions [1 ]
86105 return comparison_operator (number_to_compare , int (filter_ [2 ]))
87106
0 commit comments