gan实战(DCGAN、)
创始人
2025-06-01 21:17:19

一、DCGAN

1.1 参数

(1)输入:会被放缩到6464
(2)输出:64
64
(3)数据集:

1.2 实现

import glob
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import oslog_dir = "./model/dcgan.pth"
images_path = glob.glob('./data/xinggan_face/*.jpg')BATCH_SIZE = 32
dataset = FaceDataset(images_path)
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
image_batch = next(iter(data_loader))transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])class FaceDataset(data.Dataset):def __init__(self, images_path):self.images_path = images_pathdef __getitem__(self, index):image_path = self.images_path[index]pil_img = Image.open(image_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.images_path)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(100, 256*16*16)self.bn1 = nn.BatchNorm1d(256*16*16)self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)  # 输出:128*16*16self.bn2 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # 输出:64*32*32self.bn3 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)  # 输出:3*64*64def forward(self, x):x = F.relu(self.linear1(x))x = self.bn1(x)x = x.view(-1, 256, 16, 16)x = F.relu(self.deconv1(x))x = self.bn2(x)x = F.relu(self.deconv2(x))x = self.bn3(x)x = F.tanh(self.deconv3(x))return x# 定义判别器
class Discrimination(nn.Module):def __init__(self):super(Discrimination, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2)  # 64*31*31self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2)  # 128*15*15self.bn1 = nn.BatchNorm2d(128)self.fc = nn.Linear(128*15*15, 1)def forward(self, x):x = F.dropout(F.leaky_relu(self.conv1(x)), p=0.3)x = F.dropout(F.leaky_relu(self.conv2(x)), p=0.3)x = self.bn1(x)x = x.view(-1, 128*15*15)x = torch.sigmoid(self.fc(x))return x# 定义可视化函数
def generate_and_save_images(model, epoch, test_noise_):predictions = model(test_noise_).permute(0, 2, 3, 1).cpu().numpy()fig = plt.figure(figsize=(20, 160))for i in range(predictions.shape[0]):plt.subplot(1, 8, i+1)plt.imshow((predictions[i]+1)/2)# plt.axis('off')plt.show()# 训练函数
def train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch):print("开始训练")test_noise = torch.randn(8, 100, device=device)writer = SummaryWriter(r'D:\Project\PythonProject\Ttest\run')writer.add_graph(gen, test_noise)#############################D_loss = []G_loss = []# 开始训练for epoch in range(start_epoch, 500):D_epoch_loss = 0G_epoch_loss = 0batch_count = len(data_loader)   # 返回批次数for step, img, in enumerate(data_loader):img = img.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)  # 生成随机输入# 固定生成器,训练判别器dis_opti.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output, device=device))d_real_loss.backward()generated_img = gen(random_noise)# print(generated_img)fake_output = dis(generated_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output, device=device))d_fake_loss.backward()dis_loss = d_real_loss + d_fake_lossdis_opti.step()# 固定判别器,训练生成器gen_opti.zero_grad()fake_output = dis(generated_img)gen_loss = loss_fn(fake_output, torch.ones_like(fake_output, device=device))gen_loss.backward()gen_opti.step()with torch.no_grad():D_epoch_loss += dis_loss.item()G_epoch_loss += gen_loss.item()writer.add_scalar("loss/dis_loss", D_epoch_loss / (epoch+1), epoch+1)writer.add_scalar("loss/gen_loss", G_epoch_loss / (epoch+1), epoch+1)with torch.no_grad():D_epoch_loss /= batch_countG_epoch_loss /= batch_countD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print("Epoch:{}, 判别器损失:{}, 生成器损失:{}.".format(epoch, dis_loss, gen_loss))generate_and_save_images(gen, epoch, test_noise)state = {"gen": gen.state_dict(),"dis": dis.state_dict(),"gen_opti": gen_opti.state_dict(),"dis_opti": dis_opti.state_dict(),"epoch": epoch}torch.save(state, log_dir)plt.plot(range(1, len(D_loss)+1), D_loss, label="D_loss")plt.plot(range(1, len(D_loss)+1), G_loss, label="G_loss")plt.xlabel('epoch')plt.legend()plt.show()if __name__ == '__main__':device = "cuda:0" if torch.cuda.is_available() else "cpu"gen = Generator().to(device)dis = Discrimination().to(device)loss_fn = torch.nn.BCELoss()gen_opti = torch.optim.Adam(gen.parameters(), lr=0.0001)dis_opti = torch.optim.Adam(dis.parameters(), lr=0.00001)start_epoch = 0if os.path.exists(log_dir):checkpoint = torch.load(log_dir)gen.load_state_dict(checkpoint["gen"])dis.load_state_dict(checkpoint["dis"])gen_opti.load_state_dict(checkpoint["gen_opti"])dis_opti.load_state_dict(checkpoint["dis_opti"])start_epoch = checkpoint["epoch"]print("模型加载成功,epoch从{}开始训练".format(start_epoch))train(gen, dis, loss_fn, gen_opti, dis_opti, start_epoch)

