超分辨率(1)--基于GAN网络实现图像超分辨率重建
���录
一.项目介绍
二.项目流程详解
2.1.数据加载与配置
2.2.构建生成网络
2.3.构建判别网络
2.4.VGG特征提取网络
2.5.损失函数
三.完整代码
四.数据集
五.测试网络
一.项目介绍
超分辨率(Super-Resolution),简称超分(SR)。是指利用光学及其相关光学知识,根据已知图像信息恢复图像细节和其他数据信息的过程,简单来说就是增大图像的分辨率,防止其图像质量下降。
GAN的全称是Generative Adversarial Networks,即生成对抗网络。生成对抗网络一般由一个生成器(生成网络),和一个判别器(判别网络)组成
SRGAN使用了生成对抗的方式来进行图像的超分辨率重建,同时提出了一个由Adversarial Loss和Content Loss组成的损失函数。
论文地址:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network | IEEE Conference Publication | IEEE Xplorehttps://ieeexplore.ieee.org/document/8099502
网络架构:
分为两个网络:生成网络和判别网络
生成网络的作用是,通过学习训练集数据的特征,在判别器的指导下,将随机噪声分布尽量拟合为训练数据的真实分布,从而生成具有训练集特征的相似数据。
判别网络则负责区分输入的数据是真实的还是生成器生成的假数据,并反馈给生成器。
两个网络交替训练,能力同步提高,直到生成网络生成的数据能够以假乱真,并与与判别网络的能力达到一定均衡。
二.项目流程详解
2.1.数据加载与配置
参数配置:
from easydict import EasyDict as edict import json config = edict() config.TRAIN = edict() ## Adam config.TRAIN.batch_size = 4 config.TRAIN.lr_init = 1e-4 config.TRAIN.beta1 = 0.9 ## initialize G config.TRAIN.n_epoch_init = 100 # config.TRAIN.lr_decay_init = 0.1 # config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2) ## adversarial learning (SRGAN) config.TRAIN.n_epoch = 2000 config.TRAIN.lr_decay = 0.1 config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2) ## train set location config.TRAIN.hr_img_path = './srdata/DIV2K_train_HR' config.TRAIN.lr_img_path = './srdata/DIV2K_train_LR_bicubic/X4' config.VALID = edict() ## test set location config.VALID.hr_img_path = './srdata/DIV2K_valid_HR' config.VALID.lr_img_path = './srdata/DIV2K_valid_LR_bicubic/X4' def log_config(filename, cfg): with open(filename, 'w') as f: f.write("================================================\n") f.write(json.dumps(cfg, indent=4)) f.write("\n================================================\n")
数据加载:
# 通过tl.files.load_file_list获取图片名字 # 第一个参数是图片所在的文件夹的路径,第二个参数为图片类型 # sorted[:x]表示读取x个图像。(读取图像过多可能造成memory error问题) train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:100] train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))[:100] valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:50] valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:50] # If your machine have enough memory, please pre-load the whole train set. # 通过tl.vis.read_images读取图片 # 第一个参数是先前取出的图片,第二个参数是图片所在的文件夹地址,第三个参数是一次性读取多少图片 train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path,n_threads=8)
2.2.构建生成网络
tf.compat.v1.disable_eager_execution() t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') # 构建生成网络 # reuse=False表示不重复构建网络 net_g = SRGAN_g(t_image, is_train=True, reuse=False)
SRGAN_g:
def SRGAN_g(t_image, is_train=False, reuse=False): """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network feature maps (n) and stride (s) feature maps (n) and stride (s) """ # 权重初始化 w_init = tf.random_normal_initializer(stddev=0.02) b_init = None # tf.constant_initializer(value=0.0) # gamma值初始化(BatchNormalization中的参数) g_init = tf.random_normal_initializer(1., 0.02) with tf.variable_scope("SRGAN_g", reuse=reuse) as vs: # tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+ # 输入层构造 n = InputLayer(t_image, name='in') # 卷积层构造 n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c') temp = n # B residual blocks(增加16层残差模块) for i in range(16): nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i) nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i) nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i) nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i) # 两个网络相融合形成残差网络:nn = n + nn # 其中n是最初未经处理的网络,nn是处理后的网络(此处是经过两次卷积和两次BatchNormalization) nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i) n = nn n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m') n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m') # 最开始的网络和执行了一系列处理后的网络再进行一次融合形成新的网络 n = ElementwiseLayer([n, temp], tf.add, name='add3') # B residual blacks end # 开始对照片进行重构操作,由低分辨率重构成高分辨率 n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1') n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2') n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') # 重构后进行一次卷积得到最终的结果 n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out') return n
2.3.构建判别网络
tf.compat.v1.disable_eager_execution() t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') # 构建判别网络 # 让判别网络判断什么是真的,传入的数据参数是真实的图像数据 # reuse=False表示不共用网络 net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) # 让判别网络判断什么是假的,传入的数据参数是生成网络生成的图像数据 # reuse=True表示要共用网络 _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
SRGAN_d:
def SRGAN_d(input_images, is_train=True, reuse=False): w_init = tf.random_normal_initializer(stddev=0.02) b_init = None # tf.constant_initializer(value=0.0) gamma_init = tf.random_normal_initializer(1., 0.02) df_dim = 64 lrelu = lambda x: tl.act.lrelu(x, 0.2) # 开始进行网络的构造 with tf.variable_scope("SRGAN_d", reuse=reuse): tl.layers.set_name_reuse(reuse) net_in = InputLayer(input_images, name='input/images') net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c') net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h1/c') net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h1/bn') net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h2/c') net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h2/bn') net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h3/c') net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h3/bn') net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h4/c') net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h4/bn') net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h5/c') net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h5/bn') net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h6/c') net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h6/bn') net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h7/c') net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/bn') net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c') net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn') net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c2') net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn2') net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c3') net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='res/bn3') net_h8 = ElementwiseLayer([net_h7, net], combine_fn=tf.add, name='res/add') net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2) # 拉长卷积结果,通过全连接层 net_ho = FlattenLayer(net_h8, name='ho/flatten') net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense') logits = net_ho.outputs # 经过sigmoid函数得到最终的结果值,判断是真还是假 net_ho.outputs = tf.nn.sigmoid(net_ho.outputs) return net_ho, logits
2.4.VGG特征提取网络
## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA # 修改数据的尺寸大小,以满足VGG网络的要求 # 对原始图像进行resize t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer # 对生成图像进行resize t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)
Vgg19_simple_api:
def Vgg19_simple_api(rgb, reuse): """ Build the VGG 19 Model Parameters ----------- rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1] """ VGG_MEAN = [103.939, 116.779, 123.68] with tf.variable_scope("VGG19", reuse=reuse) as vs: start_time = time.time() print("build model started") rgb_scaled = rgb * 255.0 # Convert RGB to BGR red, green, blue = tf.split(rgb_scaled, 3, 3) assert red.get_shape().as_list()[1:] == [224, 224, 1] assert green.get_shape().as_list()[1:] == [224, 224, 1] assert blue.get_shape().as_list()[1:] == [224, 224, 1] # 减均值操作:各自的颜色通道减去各自的均值 bgr = tf.concat( [ blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2], ], axis=3) assert bgr.get_shape().as_list()[1:] == [224, 224, 3] """ input layer """ net_in = InputLayer(bgr, name='input') """ conv1 """ network = Conv2d(net_in, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_1') network = Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_2') network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool1') """ conv2 """ network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_1') network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_2') network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool2') """ conv3 """ network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_1') network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_2') network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_3') network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_4') network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool3') """ conv4 """ network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_1') network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_2') network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_3') network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_4') network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool4') # (batch_size, 14, 14, 512) conv = network """ conv5 """ network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_1') network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_2') network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_3') network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_4') network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool5') # (batch_size, 7, 7, 512) """ fc 6~8 """ # 拉长数据经过全连接层 network = FlattenLayer(network, name='flatten') network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6') network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7') network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8') print("build model finished: %fs" % (time.time() - start_time)) return network, conv
2.5.损失函数
# ###========================== DEFINE TRAIN OPS ==========================### # 判别器的loss设置: # 如果是真实图像,设置ones_like d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') # 如果是假图像,设置zeros_like d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 # 希望生成网络生成的图片是真的,设置ones_like g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') # 生成结果和真实图片进行比较 mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) # 生成结果和真是图片经过VGG网络提取特征后的比较 vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) # 生成对抗网络的最终loss g_loss = mse_loss + vgg_loss + g_gan_loss
三.完整代码
main.py
#! /usr/bin/python # -*- coding: utf8 -*- #http://tensorlayercn.readthedocs.io/zh/latest/user/installation.html import os import time import pickle, random #from datetime import datetime import numpy as np import logging, scipy import tensorflow as tf import tensorlayer as tl from model import SRGAN_g, SRGAN_d, Vgg19_simple_api from utils import * from config import config, log_config ###====================== HYPER-PARAMETERS ===========================### ## Adam batch_size = config.TRAIN.batch_size lr_init = config.TRAIN.lr_init beta1 = config.TRAIN.beta1 ## initialize G n_epoch_init = config.TRAIN.n_epoch_init ## adversarial learning (SRGAN) n_epoch = config.TRAIN.n_epoch lr_decay = config.TRAIN.lr_decay decay_every = config.TRAIN.decay_every ni = int(np.sqrt(batch_size)) def train(): ## create folders to save result images and trained model save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir_ginit) tl.files.exists_or_mkdir(save_dir_gan) checkpoint_dir = "checkpoint" # checkpoint_resize_conv tl.files.exists_or_mkdir(checkpoint_dir) ###====================== PRE-LOAD DATA ===========================### # 通过tl.files.load_file_list获取图片名字 # 第一个参数是图片所在的文件夹的路径,第二个参数为图片类型 # sorted[:x]表示读取x个图像(读取图像过多可能造成memory error问题) train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:100] train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))[:100] valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:50] valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:50] # If your machine have enough memory, please pre-load the whole train set. # 通过tl.vis.read_images读取图片 # 第一个参数是先前取出的图片,第二个参数是图片所在的文件夹地址,第三个参数是一次性读取多少图片 train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=8) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) # for im in valid_lr_imgs: # print(im.shape) # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### ## train inference tf.compat.v1.disable_eager_execution() t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') # 构建生成网络 # reuse=False表示不共用网络 net_g = SRGAN_g(t_image, is_train=True, reuse=False) # 构建判别网络 # 让判别网络判断什么是真的,传入的数据参数是真实的图像数据 # reuse=False表示不共用网络 net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) # 让判别网络判断什么是假的,传入的数据参数是生成网络生成的图像数据 # reuse=True表示要共用网络 _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) net_g.print_params(False) net_g.print_layers() net_d.print_params(False) net_d.print_layers() ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA # 修改数据的尺寸大小,以满足VGG网络的要求 # 对原始图像进行resize t_target_image_224 = tf.image.resize_images( t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer # 对生成图像进行resize t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False) _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True) ## test inference net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) # ###========================== DEFINE TRAIN OPS ==========================### # 判别器的loss设置: # 如果是真实图像,设置ones_like d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') # 如果是假图像,设置zeros_like d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') d_loss = d_loss1 + d_loss2 # 希望生成网络生成的图片是真的,设置ones_like g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') # 生成结果和真实图片进行比较 mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True) # 生成结果和真是图片经过VGG网络提取特征后的比较 vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) # 生成对抗网络的最终loss g_loss = mse_loss + vgg_loss + g_gan_loss # 获取参数 g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) ## Pretrain g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) ## SRGAN g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) ###========================== RESTORE MODEL =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) ###============================= LOAD VGG ===============================### vgg19_npy_path = "vgg19.npy" if not os.path.isfile(vgg19_npy_path): print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg") exit() npz = np.load(vgg19_npy_path, encoding='latin1').item() params = [] for val in sorted(npz.items()): W = np.asarray(val[1][0]) b = np.asarray(val[1][1]) print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) params.extend([W, b]) tl.files.assign_params(sess, params, net_vgg) # net_vgg.print_params(False) # net_vgg.print_layers() print ('ok') ###============================= TRAINING ===============================### ## use first `batch_size` of train set to have a quick test during training sample_imgs = train_hr_imgs[0:batch_size] # sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png') tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png') tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png') ###========================= initialize G ====================### ## fixed learning rate sess.run(tf.assign(lr_v, lr_init)) print(" ** fixed learning rate: %f (for init G)" % lr_init) for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() total_mse_loss, n_iter = 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update G errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) total_mse_loss += errM n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) ###========================= train GAN (SRGAN) =========================### for epoch in range(0, n_epoch + 1): ## update learning rate if epoch != 0 and (epoch % decay_every == 0): new_lr_decay = lr_decay**(epoch // decay_every) sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) print(log) elif epoch == 0: sess.run(tf.assign(lr_v, lr_init)) log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay) print(log) epoch_time = time.time() total_d_loss, total_g_loss, n_iter = 0, 0, 0 ## If your machine cannot load all images into memory, you should use ## this one to load batch of images while training. # random.shuffle(train_hr_img_list) # for idx in range(0, len(train_hr_img_list), batch_size): # step_time = time.time() # b_imgs_list = train_hr_img_list[idx : idx + batch_size] # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## If your machine have enough memory, please pre-load the whole train set. for idx in range(0, len(train_hr_imgs), batch_size): step_time = time.time() b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True) b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) ## update D errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) ## update G errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384}) print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) ## quick evaluation on train set if (epoch != 0) and (epoch % 10 == 0): out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max()) print("[*] save images") tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch) ## save model if (epoch != 0) and (epoch % 10 == 0): tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) def evaluate(): ## create folders to save result images save_dir = "samples/{}".format(tl.global_flag['mode']) tl.files.exists_or_mkdir(save_dir) checkpoint_dir = "checkpoint" ###====================== PRE-LOAD DATA ===========================### # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the whole train set. # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=8) # for im in valid_lr_imgs: # print(im.shape) valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=8) # for im in valid_hr_imgs: # print(im.shape) # exit() ###========================== DEFINE MODEL ============================### imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 valid_lr_img = valid_lr_imgs[imid] valid_hr_img = valid_hr_imgs[imid] # valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] # print(valid_lr_img.min(), valid_lr_img.max()) size = valid_lr_img.shape # t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') net_g = SRGAN_g(t_image, is_train=False, reuse=False) ###========================== RESTORE G =============================### sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g) ###======================= EVALUATION =============================### start_time = time.time() out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) print("took: %4.4fs" % (time.time() - start_time)) print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) print("[*] save images") tl.vis.save_image(out[0], save_dir + '/valid_gen.png') tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png') tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png') out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None) tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png') if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate') args = parser.parse_args() tl.global_flag['mode'] = args.mode if tl.global_flag['mode'] == 'srgan': train() elif tl.global_flag['mode'] == 'evaluate': evaluate() else: raise Exception("Unknow --mode")
config.py
from easydict import EasyDict as edict import json config = edict() config.TRAIN = edict() ## Adam config.TRAIN.batch_size = 4 config.TRAIN.lr_init = 1e-4 config.TRAIN.beta1 = 0.9 ## initialize G config.TRAIN.n_epoch_init = 100 # config.TRAIN.lr_decay_init = 0.1 # config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2) ## adversarial learning (SRGAN) config.TRAIN.n_epoch = 2000 config.TRAIN.lr_decay = 0.1 config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2) ## train set location config.TRAIN.hr_img_path = './srdata/DIV2K_train_HR' config.TRAIN.lr_img_path = './srdata/DIV2K_train_LR_bicubic/X4' config.VALID = edict() ## test set location config.VALID.hr_img_path = './srdata/DIV2K_valid_HR' config.VALID.lr_img_path = './srdata/DIV2K_valid_LR_bicubic/X4' def log_config(filename, cfg): with open(filename, 'w') as f: f.write("================================================\n") f.write(json.dumps(cfg, indent=4)) f.write("\n================================================\n")
dowmload_imagenet.py
import argparse import socket import os import urllib import numpy as np from PIL import Image from joblib import Parallel, delayed def download_image(download_str, save_dir): img_name, img_url = download_str.strip().split('\t') save_img = os.path.join(save_dir, "{}.jpg".format(img_name)) downloaded = False try: if not os.path.isfile(save_img): print("Downloading {} to {}.jpg".format(img_url, img_name)) urllib.urlretrieve(img_url, save_img) # Check size of the images downloaded = True with Image.open(save_img) as img: width, height = img.size img_size_bytes = os.path.getsize(save_img) img_size_KB = img_size_bytes / 1024 if widthmodel.py
#! /usr/bin/python # -*- coding: utf8 -*- import tensorflow as tf import tensorlayer as tl from tensorlayer.layers import * import time import os # from tensorflow.python.ops import variable_scope as vs # from tensorflow.python.ops import math_ops, init_ops, array_ops, nn # from tensorflow.python.util import nest # from tensorflow.contrib.rnn.python.ops import core_rnn_cell # https://github.com/david-gpu/srez/blob/master/srez_model.py def SRGAN_g(t_image, is_train=False, reuse=False): """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network feature maps (n) and stride (s) feature maps (n) and stride (s) """ # 权重初始化 w_init = tf.random_normal_initializer(stddev=0.02) b_init = None # tf.constant_initializer(value=0.0) # gamma值初始化(BatchNormalization中的参数) g_init = tf.random_normal_initializer(1., 0.02) # tf.compat.v1.disable_v2_behavior() with tf.compat.v1.variable_scope("SRGAN_g", reuse=reuse) as vs: # tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+ # 输入层构造 n = InputLayer(t_image, name='in') # 卷积层构造 n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c') temp = n # B residual blocks(增加16层残差模块) for i in range(16): nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i) nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i) nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i) nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i) # 两个网络相融合形成残差网络:nn = n + nn # 其中n是最初未经处理的网络,nn是处理后的网络(此处是经过两次卷积和两次BatchNormalization) nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i) n = nn n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m') n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m') # 最开始的网络和执行了一系列处理后的网络再进行一次融合形成新的网络 n = ElementwiseLayer([n, temp], tf.add, name='add3') # B residual blacks end # 开始对照片进行重构操作,由低分辨率重构成高分辨率 n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1') n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2') n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') # 重构后进行一次卷积得到最终的结果 n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out') return n def SRGAN_g2(t_image, is_train=False, reuse=False): """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network feature maps (n) and stride (s) feature maps (n) and stride (s) 96x96 --> 384x384 Use Resize Conv """ w_init = tf.random_normal_initializer(stddev=0.02) b_init = None # tf.constant_initializer(value=0.0) g_init = tf.random_normal_initializer(1., 0.02) size = t_image.get_shape().as_list() with tf.variable_scope("SRGAN_g", reuse=reuse) as vs: # tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+ n = InputLayer(t_image, name='in') n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c') temp = n # B residual blocks for i in range(16): nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i) nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i) nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i) nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i) nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i) n = nn n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m') n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m') n = ElementwiseLayer([n, temp], tf.add, name='add3') # B residual blacks end # n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1') # n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') # # n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2') # n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') ## 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA n = UpSampling2dLayer(n, size=[size[1] * 2, size[2] * 2], is_scale=False, method=1, align_corners=False, name='up1/upsample2d') n = Conv2d(n, 64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='up1/conv2d') #