Skip to content

Commit 85f3fa3

Browse files
authored
Merge pull request #340 from youguohui/patch-1
Update dataset_transform.py
2 parents 8cbef86 + 4e117f1 commit 85f3fa3

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

dataset_transform.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,24 @@
66
from io import BytesIO
77
from sklearn.model_selection import train_test_split
88

9-
10-
# 读取csv文件
119
'''
1210
original_dataset原始数据的路径文件夹,需修改为实际的路径
1311
'''
12+
13+
#训练和验证集文本数据的文件
1414
data1 = pd.read_csv('original_dataset/data1/ImageWordData.csv')
15+
#训练和验证集图像数据的目录
16+
data1_images_folder='original_dataset/data1/ImageData'
1517

1618
# 先将文本及对应图像id划分划分训练集和验证集
1719
train_data, val_data = train_test_split(data1, test_size=0.2, random_state=42)
1820

1921
# 创建函数来处理数据集,使文本关联到其对应图像id的图像
20-
def process_train_valid(data, img_file, txt_file):
22+
def process_train_valid(data, images_folder, img_file, txt_file):
2123
with open(img_file, 'w') as f_img, open(txt_file, 'w') as f_txt:
2224
for index, row in data.iterrows():
2325
# 图片内容需要被编码为base64格式
24-
img_path = os.path.join('original_dataset/data1/ImageData', row['image_id'])
26+
img_path = os.path.join(images_folder, row['image_id'])
2527
with open(img_path, 'rb') as f_img_file:
2628
img = Image.open(f_img_file)
2729
img_buffer = BytesIO()
@@ -36,21 +38,24 @@ def process_train_valid(data, img_file, txt_file):
3638
f_txt.write(json.dumps(text_data) + '\n')
3739

3840
# 处理训练集和验证集
39-
process_train_valid(train_data, 'Chinese-CLIP/datasets/DatasetName/train_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/train_texts.jsonl')
40-
process_train_valid(val_data, 'Chinese-CLIP/datasets/DatasetName/valid_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/valid_texts.jsonl')
41+
# datasets/DatasetName为在Chinese-CLIP项目目录下新建的存放转换后数据集的文件夹
42+
process_train_valid(train_data, data1_images_folder, 'Chinese-CLIP/datasets/DatasetName/train_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/train_texts.jsonl')
43+
process_train_valid(val_data, data1_images_folder, 'Chinese-CLIP/datasets/DatasetName/valid_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/valid_texts.jsonl')
4144

4245

4346

44-
#制作从文本到图像(Text_to_Image)检索时的,测试集。data2为Text_to_Image测试数据文件夹名
47+
# 制作从文本到图像(Text_to_Image)检索时的,测试集。data2为Text_to_Image测试数据文件夹名
4548
image_data2 = pd.read_csv('original_dataset/data2/image_data.csv')
4649
word_test2 = pd.read_csv('original_dataset/data2/word_test.csv')
50+
# 原始图像测试集目录
51+
data2_images_folder='original_dataset/data2/ImageData'
4752

4853
# 处理Text_to_Image测试集
49-
def process_text_to_image(image_data, word_test, img_file, txt_file):
54+
def process_text_to_image(image_data, images_folder, word_test, img_file, txt_file):
5055
with open(img_file, 'w') as f_img, open(txt_file, 'w') as f_txt:
5156
for index, row in image_data.iterrows():
5257
# 图片内容需要被编码为base64格式
53-
img_path = os.path.join('../dataset/data2/ImageData', row['image_id'])
58+
img_path = os.path.join(images_folder, row['image_id'])
5459
with open(img_path, 'rb') as f_img_file:
5560
img = Image.open(f_img_file)
5661
img_buffer = BytesIO()
@@ -65,20 +70,23 @@ def process_text_to_image(image_data, word_test, img_file, txt_file):
6570
text_data = {"text_id": row["text_id"], "text": row["caption"], "image_ids": []}
6671
f_txt.write(json.dumps(text_data) + '\n')
6772

68-
process_text_to_image(image_data2, word_test2, 'Chinese-CLIP/datasets/DatasetName/test2_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test2_texts.jsonl')
73+
# datasets/DatasetName为在Chinese-CLIP项目目录下新建的存放转换后数据集的文件夹
74+
process_text_to_image(image_data2, data2_images_folder, word_test2, 'Chinese-CLIP/datasets/DatasetName/test2_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test2_texts.jsonl')
6975

7076

7177

72-
#制作从图像到文本(Image_to_Text)检索时的,测试集。data3为Image_to_Text测试数据文件夹名
78+
# 制作从图像到文本(Image_to_Text)检索时的,测试集。data3为Image_to_Text测试数据文件夹名
7379
image_test3 = pd.read_csv('original_dataset/data3/image_test.csv')
7480
word_data3 = pd.read_csv('original_dataset/data3/word_data.csv')
81+
# 原始图像测试集目录
82+
data3_images_folder='original_dataset/data3/ImageData'
7583

7684
# 处理Image_to_Text测试集集
77-
def process_image_to_text(image_data, word_test, img_file, txt_file):
85+
def process_image_to_text(image_data, images_folder, word_test, img_file, txt_file):
7886
with open(img_file, 'w') as f_img, open(txt_file, 'w') as f_txt:
7987
for index, row in image_data.iterrows():
8088
# 图片内容需要被编码为base64格式
81-
img_path = os.path.join('../dataset/data3/ImageData', row['image_id'])
89+
img_path = os.path.join(images_folder, row['image_id'])
8290
with open(img_path, 'rb') as f_img_file:
8391
img = Image.open(f_img_file)
8492
img_buffer = BytesIO()
@@ -93,7 +101,8 @@ def process_image_to_text(image_data, word_test, img_file, txt_file):
93101
text_data = {"text_id": row["text_id"], "text": row["caption"], "image_ids": []}
94102
f_txt.write(json.dumps(text_data) + '\n')
95103

96-
process_image_to_text(image_test3, word_data3, 'Chinese-CLIP/datasets/DatasetName/test3_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test3_texts.jsonl')
104+
# datasets/DatasetName为在Chinese-CLIP项目目录下新建的存放转换后数据集的文件夹
105+
process_image_to_text(image_test3, data3_images_folder, word_data3, 'Chinese-CLIP/datasets/DatasetName/test3_imgs.tsv', 'Chinese-CLIP/datasets/DatasetName/test3_texts.jsonl')
97106

98107

99108
'''

0 commit comments

Comments
 (0)