Skip to content

Commit 6c16fff

Browse files
committed
Fix set_sat
1 parent b7d665c commit 6c16fff

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

pilgram/css/blending/nonseparable.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,10 @@ def set_sat(c, s):
230230
cmid = r + g + b - cmax - cmin
231231
new_cmid = ((cmid - cmin) * s) / (cmax - cmin)
232232

233-
cmid_r = (cmax > cmin) * (cmid == r) * new_cmid
234-
cmid_g = (cmax > cmin) * (cmid == g) * new_cmid
235-
cmid_b = (cmax > cmin) * (cmid == b) * new_cmid
233+
# NOTE: use cmax if cmax == cmid
234+
cmid_r = (cmax > cmin) * (cmax > cmid) * (cmid == r) * new_cmid
235+
cmid_g = (cmax > cmin) * (cmax > cmid) * (cmid == g) * new_cmid
236+
cmid_b = (cmax > cmin) * (cmax > cmid) * (cmid == b) * new_cmid
236237

237238
cmax_r = (cmax > cmin) * (cmax == r) * s
238239
cmax_g = (cmax > cmin) * (cmax == g) * s

pilgram/css/blending/tests/test_nonseparable.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pilgram.css.blending.nonseparable import _min3, _max3
2323
from pilgram.css.blending.nonseparable import _clip_color
2424
from pilgram.css.blending.nonseparable import lum, lum_im, set_lum
25-
from pilgram.css.blending.nonseparable import sat
25+
from pilgram.css.blending.nonseparable import sat, set_sat
2626

2727

2828
def test_min3():
@@ -110,5 +110,50 @@ def test_sat():
110110
assert list(im_sat.getdata()) == [120]
111111

112112

113-
def test_set_sat():
114-
pass # TODO
113+
def test_set_sat_cmax_gt_cmin():
114+
im1 = util.fill((1, 1), [0, 128, 255])
115+
im2 = util.fill((1, 1), [64, 96, 128]) # sat = 64
116+
r1, g1, b1 = im1.split()
117+
r2, g2, b2 = im2.split()
118+
bands = ImageMath.eval(
119+
'set_sat((r1, g1, b1), sat((r2, g2, b2)))',
120+
set_sat=set_sat, sat=sat,
121+
r1=r1, g1=g1, b1=b1,
122+
r2=r2, g2=g2, b2=b2)
123+
124+
expected = [
125+
[0],
126+
[pytest.approx(32.12549019607843, abs=1)],
127+
[64],
128+
]
129+
assert [list(band.im.getdata()) for band in bands] == expected
130+
131+
132+
def test_set_sat_cmax_eq_cmid_gt_cmin():
133+
im1 = util.fill((1, 1), [0, 128, 128])
134+
im2 = util.fill((1, 1), [64, 96, 128]) # sat = 64
135+
r1, g1, b1 = im1.split()
136+
r2, g2, b2 = im2.split()
137+
bands = ImageMath.eval(
138+
'set_sat((r1, g1, b1), sat((r2, g2, b2)))',
139+
set_sat=set_sat, sat=sat,
140+
r1=r1, g1=g1, b1=b1,
141+
r2=r2, g2=g2, b2=b2)
142+
143+
expected = [[0], [64], [64]]
144+
assert [list(band.im.getdata()) for band in bands] == expected
145+
146+
147+
def test_set_sat_cmax_eq_cmin():
148+
im1 = util.fill((1, 1), [128, 128, 128])
149+
im2 = util.fill((1, 1), [64, 96, 128]) # sat = 64
150+
r1, g1, b1 = im1.split()
151+
r2, g2, b2 = im2.split()
152+
bands = ImageMath.eval(
153+
'set_sat((r1, g1, b1), sat((r2, g2, b2)))',
154+
set_sat=set_sat, sat=sat,
155+
r1=r1, g1=g1, b1=b1,
156+
r2=r2, g2=g2, b2=b2)
157+
158+
expected = [[0], [0], [0]]
159+
assert [list(band.im.getdata()) for band in bands] == expected

0 commit comments

Comments
 (0)