Skip to content

Commit 281bf29

Browse files
committed
Allow filtering of marking confidence.
Also rename the `x` and `y` filter to `width` and `height` as it is better describing their behaviour
1 parent 77854a3 commit 281bf29

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ apply:
144144
tag `catastrophe`.
145145
- `marking`: Images that contain at least one marking with this label. It
146146
doesn't matter whether it is a _hint_, _include_ or _exclude_ marking.
147+
- `marking:cat:>0.789` when the label is followed by a colon with a
148+
relational operator and a number then only those markings are matched
149+
where the confidence fits to the specified number, in this example case
150+
a `cat` marking must have a confidence number higher than 0.789.
147151
- `name`: Images that contain the filter term in the file name
148152
- `name:cat` will match images such as `cat-1.jpg` or `large_cat.png`.
149153
- `path`: Images that contain the filter term in the full file path
@@ -174,9 +178,9 @@ comparison.
174178
caption.
175179
- `tokens:<=50` will match images that have 50 or fewer tokens in the
176180
caption.
177-
- `x` and `y`: will match images with the specified x or y dimension.
178-
- `x:>512` will match images where the width is greater than 512 pixels.
179-
- `y:=1024` will match images where the height is exactly 1024 pixels.
181+
- `width` and `height`: will match images with the specified width or height.
182+
- `width:>512` will match images where the width is greater than 512 pixels.
183+
- `height:=1024` will match images where the height is exactly 1024 pixels.
180184

181185
### Spaces and quotes
182186

taggui/auto_marking/marking_thread.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ def generate_output(self, image_index, image: Image, image_prompt, model_inputs)
5858
markings.append({'box': box,
5959
'label': marking[0],
6060
'type': marking[1],
61-
'confidence': confidence})
61+
'confidence': round(confidence, 3)})
6262
self.marking_generated.emit(image_index, markings)
6363
return f'Found {len(markings)} marking(s).'

taggui/models/proxy_image_list_model.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
import re
23
from fnmatch import fnmatchcase
34

45
from PySide6.QtCore import QModelIndex, QSortFilterProxyModel, Qt
@@ -8,6 +9,17 @@
89
from utils.image import Image
910
import utils.target_dimension as target_dimension
1011

12+
comparison_operators = {
13+
'=': operator.eq,
14+
'==': operator.eq,
15+
'!=': operator.ne,
16+
'<': operator.lt,
17+
'>': operator.gt,
18+
'<=': operator.le,
19+
'>=': operator.ge
20+
}
21+
22+
1123
class ProxyImageListModel(QSortFilterProxyModel):
1224
def __init__(self, image_list_model: ImageListModel,
1325
tokenizer: PreTrainedTokenizerBase, tag_separator: str):
@@ -34,7 +46,23 @@ def does_image_match_filter(self, image: Image,
3446
caption = self.tag_separator.join(image.tags)
3547
return fnmatchcase(caption, f'*{filter_[1]}*')
3648
if filter_[0] == 'marking':
37-
return any(fnmatchcase(markings.label, filter_[1]) for markings in image.markings)
49+
last_colon_index = filter_[1].rfind(':')
50+
if last_colon_index < 0:
51+
return any(fnmatchcase(marking.label, filter_[1])
52+
for marking in image.markings)
53+
else:
54+
label = filter_[1][:last_colon_index]
55+
confidence = filter_[1][last_colon_index + 1:]
56+
pattern =r'^(<=|>=|==|<|>|=)\s*(0?[.,][0-9]+)'
57+
match = re.match(pattern, confidence)
58+
if not match or len(match.group(2)) == 0:
59+
return False
60+
comparison_operator = comparison_operators[match.group(1)]
61+
confidence_target = float(match.group(2).replace(',', '.'))
62+
return any((fnmatchcase(marking.label, label) and
63+
comparison_operator(marking.confidence,
64+
confidence_target))
65+
for marking in image.markings)
3866
if filter_[0] == 'name':
3967
return fnmatchcase(image.path.name, f'*{filter_[1]}*')
4068
if filter_[0] == 'path':
@@ -48,7 +76,7 @@ def does_image_match_filter(self, image: Image,
4876
if filter_[0] == 'target':
4977
# accept any dimension separator of [x:]
5078
dimension = (filter_[1]).replace(':', 'x').split('x')
51-
if image.target_dimension == None:
79+
if image.target_dimension is None:
5280
image.target_dimension = target_dimension.get(image.dimensions)
5381
return (len(dimension) == 2
5482
and dimension[0] == str(image.target_dimension.width())
@@ -59,15 +87,6 @@ def does_image_match_filter(self, image: Image,
5987
if filter_[1] == 'OR':
6088
return (self.does_image_match_filter(image, filter_[0])
6189
or self.does_image_match_filter(image, filter_[2:]))
62-
comparison_operators = {
63-
'=': operator.eq,
64-
'==': operator.eq,
65-
'!=': operator.ne,
66-
'<': operator.lt,
67-
'>': operator.gt,
68-
'<=': operator.le,
69-
'>=': operator.ge
70-
}
7190
comparison_operator = comparison_operators[filter_[1]]
7291
number_to_compare = None
7392
if filter_[0] == 'tags':
@@ -79,9 +98,9 @@ def does_image_match_filter(self, image: Image,
7998
caption = self.tag_separator.join(image.tags)
8099
# Subtract 2 for the `<|startoftext|>` and `<|endoftext|>` tokens.
81100
number_to_compare = len(self.tokenizer(caption).input_ids) - 2
82-
elif filter_[0] == 'x':
101+
elif filter_[0] == 'width':
83102
number_to_compare = image.dimensions[0]
84-
elif filter_[0] == 'y':
103+
elif filter_[0] == 'height':
85104
number_to_compare = image.dimensions[1]
86105
return comparison_operator(number_to_compare, int(filter_[2]))
87106

taggui/widgets/image_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self):
5656
+ optionally_quoted_string)
5757
for key in string_filter_keys]
5858
comparison_operator = one_of('= == != < > <= >=')
59-
number_filter_keys = ['tags', 'chars', 'tokens', 'x', 'y']
59+
number_filter_keys = ['tags', 'chars', 'tokens', 'width', 'height']
6060
number_filter_expressions = [Group(CaselessLiteral(key) + Suppress(':')
6161
+ comparison_operator + Word(nums))
6262
for key in number_filter_keys]

0 commit comments

Comments
 (0)