Skip to content

Commit

Permalink
Add llm and fl beta code
Browse files Browse the repository at this point in the history
  • Loading branch information
naibo committed Dec 23, 2024
1 parent b4d7ddf commit 5180f47
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 0 deletions.
108 changes: 108 additions & 0 deletions ExecuteStage/fl_beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
import os

# 定义 ResNet 模型(以 ResNet18 为例)
class ResNetModel(nn.Module):
def __init__(self, num_classes):
super(ResNetModel, self).__init__()
self.resnet = models.resnet18(pretrained=True)
# 修改最后的全连接层以适应特定的分类任务
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

def forward(self, x):
return self.resnet(x)

# 自定义数据集类
class WebpageDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]

def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
img_name = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')
label = self.get_label_from_filename(self.image_files[idx])
if self.transform:
image = self.transform(image)
return image, label

def get_label_from_filename(self, filename):
# 假设文件名格式为 'class_label.png'
return int(filename.split('_')[0])

# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 定义客户端训练函数
def train_local_model(model, dataloader, criterion, optimizer, epochs=5):
model.train()
for epoch in range(epochs):
for images, labels in dataloader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model.state_dict()

# 联邦平均算法
def federated_average(models_state_dicts):
avg_state_dict = models_state_dicts[0]
for key in avg_state_dict.keys():
for i in range(1, len(models_state_dicts)):
avg_state_dict[key] += models_state_dicts[i][key]
avg_state_dict[key] = torch.div(avg_state_dict[key], len(models_state_dicts))
return avg_state_dict

# 模拟多个客户端的数据
client_data_dirs = ['client1_data', 'client2_data', 'client3_data'] # 每个客户端的数据目录
num_classes = 10 # 根据实际情况设置

# 初始化全局模型
global_model = ResNetModel(num_classes=num_classes)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 联邦学习过程
num_rounds = 10
for round in range(num_rounds):
local_models = []
for client_dir in client_data_dirs:
# 加载客户端数据
dataset = WebpageDataset(image_dir=client_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 初始化客户端模型
local_model = ResNetModel(num_classes=num_classes)
local_model.load_state_dict(global_model.state_dict())

# 定义优化器
optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9)

# 训练本地模型
local_state_dict = train_local_model(local_model, dataloader, criterion, optimizer)
local_models.append(local_state_dict)

# 聚合模型参数
global_state_dict = federated_average(local_models)
global_model.load_state_dict(global_state_dict)

print(f'Round {round+1}/{num_rounds} completed.')

# 保存全局模型
torch.save(global_model.state_dict(), 'federated_resnet_model.pth')
36 changes: 36 additions & 0 deletions ExecuteStage/llm_beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import torch

# 加载 Llama 3.2 视觉模型和处理器
model_name = "meta-llama/Llama-3.2-11B-Vision" # 请根据实际模型路径替换
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForVision2Seq.from_pretrained(model_name)

# 处理网页截图并提取结构
def predict_structure_from_image(image_path):
# 加载图像
image = Image.open(image_path).convert("RGB")

# 预处理图像
inputs = processor(images=image, return_tensors="pt")

# 生成描述(结构描述)
outputs = model.generate(
inputs["pixel_values"],
max_length=512,
num_beams=5,
early_stopping=True
)
description = processor.decode(outputs[0], skip_special_tokens=True)
return description

# 示例使用
if __name__ == "__main__":
# 提供网页截图的路径
image_path = "webpage_screenshot.png" # 请替换为实际的图像文件路径

# 预测结构
predicted_structure = predict_structure_from_image(image_path)

print("预测的结构:", predicted_structure)

0 comments on commit 5180f47

Please sign in to comment.