From 5180f47b708aea29573084e73908893235c9dec5 Mon Sep 17 00:00:00 2001 From: naibo Date: Tue, 24 Dec 2024 00:14:35 +0800 Subject: [PATCH] Add llm and fl beta code --- ExecuteStage/fl_beta.py | 108 +++++++++++++++++++++++++++++++++++++++ ExecuteStage/llm_beta.py | 36 +++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 ExecuteStage/fl_beta.py create mode 100644 ExecuteStage/llm_beta.py diff --git a/ExecuteStage/fl_beta.py b/ExecuteStage/fl_beta.py new file mode 100644 index 00000000..a1a44391 --- /dev/null +++ b/ExecuteStage/fl_beta.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models, transforms +from torch.utils.data import DataLoader, Dataset +import numpy as np +from PIL import Image +import os + +# 定义 ResNet 模型(以 ResNet18 为例) +class ResNetModel(nn.Module): + def __init__(self, num_classes): + super(ResNetModel, self).__init__() + self.resnet = models.resnet18(pretrained=True) + # 修改最后的全连接层以适应特定的分类任务 + self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) + + def forward(self, x): + return self.resnet(x) + +# 自定义数据集类 +class WebpageDataset(Dataset): + def __init__(self, image_dir, transform=None): + self.image_dir = image_dir + self.transform = transform + self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')] + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, idx): + img_name = os.path.join(self.image_dir, self.image_files[idx]) + image = Image.open(img_name).convert('RGB') + label = self.get_label_from_filename(self.image_files[idx]) + if self.transform: + image = self.transform(image) + return image, label + + def get_label_from_filename(self, filename): + # 假设文件名格式为 'class_label.png' + return int(filename.split('_')[0]) + +# 图像预处理 +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# 定义客户端训练函数 +def train_local_model(model, dataloader, criterion, optimizer, epochs=5): + model.train() + for epoch in range(epochs): + for images, labels in dataloader: + outputs = model(images) + loss = criterion(outputs, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + return model.state_dict() + +# 联邦平均算法 +def federated_average(models_state_dicts): + avg_state_dict = models_state_dicts[0] + for key in avg_state_dict.keys(): + for i in range(1, len(models_state_dicts)): + avg_state_dict[key] += models_state_dicts[i][key] + avg_state_dict[key] = torch.div(avg_state_dict[key], len(models_state_dicts)) + return avg_state_dict + +# 模拟多个客户端的数据 +client_data_dirs = ['client1_data', 'client2_data', 'client3_data'] # 每个客户端的数据目录 +num_classes = 10 # 根据实际情况设置 + +# 初始化全局模型 +global_model = ResNetModel(num_classes=num_classes) + +# 定义损失函数 +criterion = nn.CrossEntropyLoss() + +# 联邦学习过程 +num_rounds = 10 +for round in range(num_rounds): + local_models = [] + for client_dir in client_data_dirs: + # 加载客户端数据 + dataset = WebpageDataset(image_dir=client_dir, transform=transform) + dataloader = DataLoader(dataset, batch_size=32, shuffle=True) + + # 初始化客户端模型 + local_model = ResNetModel(num_classes=num_classes) + local_model.load_state_dict(global_model.state_dict()) + + # 定义优化器 + optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9) + + # 训练本地模型 + local_state_dict = train_local_model(local_model, dataloader, criterion, optimizer) + local_models.append(local_state_dict) + + # 聚合模型参数 + global_state_dict = federated_average(local_models) + global_model.load_state_dict(global_state_dict) + + print(f'Round {round+1}/{num_rounds} completed.') + +# 保存全局模型 +torch.save(global_model.state_dict(), 'federated_resnet_model.pth') diff --git a/ExecuteStage/llm_beta.py b/ExecuteStage/llm_beta.py new file mode 100644 index 00000000..93cfb6f7 --- /dev/null +++ b/ExecuteStage/llm_beta.py @@ -0,0 +1,36 @@ +from transformers import AutoProcessor, AutoModelForVision2Seq +from PIL import Image +import torch + +# 加载 Llama 3.2 视觉模型和处理器 +model_name = "meta-llama/Llama-3.2-11B-Vision" # 请根据实际模型路径替换 +processor = AutoProcessor.from_pretrained(model_name) +model = AutoModelForVision2Seq.from_pretrained(model_name) + +# 处理网页截图并提取结构 +def predict_structure_from_image(image_path): + # 加载图像 + image = Image.open(image_path).convert("RGB") + + # 预处理图像 + inputs = processor(images=image, return_tensors="pt") + + # 生成描述(结构描述) + outputs = model.generate( + inputs["pixel_values"], + max_length=512, + num_beams=5, + early_stopping=True + ) + description = processor.decode(outputs[0], skip_special_tokens=True) + return description + +# 示例使用 +if __name__ == "__main__": + # 提供网页截图的路径 + image_path = "webpage_screenshot.png" # 请替换为实际的图像文件路径 + + # 预测结构 + predicted_structure = predict_structure_from_image(image_path) + + print("预测的结构:", predicted_structure)