python调用tensorflow.keras实现超分辨率生成对抗网络
目录
程序简介
项目调用tensorflow.keras搭建超分辨率生成对抗网络来提高图片分辨率,训练用的数据集则是500张图片
程序输入:60x60的图片
程序输出:120x120的图片
超分辨率生成对抗网络(SRGAN):从其低分辨率(LR)对应物估计高分辨率(HR)图像的极具挑战性的任务被称为超分辨率(SR)。SRGAN是一种用于图像超分辨率(SR)的生成对抗网络(GAN)。
程序/数据集下载
图片迭代器 Module/Collect.py
导入模块、路径
# -*- coding: utf-8 -*- import matplotlib.pyplot as plt import pandas as pd import numpy as np import cv2 import os #路径目录 baseDir = ''#根目录 staticDir = os.path.join(baseDir,'Static')#静态文件目录 resultDir = os.path.join(baseDir,'Result')#结果文件目录 imgsDir = staticDir+'/图片'#图片目录 names = os.listdir(imgsDir)#图片名集合
图片增强函数,随机让图片旋转0-360度,测试查看图片旋转效果
def augment(name): ''' 读取图片,并做随机旋转操作,返回图片矩阵 name:图片名 ''' imgPath = imgsDir + '/' + name#图片路径 #图片矩阵 img = cv2.imdecode(np.fromfile(imgsDir+'/'+name, dtype=np.uint8),-1) #旋转图片矩阵 img = np.rot90(img,k=np.random.randint(4)) return img #同一张图片名输入两次,得到两张不同图片 img1 = augment(names[1]) img2 = augment(names[1]) combine = np.concatenate((img1,img2), axis=1) plt.matshow(combine)
图片缩小函数,原图为网络期望输出,缩小图片为网络输入,测试查看效果
def reduce(img): ''' 缩小图片的长宽为原来的一半,返回小图 img:图片矩阵 ''' #采用双线性插值算法缩小图片 miniImg = cv2.resize(img,(int(img.shape[0]/2),int(img.shape[1]/2)), interpolation=cv2.INTER_LINEAR) return miniImg img3 = reduce(img1) plt.matshow(img1) plt.title('原图') plt.matshow(img3) plt.title('缩小图')
因为图片像素区间为[0,255],而神经网络输入输出的区间最好是[-1,1],所以需要下文的函数对像素值进行归一化和还原操作,测试查看效果
def normalizeImg(img): '''将图片归一化到-1,1,这里可以有小数''' img = (img/255 - 0.5)*2 return img def reverseImg(img): '''将图片还原到原数量级,这里不能有小数''' img = (img + 1)*255/2 img = img.astype(np.uint8) return img print('原数',[0,255],'归一化后',normalizeImg(np.array([0,255])),'还原后',reverseImg(np.array([-1,1])))
原数 [0, 255] 归一化后 [-1. 1.] 还原后 [ 0 255]
图片迭代器,调用上文定义的函数,每次调用都会随机抽取批处理量的图片,并且对图片进行随机增强的操作函数,返回的数据为DataFrame,分为4列,未归一化的输入输出集,归一化的输入输出集
def collect(batchSize): ''' 随机批量抽取图片作为训练输入输出 batchSize:批量大小 ''' #随机选择batch张图片 choosNames = np.random.choice(names,batchSize,replace=False) data = pd.DataFrame({'name':choosNames}) #原输出集 data['output'] = data['name'].apply(augment) #原输入集 data['input'] = data['output'].apply(reduce) #归一化输入集 data['normalInput'] = data['input'].apply(normalizeImg) #归一化输出集 data['normalOutput'] = data['output'].apply(normalizeImg) return data
搭建SRGAN框架 Module/BuileNet.py
导入模块
# -*- coding: utf-8 -*- from tensorflow.keras.layers import Input,Dense,Conv2D,Flatten,BatchNormalization,UpSampling2D from tensorflow.keras.layers import PReLU,Add,Concatenate,LeakyReLU from tensorflow.keras.models import Model from tensorflow.keras.optimizers import Adam,RMSprop from tensorflow.keras.losses import mean_squared_error as mse from tensorflow.keras.losses import mean_absolute_error as mae from tensorflow.keras.applications import VGG19 from tensorflow.keras.applications.vgg19 import preprocess_input import tensorflow.keras.backend as K import tensorflow as tf import numpy as np
生成器构建函数,即将60x60的图片超分为120x120的图片的神经网络,其中比较重要的结构被称为残差块,即程序中的resBlock函数,这里没配置损失函数和优化器,是因为生成器的训练在下文的对抗网络训练过程中
def resBlock(xIn,filterNum): '''残差块''' x = Conv2D(filters=filterNum,kernel_size=3,padding='same')(xIn) x = BatchNormalization()(x) x = LeakyReLU()(x) x = Conv2D(filters=filterNum,kernel_size=3,padding='same')(x) x = BatchNormalization()(x) x = Add()([xIn, x]) x = LeakyReLU()(x) return x def createGenerator(layerNum,filterNum): ''' 创建生成器 layerNum:残差块数 filterNum:残差块卷积核数 ''' #输入层 inputLayer = Input(shape=(None,None,3)) #第一层 firstLayer = Conv2D(filters=filterNum,kernel_size=3,padding='same')(inputLayer) firstLayer = BatchNormalization()(firstLayer) firstLayer = LeakyReLU()(firstLayer) #中间层 middle = firstLayer for num in range(layerNum): middle = resBlock(middle,filterNum) middle = Conv2D(filters=filterNum,kernel_size=3,padding='same')(middle) middle = BatchNormalization()(middle) middle = LeakyReLU()(middle) middle = Add()([firstLayer,middle]) middle = UpSampling2D(size=2)(middle) #输出层 outputLayer = Conv2D(filters=3,kernel_size=9,padding="same",activation='tanh')(middle) #建模 model = Model(inputs=inputLayer,outputs=outputLayer) return model
判别器构建函数,即判断生成器生成的高清图片是否为真实图片,判别器差不多就是普通的分类卷积神经网络,输出在[0,1]区间,损失函数则是二分类损失
def block(xIn,filterNum): '''卷积+标准化+激活块''' x = Conv2D(filterNum,kernel_size=3,strides=3,padding='same')(xIn) x = BatchNormalization()(x) x = LeakyReLU(0.2)(x) return x def createDiscriminator(layerNum,filterNum,lr): ''' 创建判别器 layerNum:中间块数 filterNum:中间块卷积核数 lr:学习率 ''' #输入层 inputLayer = Input(shape=(120,120,3)) #中间层 middle = inputLayer for num in range(layerNum): middle = block(middle,filterNum) middle = Flatten()(middle) middle = Dense(1000)(middle) middle = LeakyReLU()(middle) #输出层 outputLayer = Dense(1, activation='sigmoid')(middle) #建模 model = Model(inputs=inputLayer,outputs=outputLayer) #优化器 optimizer = RMSprop(lr=lr) model.compile(optimizer=optimizer, loss='binary_crossentropy') return model
构建对抗网络,即组合生成器和判别器,形成新的网络,构成这个网络的原因是为了训练生成器,生成器的目的就是迷惑判别器,组合后的网络先将判别器部分的参数固定,然后训练生成器部分的参数,使得判别器分不清真实和生成图片
注意,对抗网络的输入和输出都在[-1,1]的区间内,而在计算内容损失时需要将图片还原,所以这里定义一个reverseImg图片还原函数,内容损失是对抗网络的损失函数
def reverseImg(img): '''将图片还原到原数量级''' img = (img + 1)*255/2 return img print('处理前',np.array([-1,1]),'处理后',reverseImg(np.array([-1,1])))
处理前 [-1 1] 处理后 [ 0. 255.]
对抗网络的输出有两部分,第一部分为组合判别网判断生成图片是否为真实图片,第二部分为组合生成器生成的图片
与之对应的损失函数也有两部分,第一部分为组合后判别器的二分类损失,第二部分为内容损失,内容损失这里不是将原高清图片与生成图片进行MSE计算,而是需要用VGG19网络进行特征提取,然后对两张图片的特征进行MSE计算
#特征提取器 vgg19 = VGG19(include_top=False, weights='imagenet') vgg19 = Model(vgg19.input, vgg19.output) def contentLoss(y_true, y_pred): '''内容损失''' y_true = reverseImg(y_true) y_pred = reverseImg(y_pred) y_true = preprocess_input(y_true) y_pred = preprocess_input(y_pred) sr = vgg19(y_pred) hr = vgg19(y_true) return mse(y_true, y_pred) def createGan(generator,discriminator,lr): '''构建对抗网''' discriminator.trainable = False #生成器输入 lowImg = generator.input #生成器输出 fakeHighImg = generator(lowImg) #生成器判断 judge = discriminator(fakeHighImg) model = Model(inputs=lowImg,outputs=[judge,fakeHighImg]) optimizer = RMSprop(lr=lr) model.compile(optimizer=optimizer, loss=['binary_crossentropy', contentLoss],loss_weights=[1, 1e-1]) model.summary() return model
实例化生成器、判别器、对抗网络
generator = createGenerator(5,100)#生成器 discriminator = createDiscriminator(10,100,1e-5)#判别器 print('打印内容为对抗网络结构') gan = createGan(generator,discriminator,1e-5)#对抗网络
打印内容为对抗网络结构 _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_9 (InputLayer) (None, None, None, 3) 0 _________________________________________________________________ model_10 (Model) (None, None, None, 3) 1023003 _________________________________________________________________ model_11 (Model) (None, 1) 919701 ================================================================= Total params: 1,942,704 Trainable params: 1,020,603 Non-trainable params: 922,101 _________________________________________________________________
训练网络,查看效果 Main.py
导入模块、路径、预设参数
# -*- coding: utf-8 -*- from Module.BuidModel import createGenerator,createDiscriminator,createGan from Module.Collect import collect,reverseImg import cv2 import numpy as np import os ############################可调整参数########################## batchSize = 10#批处理量 genFilters = 100#生成器核数 disFilters = 100#判别器核数 genLayers = 5#生成器残差块层数 disLayers = 10#判别器卷积块数 genLearnRate = 5e-5#生成学习率 disLearnRate = 1e-4#判别学习率 ############################################################## #路径目录 baseDir = ''#当前目录 staticDir = os.path.join(baseDir,'Static')#静态文件目录 resultDir = os.path.join(baseDir,'Result')#结果文件目录
实例化生成器、判别器、对抗网络,效果与上文的实例化演示一致
generator = createGenerator(genLayers,genFilters)#生成器 discriminator = createDiscriminator(disLayers,disFilters,disLearnRate)#判别器 gan = createGan(generator,discriminator,genLearnRate)#对抗网络
进入训练的无限循环,每个epoch随机抽取图片,首先用生成器从低清图生成到伪高清图,然后将伪高清图和原高清图作为输入到判别器
判别器的训练目的在于给伪高清图打标签0,给原高清图片打标签1
最后组合生成器和判别器,固定判别器的参数,训练对抗网络(即生成器的参数),目的是让生成器混淆判别器的识别能力,使生成器的生成图片尽可能的被判别器打为标签1
每100个epoch保存效果图,下文查看效果
epochs = 0#迭代次数 loss = {'dLoss':[],'gLoss':[],'cLoss':[]} while True: epochs += 1 imgs = collect(batchSize)#随机抽取图片 #归一化的输入 try: low = np.array(imgs['normalInput'].values.tolist()).reshape(batchSize,60,60,3) except: epochs -= 1 print('error') continue #生成高清图(-1,1) fakeHigh = generator.predict(low) #原高清图(-1,1) realHigh = np.array(imgs['normalOutput'].values.tolist()).reshape(-1,120,120,3) #真伪标签 realBool = np.random.uniform(0.7,1,size=(batchSize,)) fakeBool = np.random.uniform(0,0.3,size=(batchSize,)) #鉴别器训练 discriminator.trainable = True dRealLoss = discriminator.train_on_batch(x=realHigh, y=realBool) dFakeLoss = discriminator.train_on_batch(x=fakeHigh, y=fakeBool) #判别损失 loss['dLoss'].append(0.5 * (dRealLoss + dFakeLoss)) #生成器训练 discriminator.trainable = False ganLoss = gan.train_on_batch(x=low, y=[realBool,realHigh]) #对抗损失 loss['gLoss'].append(ganLoss[1]) #内容损失 loss['cLoss'].append(ganLoss[2]) if epochs%100==0: #打印损失 dLoss = np.array(loss['dLoss'][-100:]).mean() gLoss = np.array(loss['gLoss'][-100:]).mean() cLoss = np.array(loss['cLoss'][-100:]).mean() print('epoch:%d dLoss:%.4f gLoss:%.4f cLoss:%.4f'%(epochs,dLoss,gLoss,cLoss)) #保存模型 generator.save_weights(resultDir+'/generator.h5') discriminator.save_weights(resultDir+'/discriminator.h5') #原低清图(-1,1) lowImg = low[0] #生成高清图 fakeHigh = generator.predict(lowImg[np.newaxis,:]).reshape((120,120,3)) fakeHigh = reverseImg(fakeHigh) #传统双线性插值放大 lineHigh = cv2.resize(reverseImg(lowImg),(120,120), interpolation=cv2.INTER_LINEAR) #原图像 originHigh = reverseImg(realHigh[0]) #组合图片 combine = np.concatenate((lineHigh,fakeHigh,originHigh), axis=1) cv2.imencode('.jpg',combine)[1].tofile(resultDir+'/compare.jpg')
从左往右数,图1为传统的双线性插值法的超分结果,图2为超分对抗网络的超分结果,图3为原高清图