-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
naibo
committed
Dec 23, 2024
1 parent
b4d7ddf
commit 5180f47
Showing
2 changed files
with
144 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torchvision import models, transforms | ||
from torch.utils.data import DataLoader, Dataset | ||
import numpy as np | ||
from PIL import Image | ||
import os | ||
|
||
# 定义 ResNet 模型(以 ResNet18 为例) | ||
class ResNetModel(nn.Module): | ||
def __init__(self, num_classes): | ||
super(ResNetModel, self).__init__() | ||
self.resnet = models.resnet18(pretrained=True) | ||
# 修改最后的全连接层以适应特定的分类任务 | ||
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) | ||
|
||
def forward(self, x): | ||
return self.resnet(x) | ||
|
||
# 自定义数据集类 | ||
class WebpageDataset(Dataset): | ||
def __init__(self, image_dir, transform=None): | ||
self.image_dir = image_dir | ||
self.transform = transform | ||
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')] | ||
|
||
def __len__(self): | ||
return len(self.image_files) | ||
|
||
def __getitem__(self, idx): | ||
img_name = os.path.join(self.image_dir, self.image_files[idx]) | ||
image = Image.open(img_name).convert('RGB') | ||
label = self.get_label_from_filename(self.image_files[idx]) | ||
if self.transform: | ||
image = self.transform(image) | ||
return image, label | ||
|
||
def get_label_from_filename(self, filename): | ||
# 假设文件名格式为 'class_label.png' | ||
return int(filename.split('_')[0]) | ||
|
||
# 图像预处理 | ||
transform = transforms.Compose([ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
]) | ||
|
||
# 定义客户端训练函数 | ||
def train_local_model(model, dataloader, criterion, optimizer, epochs=5): | ||
model.train() | ||
for epoch in range(epochs): | ||
for images, labels in dataloader: | ||
outputs = model(images) | ||
loss = criterion(outputs, labels) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
return model.state_dict() | ||
|
||
# 联邦平均算法 | ||
def federated_average(models_state_dicts): | ||
avg_state_dict = models_state_dicts[0] | ||
for key in avg_state_dict.keys(): | ||
for i in range(1, len(models_state_dicts)): | ||
avg_state_dict[key] += models_state_dicts[i][key] | ||
avg_state_dict[key] = torch.div(avg_state_dict[key], len(models_state_dicts)) | ||
return avg_state_dict | ||
|
||
# 模拟多个客户端的数据 | ||
client_data_dirs = ['client1_data', 'client2_data', 'client3_data'] # 每个客户端的数据目录 | ||
num_classes = 10 # 根据实际情况设置 | ||
|
||
# 初始化全局模型 | ||
global_model = ResNetModel(num_classes=num_classes) | ||
|
||
# 定义损失函数 | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
# 联邦学习过程 | ||
num_rounds = 10 | ||
for round in range(num_rounds): | ||
local_models = [] | ||
for client_dir in client_data_dirs: | ||
# 加载客户端数据 | ||
dataset = WebpageDataset(image_dir=client_dir, transform=transform) | ||
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) | ||
|
||
# 初始化客户端模型 | ||
local_model = ResNetModel(num_classes=num_classes) | ||
local_model.load_state_dict(global_model.state_dict()) | ||
|
||
# 定义优化器 | ||
optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9) | ||
|
||
# 训练本地模型 | ||
local_state_dict = train_local_model(local_model, dataloader, criterion, optimizer) | ||
local_models.append(local_state_dict) | ||
|
||
# 聚合模型参数 | ||
global_state_dict = federated_average(local_models) | ||
global_model.load_state_dict(global_state_dict) | ||
|
||
print(f'Round {round+1}/{num_rounds} completed.') | ||
|
||
# 保存全局模型 | ||
torch.save(global_model.state_dict(), 'federated_resnet_model.pth') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from transformers import AutoProcessor, AutoModelForVision2Seq | ||
from PIL import Image | ||
import torch | ||
|
||
# 加载 Llama 3.2 视觉模型和处理器 | ||
model_name = "meta-llama/Llama-3.2-11B-Vision" # 请根据实际模型路径替换 | ||
processor = AutoProcessor.from_pretrained(model_name) | ||
model = AutoModelForVision2Seq.from_pretrained(model_name) | ||
|
||
# 处理网页截图并提取结构 | ||
def predict_structure_from_image(image_path): | ||
# 加载图像 | ||
image = Image.open(image_path).convert("RGB") | ||
|
||
# 预处理图像 | ||
inputs = processor(images=image, return_tensors="pt") | ||
|
||
# 生成描述(结构描述) | ||
outputs = model.generate( | ||
inputs["pixel_values"], | ||
max_length=512, | ||
num_beams=5, | ||
early_stopping=True | ||
) | ||
description = processor.decode(outputs[0], skip_special_tokens=True) | ||
return description | ||
|
||
# 示例使用 | ||
if __name__ == "__main__": | ||
# 提供网页截图的路径 | ||
image_path = "webpage_screenshot.png" # 请替换为实际的图像文件路径 | ||
|
||
# 预测结构 | ||
predicted_structure = predict_structure_from_image(image_path) | ||
|
||
print("预测的结构:", predicted_structure) |