-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder_pipeline.py
117 lines (88 loc) · 3.67 KB
/
encoder_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import List, Optional
import numbers
import os
import torch
from torchvision import transforms
from torchvision.transforms.functional import pad
from PIL import Image
import skimage.io
from looseless_compressors import LooselessCompressor, Huffman
class PadDivisibleBy32(object):
def __init__(self, fill=0, padding_mode='constant'):
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be padded.
Returns:
PIL Image: Padded image.
"""
return pad(img, self._get_padding(img), self.fill, self.padding_mode)
def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
format(self.fill, self.padding_mode)
@staticmethod
def _get_padding(image):
ch, w, h = image.shape
w_pad = w%32
h_pad = h%32
l_pad = w_pad//2 + w_pad%2
r_pad = w_pad//2
t_pad = h_pad//2 + h_pad%2
b_pad = h_pad//2
return int(l_pad), int(t_pad), int(r_pad), int(b_pad)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
imagenet_normalize = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
inference_data_transform = transforms.Compose([
# If we use resnet autoencoder the image size should be a multiple of 32.
# Otherwise the decoded shape would be different from original.
# PadDivisibleBy32(),
transforms.ToTensor(),
imagenet_normalize
])
def img_path_to_model_input(img_path: str,
inference_data_transform = inference_data_transform) -> torch.Tensor:
image = Image.fromarray(skimage.io.imread(img_path))
image = inference_data_transform(image)
return image
def quantize(encoder_out: torch.Tensor, B: int):
quantized = torch.round(encoder_out * 2**B)
return quantized.type(torch.int8)
def save_binary_string_to_file(binary_string, filename):
"""
Converts a string of 0 and 1 to a bytearray and writes it into file
"""
# We add 1 to the end to be able to find the end of the string when decoding.
# Otherwise it won't be possible to know if the zeros at the end are padding
# or correct data.
binary_string += '1'
if len(binary_string) % 8 != 0:
# Pad the string with 0s to the nearest multiple of 8
padding = 8 - len(binary_string) % 8
binary_string += '0' * padding
binary_bytes = bytearray(
int(binary_string[i:i+8], 2) for i in range(0, len(binary_string), 8))
# Write the bytes to the file
with open(filename, 'wb') as f:
f.write(binary_bytes)
def encoder_pipeline(encoder, img_path: str, B: int,
compressor_state_path: Optional[str] = None,
compressed_img_path: Optional[str] = None,
looseless_compressor: LooselessCompressor = Huffman()):
encoder.eval()
img = img_path_to_model_input(img_path)
unsqueezed = img.unsqueeze(0)
encoder_out = encoder(unsqueezed)
quantized = quantize(encoder_out, B)
flat_out = [int(x) for x in quantized.flatten()]
looseless_compressor.init_from_sequence(flat_out)
if compressor_state_path is not None:
looseless_compressor.save_state_to_file(compressor_state_path)
encoded = looseless_compressor.encode(flat_out)
if compressed_img_path is not None:
save_binary_string_to_file(encoded, compressed_img_path)
return encoder_out, encoded # for debug purposes only