-
Notifications
You must be signed in to change notification settings - Fork 1
/
filter_utils.py
101 lines (77 loc) · 2.7 KB
/
filter_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class Filter(object):
"""
https://www.w3.org/TR/2003/REC-PNG-20031110/#9-table91
interface for filter
x: the byte being filtered;
a: the byte corresponding to x in the pixel before the pixel containing x
b: the byte corresponding to x in the previous scanline;
c: the byte corresponding to b in the pixel before the pixel containing b
"""
def filter(self, x, a, b, c):
raise NotImplementedError
def reconstruct(self, x, a, b, c):
raise NotImplementedError
class NoneFilter(Filter):
def filter(self, x, a, b, c):
return x
def reconstruct(self, x, a, b, c):
return x
class SubFilter(Filter):
def filter(self, x, a, b, c):
return (x - a) % 256
def reconstruct(self, x, a, b, c):
return (x + a) % 256
class UpFilter(Filter):
def filter(self, x, a, b, c):
return (x - b) % 256
def reconstruct(self, x, a, b, c):
return (x + b) % 256
class AverageFilter(Filter):
def filter(self, x, a, b, c):
return (x - (a + b) // 2) % 256
def reconstruct(self, x, a, b, c):
return (x + (a + b) // 2) % 256
class PaethFilter(Filter):
def filter(self, x, a, b, c):
return (x - self.paeth_predictor(a, b, c)) % 256
def reconstruct(self, x, a, b, c):
return (x + self.paeth_predictor(a, b, c)) % 256
@staticmethod
def paeth_predictor(a, b, c):
"""https://www.w3.org/TR/2003/REC-PNG-20031110/#9-figure91"""
p = a + b - c
pa = abs(p - a)
pb = abs(p - b)
pc = abs(p - c)
if pa <= pb and pa <= pc:
pr = a
elif pb <= pc:
pr = b
else:
pr = c
return pr
FILTER = {
0: NoneFilter(),
1: SubFilter(),
2: UpFilter(),
3: AverageFilter(),
4: PaethFilter()
}
def filter_scanline(filter, scanline, prev_scanline=None, pixel_size=1):
"""apply filter to scanline (aka line of pixels)"""
res = []
for i in range(len(scanline)):
a = scanline[i - pixel_size] if i >= pixel_size else 0
b = prev_scanline[i] if prev_scanline else 0
c = prev_scanline[i - pixel_size] if prev_scanline and i >= pixel_size else 0
res.append(filter.filter(scanline[i], a, b, c))
return res
def reconstruct_scanline(filter, scanline, prev_scanline=None, pixel_size=1):
"""reconstruct filtered scanline (aka line of pixels)"""
res = []
for i in range(len(scanline)):
a = res[i - pixel_size] if i >= pixel_size else 0
b = prev_scanline[i] if prev_scanline else 0
c = prev_scanline[i - pixel_size] if prev_scanline and i >= pixel_size else 0
res.append(filter.reconstruct(scanline[i], a, b, c))
return res