-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinternvl_dataset.py
688 lines (583 loc) · 25.5 KB
/
internvl_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import sys
IGNORE_TOKEN_ID = -100 # LabelSmoother.ignore_index
import random
from typing import Dict
from collections.abc import Sequence
import paddle
import paddle.vision.transforms as T
from paddlemix.models.internvl2.conversation import get_conv_template
from PIL import Image
from paddle.io import ConcatDataset, WeightedRandomSampler
from paddlemix.models.internvl2.constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
SIGLIP_MEAN, SIGLIP_STD)
class WeightedConcatDataset(ConcatDataset):
def __init__(self, datasets, weights):
super().__init__(datasets)
self.weights = paddle.to_tensor(weights, dtype='float32')
self.total_size = sum(len(d) for d in datasets)
self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
def __iter__(self):
return iter(self.sampler)
def __len__(self):
return self.total_size
def pil_loader(img_str):
buff = io.BytesIO(img_str)
img = Image.open(buff)
return img.convert('RGB')
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def simulate_jpeg_degradation(quality):
def jpeg_degrade(img):
with io.BytesIO() as output:
img.convert('RGB').save(output, format='JPEG', quality=quality)
output.seek(0) # Move the reading cursor to the start of the stream
img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
return img_jpeg
return jpeg_degrade
# Define the JPEG compression quality range, pre-create all JPEG compression functions
qualities = list(range(75, 101))
jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
class Lambda:
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
#_log_api_usage_once(self)
if not callable(lambd):
raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
self.lambd = lambd
def __call__(self, img):
return self.lambd(img)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
class RandomTransforms:
"""Base class for a list of transformations with randomness
Args:
transforms (sequence): list of transformations
"""
def __init__(self, transforms):
#_log_api_usage_once(self)
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument p should be a sequence")
self.p = p
def __call__(self, *args):
t = random.choices(self.transforms, weights=self.p)[0]
return t(*args)
def __repr__(self) -> str:
return f"{super().__repr__()}(p={self.p})"
def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
if normalize_type == 'imagenet':
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
elif normalize_type == 'clip':
MEAN, STD = CLIP_MEAN, CLIP_STD
elif normalize_type == 'siglip':
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
else:
raise NotImplementedError
if is_train: # use data augumentation
transform = T.Compose([
Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
RandomChoice([Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
T.Resize((input_size, input_size), interpolation='bicubic'),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
else:
if pad2square is False: # now we use this transform function by default
# run this
transform = T.Compose([
Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation='bicubic'),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
else:
transform = T.Compose([
Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
T.Resize((input_size, input_size), interpolation='bicubic'),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def preprocess(
template_name,
sources,
tokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1,
):
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pd',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ': '
for conversation, target in zip(conversations, targets):
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum())
turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
instruction_len -= 1
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
cur_len -= 1
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
logger.info(tokenizer.decode(z))
exit()
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
)
def preprocess_mpt(
template_name,
sources,
tokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pd',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
for conversation, target in zip(conversations, targets):
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids) + 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
)
def preprocess_phi3(
template_name,
sources,
tokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
tokenizer.padding_side = 'right'
input_ids = tokenizer(
conversations,
return_tensors='pd',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
for conversation, target in zip(conversations, targets):
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
target[target == endoftext_id] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
if i == 0:
turn_len = len(tokenizer(turn).input_ids)
else:
turn_len = len(tokenizer(turn).input_ids) - 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if i == 0:
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
else:
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
)
def preprocess_internlm(
template_name,
sources,
tokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
sentence['value'] = sentence['value'].strip()
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pd',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
new_targets = []
# print('tokenizer.pad_token_id:\n', tokenizer.pad_token_id) # 151643
# print('targets', targets, targets.shape, targets.sum().item())
# [[151644, 8948 , 198 , ..., 103978, 1773 , 151645]] [1, 1918] 281157253
for conversation, target in zip(conversations, targets):
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum()) # 浦语里面 pad_token_id = eos_token_id
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID # <s>
parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
info = parts[0] + conv.roles[1]
temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s>
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
for index in range(1, len(parts) - 1):
info = parts[index]
part1, part2 = info.split(conv.roles[0])
temp_len = len(tokenizer(part1).input_ids) - 1
cur_len = cur_len + temp_len
part = conv.roles[0] + part2 + conv.roles[1]
temp_len = len(tokenizer(part).input_ids) - 1
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
last_info = parts[-1]
temp_len = len(tokenizer(last_info).input_ids) - 1
cur_len = cur_len + temp_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
sys.stdout.flush()
new_targets.append(target)
new_targets = paddle.stack(new_targets, axis=0)
return dict(
input_ids=input_ids,
labels=new_targets,
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
)
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_target_aspect_ratio=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
if return_target_aspect_ratio:
return processed_images, target_aspect_ratio
else:
return processed_images
def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
new_target_ratios = []
if prior_aspect_ratio is not None:
for i in target_ratios:
if prior_aspect_ratio[0]%i[0] != 0 and prior_aspect_ratio[1]%i[1] != 0:
new_target_ratios.append(i)
else:
continue
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images