1.3 实验效果

开始训练
Epoch:0, 判别器损失:1.6549043655395508, 生成器损失:0.7864767909049988.
在这里插入图片描述
Epoch:20, 判别器损失:1.3690211772918701, 生成器损失:0.6662370562553406.
在这里插入图片描述
Epoch:40, 判别器损失:1.413375735282898, 生成器损失:0.7497923970222473.
在这里插入图片描述
Epoch:60, 判别器损失:1.2889504432678223, 生成器损失:0.8668195009231567.
在这里插入图片描述
Epoch:80, 判别器损失:1.2824485301971436, 生成器损失:0.805076003074646.
在这里插入图片描述
Epoch:100, 判别器损失:1.3278448581695557, 生成器损失:0.7859240770339966.
在这里插入图片描述
Epoch:120, 判别器损失:1.39650297164917, 生成器损失:0.7616179585456848.
在这里插入图片描述
Epoch:140, 判别器损失:1.3387322425842285, 生成器损失:0.811163067817688.
在这里插入图片描述
Epoch:160, 判别器损失:1.1281094551086426, 生成器损失:0.7557946443557739.

在这里插入图片描述
Epoch:180, 判别器损失:1.369300365447998, 生成器损失:0.5207887887954712.

在这里插入图片描述

相关内容

热门资讯

