深度学习第54讲:训练一个深度卷积对抗网络DCGAN
自从GoodFellow提出GAN以后,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。为了解决这些问题,后来的研究...
自从GoodFellow提出GAN以后,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。为了解决这些问题,后来的研究者不断推陈出新,以至于现在有着各种各样的GAN变体和升级网络。比如 LSGAN,WGAN,WGAN-GP,DRAGAN,CGAN,infoGAN, ACGAN,EBGAN,BEGAN,DCGAN以及最近号称史上最强图像生成网络的BigGAN等等。本节仅选取其中的DCGAN——深度卷积对抗网络进行简单讲解并利用keras进行实现。
DCGAN的原始论文为 UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS,所谓DCGAN,顾名思义就是生成器和判别器都是深度卷积神经网络的GAN。

搭建一个稳健的DCGAN要点在于:
-
所有的pooling层使用步幅卷积(判别网络)和微步幅度卷积(生成网络)进行替换。
-
在生成网络和判别网络上使用批处理规范化。
-
对于更深的架构移除全连接隐藏层。
-
在生成网络的所有层上使用ReLU激活函数,除了输出层使用Tanh激活函数。
-
在判别网络的所有层上使用LeakyReLU激活函数。

基于DCGAN生成的卧室图片:

下面就基于keras搭建一个DCGAN。
from keras.layers import Dense, Conv2D, LeakyReLU, Dropout, Input
from keras.layers import Reshape, Conv2DTranspose, Flatten
from keras.models import Model
from keras import optimizers
import kerasimport numpy as npimport warnings
warnings.filterwarnings('ignore')
设置相关参数:
# 潜变量维度
latent_dim = 32
# 输入像素维度
height = 32
width = 32
channels = 3
下面开始搭建生成器网络:
generator_input = Input(shape=(latent_dim,))
x = Dense(128 * 16 * 16)(generator_input)
x = LeakyReLU()(x)
x = Reshape((16, 16, 128))(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = Model(generator_input, x)
generator.summary()
生成器网络概要如下:

然后搭建判别器网络:
discriminator_input = Input(shape=(height, width, channels))
x = Conv2D(128, 3)(discriminator_input)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1, activation='sigmoid')(x)
discriminator = Model(discriminator_input, x)
discriminator.summary()
discriminator_optimizer = optimizers.RMSprop(lr=0.0008,
clipvalue=1.0,
decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer,
loss='binary_crossentropy')
判别器网络概要如下:

将生成器网络和判别器网络进行组合成DCGAN:
# 将判别器参数设置为不可训练
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
# 搭建对抗网络
gan = Model(gan_input, gan_output)
gan_optimizer = optimizers.RMSprop(lr=0.0004,
clipvalue=1.0,
decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')
DCGAN搭建完成之后,我们使用CIFAR-10数据来进行训练,构建训练代码如下:
import os
from keras.preprocessing import image
# 加载cifar-10数据
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
# 指定青蛙图像(编号为6)
x_train = x_train[y_train.flatten() == 6]
x_train = x_train.reshape((x_train.shape[0],) +(height, width, channels)).astype('float32') / 255.
iterations = 10000
batch_size = 20
save_dir = './image'
start = 0
for step in range(iterations):
# 潜在空间随机采样
random_latent_vectors = np.random.normal(size=(batch_size, latent_dim)) # 解码生成虚假图像
generated_images = generator.predict(random_latent_vectors)
stop = start + batch_size
real_images = x_train[start: stop]
# 将虚假图像和真实图像混合
combined_images = np.concatenate([generated_images, real_images]) # 合并标签,区分真实和虚假图像
labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
# 向标签中添加随机噪声
labels += 0.05 * np.random.random(labels.shape)
# 训练判别器
d_loss = discriminator.train_on_batch(combined_images, labels)
# 潜在空间随机采样
random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
# 合并标签,以假乱真
misleading_targets = np.zeros((batch_size, 1))
# 通过gan模型来训练生成器模型,冻结判别器模型权重
a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
start += batch_size
if start > len(x_train) - batch_size:
start = 0
# 每100步绘图并保存
if step % 100 == 0:
gan.save_weights('gan.h5')
print('discriminator loss:', d_loss)
print('adversarial loss:', a_loss)
img = image.array_to_img(generated_images[0] * 255., scale=False)
img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))
img = image.array_to_img(real_images[0] * 255., scale=False)
img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))
训练过程如下:

DCGAN生成的青蛙图片和真实图片混在一起如下图所示,能否辨别出哪张是真实样本,哪张是DCGAN生成的样本?

受限于CIFAR-10数据本身的低像素性,DCGAN生成出来的图像虽然也很模糊,但基本上足以达到以假乱真的水平。上图图片中,每一列有两张是生成样本,有一张是真实样本,按列第2、1、3和2张图片是真实样本,其余都是DCGAN伪造出来的青蛙图片。
以上便是本节内容。
参考资料:
UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS
https://blog.csdn.net/liuxiao214/article/details/74502975
thttp://www.twistedwg.com/2018/01/31/Various-GAN.html
Deep Learning with Python
往期精彩:
一个数据科学从业者的学习历程


长按二维码.关注机器学习实验室

更多推荐



所有评论(0)