火龙果怎么剥皮,火龙果可以徒手... 火龙果怎么剥皮目录火龙果怎么剥皮火龙果可以徒手剥皮吗?火龙果的皮怎么剥?火龙果怎么剥皮最方便火龙果怎...
木桂是什么 极速百科网 极速百... 木桂是什么目录木桂是什么木桂是什么什么是木马病毒
激励自己努力的短句狠一点的,奋... 激励自己努力的短句狠一点的目录激励自己努力的短句狠一点的奋斗励志的句子简短 励志短句致自己奋斗比较励...
房产销售员好做吗,房产销售好做... 房产销售员好做吗目录房产销售员好做吗房产销售好做吗听一听业内人士怎么说房地产业务员好做吗做房地产销售...
关于女王节的句子,2022女王... 关于女王节的句子目录关于女王节的句子2022女王节发朋友圈的句子有哪些关于三八节文案句子?女王节唯美...
小鲜肉是什么意思,小鲜肉什么意... 小鲜肉是什么意思目录小鲜肉是什么意思小鲜肉什么意思,是什么 小鲜肉老有人说我是小鲜肉什么意思啊?小鲜...
手表镜面划痕怎么办,手表表镜划... 手表镜面划痕怎么办目录手表镜面划痕怎么办手表表镜划痕修复手表镜面刮花了怎么办手表表镜刮花怎么办紧急处...
可转债多久上市,转债几天后上市... 可转债多久上市目录可转债多久上市转债几天后上市中了转债几天上市?可转债上市时间可转债多久上市 ...
6寸照片尺寸(6寸照片尺寸大小... 今天给各位分享6寸照片尺寸的知识,其中也会对6寸照片尺寸大小图片进行解释,如果能碰巧解决你现在面临的...
全日制和在职的是什么意思 极速... 全日制和在职的是什么意思目录全日制和在职的是什么意思全日制和在职的是什么意思在职研究生是什么意思?和...
如何退货怎么寄回去,退货怎么寄... 如何退货怎么寄回去目录如何退货怎么寄回去退货怎么寄回去给商家买家怎样寄快递回商家退货要怎么寄回去,有...
葬爱家族怎么来的,当别人说你是... 葬爱家族怎么来的目录葬爱家族怎么来的当别人说你是葬爱家族你怎么回答?葬爱小龙是葬爱家族的吗? 有人说...
如何为孩子选购儿童脚踏车(儿童... 本篇文章极速百科给大家谈谈如何为孩子选购儿童脚踏车,以及儿童脚踏车什么牌子比较好对应的知识点,希望对...
学好高中语文的方法,怎样学好高... 学好高中语文的方法目录学好高中语文的方法怎样学好高中语文?如何学好高中语文如何学好高中语文学好高中语...
家用轿车哪款比较好?家用轿车排... 今天给各位分享家用轿车哪款比较好?家用轿车排行榜前十名2022的知识,其中也会对家用轿车哪款车最实用...
海尔电视怎么调出电视模式,海尔... 海尔电视怎么调出电视模式目录海尔电视怎么调出电视模式海尔电视怎么调出电视模式?海尔电视户户通怎么切换...
北京x7的优势有哪些hhhh8... 本篇文章极速百科给大家谈谈北京x7的优势有哪些hhhh88,以及北京x7好么对应的知识点,希望对各位...
日土县属于哪个市,西藏日土县是... 日土县属于哪个市目录日土县属于哪个市西藏日土县是哪个市?西藏包括哪些地方?日土县属于哪个市日土县属于...
排放标准国一到国六的符号是什么... 排放标准国一到国六的符号是什么目录排放标准国一到国六的符号是什么排放标准国一到国六的符号分别是什么?...
处暑下雨谚语,处暑下雨的谚语 ... 处暑下雨谚语目录处暑下雨谚语处暑下雨的谚语关于下雨的农谚或俗语有哪些?关于处暑的谚语大全处暑下雨谚语...
黄骅市邮编号码市多少,河北省黄... 黄骅市邮编号码市多少目录黄骅市邮编号码市多少河北省黄骅市羊二庄镇的邮编号码是多少黄骅市旧城镇的邮政编...
书记处书记是什么职位,书记处书... 书记处书记是什么职位目录书记处书记是什么职位书记处书记是干什么的学生会书记处是做什么的?主要负责什么...
埃及的货币叫什么,在埃及可以使... 埃及的货币叫什么目录埃及的货币叫什么在埃及可以使用什么货币,可以用美金吗?非洲各国的钱叫什么名称?埃...
怎样能快速学好英语,如何有效快... 怎样能快速学好英语目录怎样能快速学好英语如何有效快速学好英语怎么才能快速轻易学好英语?英语怎么才能快...
蔷薇属植物有哪些 极速百科网 ... 蔷薇属植物有哪些目录蔷薇属植物有哪些蔷薇属植物有哪些蔷薇属植物有哪些?在城市绿化中常用的有哪几种?蔷...
econ什么专业(eco是什么... 今天给各位分享econ什么专业的知识,其中也会对eco是什么专业进行解释,如果能碰巧解决你现在面临的...
oppo手机关不了机怎么办,o... 3. 检查电源键:确认电源键是否正常使用,可锁屏亮屏排除尝试。 5. 卸载近期下载的软件:若无...
快狗打车加入条件,加入快狗打车... 快狗打车加入条件目录快狗打车加入条件加入快狗打车需要什么条件?没有车可以加入快狗打车么?快狗打车加盟...
耳顺之年是指什么年龄(耳顺之年... 本篇文章极速百科给大家谈谈耳顺之年是指什么年龄,以及耳顺之年是指什么年龄段对应的知识点,希望对各位有...
海尔洗衣机出现e1故障怎么处理... 海尔洗衣机出现e1故障怎么处理目录海尔洗衣机出现e1故障怎么处理海尔洗衣机海尔玫瑰钻XQG50-E7...