CycleGAN 流程:

[1] - 从域 A 中选择一张图片input_A,通过Generator A2B 变换为域 B 中的一张假图片Generated_B,计算分类器 Discrimination_B 的损失;

[2] - 将生成的假图片通过Generator B2A 变换回域 A,得到 Cyclic_A,计算其与 input_A 之间的损失;

[3] - 另一网络,执行相同的流程处理.

1. CycleGAN

#!/usr/bin/python3
#!--*-- coding: utf-8 --*--
"""
Pix2pix.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
from absl import app
from absl import flags

import tensorflow as tf
import tensorflow_datasets as tfds

FLAGS = flags.FLAGS

flags.DEFINE_integer('buffer_size', 400, 'Shuffle buffer size')
flags.DEFINE_integer('batch_size', 1, 'Batch Size')
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
flags.DEFINE_string('path', None, 'Path to the data folder')
flags.DEFINE_boolean('enable_function', True, 'Enable Function?')

IMG_WIDTH = 256
IMG_HEIGHT = 256
AUTOTUNE = tf.data.experimental.AUTOTUNE # 根据可用的CPU动态设置并行调用的数量


def random_crop(image):
  cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image


def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image


@tf.function
def random_jitter(image):
    # 调整大小为 286 x 286 x 3
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # 随机裁剪到 256 x 256 x 3
    image = random_crop(image)
    # 随机镜像
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_image_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    return image


def preprocess_image_test(image, label):
    image = normalize(image)
    return image


def create_horse2zebra_dataset(buffer_size, batch_size):
    """
    Creates a tf.data Dataset.

    Args:
      path_to_train_images: Path to train images folder.
      path_to_test_images: Path to test images folder.
      buffer_size: Shuffle buffer size.
      batch_size: Batch size

    Returns:
      train dataset, test dataset
    """
    dataset, metadata = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)

    train_horses, train_zebras = dataset['trainA'], dataset['trainB']
    test_horses, test_zebras = dataset['testA'], dataset['testB']
    #
    train_horses = train_horses.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache()
    train_horses = train_horses.shuffle(buffer_size)
    train_horses = train_horses.batch(batch_size)

    train_zebras = train_zebras.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache()
    train_zebras = train_zebras.shuffle(buffer_size)
    train_zebras = train_zebras.batch(batch_size)

    test_horses = test_horses.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache()
    test_horses = test_horses.shuffle(buffer_size)
    test_horses = test_horses.batch(batch_size)

    test_zebras = test_zebras.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache()
    test_zebras = test_zebras.shuffle(buffer_size)
    test_zebras = test_zebras.batch(batch_size)

    #
    dataset_dict = {}
    dataset_dict['train_horses'] = train_horses
    dataset_dict['train_zebras'] = train_zebras
    dataset_dict['test_horses'] = test_horses
    dataset_dict['test_zebras'] = test_zebras

    return dataset_dict


class InstanceNormalization(tf.keras.layers.Layer):
  """
  Instance Normalization Layer (https://arxiv.org/abs/1607.08022).
  """

  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name='scale',
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable=True)

    self.offset = self.add_weight(
        name='offset',
        shape=input_shape[-1:],
        initializer='zeros',
        trainable=True)

  def call(self, x):
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset


def downsample(filters, size, norm_type='batchnorm', apply_norm=True):
  """
  Downsamples an input.

  Conv2D => Batchnorm => LeakyRelu

  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_norm: If True, adds the batchnorm layer

  Returns:
    Downsample Sequential Model
  """
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_norm:
    if norm_type.lower() == 'batchnorm':
      result.add(tf.keras.layers.BatchNormalization())
    elif norm_type.lower() == 'instancenorm':
      result.add(InstanceNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result


def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
  """
  Upsamples an input.

  Conv2DTranspose => Batchnorm => Dropout => Relu

  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_dropout: If True, adds the dropout layer

  Returns:
    Upsample Sequential Model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  elif norm_type.lower() == 'instancenorm':
    result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result


def unet_generator(output_channels, norm_type='batchnorm'):
  """
  Modified u-net generator model (https://arxiv.org/abs/1611.07004).

  Args:
    output_channels: Output channels
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.

  Returns:
    Generator model
  """

  down_stack = [
      downsample(64, 4, norm_type, apply_norm=False),  # (bs, 128, 128, 64)
      downsample(128, 4, norm_type),  # (bs, 64, 64, 128)
      downsample(256, 4, norm_type),  # (bs, 32, 32, 256)
      downsample(512, 4, norm_type),  # (bs, 16, 16, 512)
      downsample(512, 4, norm_type),  # (bs, 8, 8, 512)
      downsample(512, 4, norm_type),  # (bs, 4, 4, 512)
      downsample(512, 4, norm_type),  # (bs, 2, 2, 512)
      downsample(512, 4, norm_type),  # (bs, 1, 1, 512)
  ]

  up_stack = [
      upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 2, 2, 1024)
      upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 4, 4, 1024)
      upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 8, 8, 1024)
      upsample(512, 4, norm_type),  # (bs, 16, 16, 1024)
      upsample(256, 4, norm_type),  # (bs, 32, 32, 512)
      upsample(128, 4, norm_type),  # (bs, 64, 64, 256)
      upsample(64, 4, norm_type),  # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 4, strides=2,
      padding='same', kernel_initializer=initializer,
      activation='tanh')  # (bs, 256, 256, 3)

  concat = tf.keras.layers.Concatenate()

  inputs = tf.keras.layers.Input(shape=[None, None, 3])
  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)


def discriminator(norm_type='batchnorm', target=True):
  """
  PatchGan discriminator model (https://arxiv.org/abs/1611.07004).

  Args:
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
    target: Bool, indicating whether target image is an input or not.

  Returns:
    Discriminator model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
  x = inp

  if target:
    tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, norm_type, False)(x)  # (bs, 128, 128, 64)
  down2 = downsample(128, 4, norm_type)(down1)  # (bs, 64, 64, 128)
  down3 = downsample(256, 4, norm_type)(down2)  # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(
      512, 4, strides=1, kernel_initializer=initializer,
      use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

  if norm_type.lower() == 'batchnorm':
    norm1 = tf.keras.layers.BatchNormalization()(conv)
  elif norm_type.lower() == 'instancenorm':
    norm1 = InstanceNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(
      1, 4, strides=1,
      kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

  if target:
    return tf.keras.Model(inputs=[inp, tar], outputs=last)
  else:
    return tf.keras.Model(inputs=inp, outputs=last)


def get_checkpoint_prefix():
  checkpoint_dir = './training_checkpoints'
  checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

  return checkpoint_prefix

#
class CycleGAN(object):
    def __init__(self, epochs, enable_function):
        self.epochs = epochs
        self.enable_function = enable_function
        self.lambda_value = 10
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        self.generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.generator_g = unet_generator(output_channels=3, norm_type='instancenorm')
        self.generator_f = unet_generator(output_channels=3, norm_type='instancenorm')
        self.discriminator_x = discriminator(norm_type='instancenorm', target=False)
        self.discriminator_y = discriminator(norm_type='instancenorm', target=False)

        self.checkpoint = tf.train.Checkpoint(
            generator_g=self.generator_g,
            generator_f=self.generator_f,
            discriminator_x=self.discriminator_x,
            discriminator_y=self.discriminator_y,
            generator_g_optimizer=self.generator_g_optimizer,
            generator_f_optimizer=self.generator_f_optimizer,
            discriminator_x_optimizer=self.discriminator_x_optimizer,
            discriminator_y_optimizer=self.discriminator_y_optimizer)


    def discriminator_loss(self, disc_real_output, disc_generated_output):
        real_loss = self.loss_object(tf.ones_like(disc_real_output), disc_real_output)

        generated_loss = self.loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

    #
    def generator_loss(self, generated):
        gan_loss = self.loss_object(tf.ones_like(generated), generated)
        return gan_loss

    def calc_cycle_loss(self, real_image, cycled_image):
      loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

      return self.lambda_value * loss1

    def identity_loss(self, real_image, same_image):
      loss = tf.reduce_mean(tf.abs(real_image - same_image))
      return self.lambda_value * 0.5 * loss

    @tf.function
    def train_step(self, real_x, real_y):
        # persistent 设置为 Ture,因为 GradientTape 被多次应用于计算梯度
        with tf.GradientTape(persistent=True) as tape:
            # 生成器 G 转换 X -> Y
            # 生成器 F 转换 Y -> X

            fake_y = self.generator_g(real_x, training=True)
            cycled_x = self.generator_f(fake_y, training=True)

            fake_x = self.generator_f(real_y, training=True)
            cycled_y = self.generator_g(fake_x, training=True)

            # same_x 和 same_y 用于一致性损失。
            same_x = self.generator_f(real_x, training=True)
            same_y = self.generator_g(real_y, training=True)

            disc_real_x = self.discriminator_x(real_x, training=True)
            disc_real_y = self.discriminator_y(real_y, training=True)

            disc_fake_x = self.discriminator_x(fake_x, training=True)
            disc_fake_y = self.discriminator_y(fake_y, training=True)

            # 计算损失
            gen_g_loss = self.generator_loss(disc_fake_y)
            gen_f_loss = self.generator_loss(disc_fake_x)

            total_cycle_loss = self.calc_cycle_loss(real_x, cycled_x) + self.calc_cycle_loss(real_y, cycled_y)

            # 总生成器损失 = 对抗性损失 + 循环损失。
            total_gen_g_loss = gen_g_loss + total_cycle_loss + self.identity_loss(real_y, same_y)
            total_gen_f_loss = gen_f_loss + total_cycle_loss + self.identity_loss(real_x, same_x)

            disc_x_loss = self.discriminator_loss(disc_real_x, disc_fake_x)
            disc_y_loss = self.discriminator_loss(disc_real_y, disc_fake_y)

        # 计算生成器和判别器损失。
        generator_g_gradients = tape.gradient(total_gen_g_loss, self.generator_g.trainable_variables)
        generator_f_gradients = tape.gradient(total_gen_f_loss, self.generator_f.trainable_variables)

        discriminator_x_gradients = tape.gradient(disc_x_loss, self.discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(disc_y_loss, self.discriminator_y.trainable_variables)

        # 将梯度应用于优化器。
        self.generator_g_optimizer.apply_gradients(zip(generator_g_gradients, self.generator_g.trainable_variables))
        self.generator_f_optimizer.apply_gradients(zip(generator_f_gradients, self.generator_f.trainable_variables))

        self.discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, self.discriminator_x.trainable_variables))
        self.discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, self.discriminator_y.trainable_variables))

        return total_gen_g_loss, total_gen_f_loss, disc_x_loss, disc_y_loss

    def train(self, dataset, checkpoint_pr):
        """
        Train the GAN for x number of epochs.

        Args:
          dataset: train dataset.
          checkpoint_pr: prefix in which the checkpoints are stored.

        Returns:
          Time for each epoch.
        """
        time_list = []
        if self.enable_function:
            self.train_step = tf.function(self.train_step)

        #
        train_horses, train_zebras = dataset['train_horses'], dataset['train_zebras']
        for epoch in range(self.epochs):
            start_time = time.time()
            for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
                total_gen_g_loss, total_gen_f_loss, disc_x_loss, disc_y_loss = self.train_step(image_x, image_y)

            wall_time_sec = time.time() - start_time
            time_list.append(wall_time_sec)

            # saving (checkpoint) the model every 20 epochs
            if (epoch + 1) % 10 == 0:
                self.checkpoint.save(file_prefix=checkpoint_pr)

            template = 'Epoch {}, Generator G loss {}, Generator F loss {}, Discriminator X Loss {}, Discriminator Y Loss {}'
            print(template.format(epoch, total_gen_g_loss, total_gen_f_loss, disc_x_loss, disc_y_loss))

        return time_list


def run_main(argv):
  del argv
  kwargs = {'epochs': FLAGS.epochs,
            'enable_function': FLAGS.enable_function,
            'path': FLAGS.path,
            'buffer_size': FLAGS.buffer_size,
            'batch_size': FLAGS.batch_size}
  main(**kwargs)


def main(epochs, enable_function, path, buffer_size, batch_size):
  path_to_folder = path

  cyclegan_object = CycleGAN(epochs, enable_function)

  dataset_dict = create_horse2zebra_dataset(buffer_size, batch_size)
  checkpoint_pr = get_checkpoint_prefix()
  print ('Training ...')
  return cyclegan_object.train(dataset_dict, checkpoint_pr)


if __name__ == '__main__':
  app.run(run_main)
Last modification:November 10th, 2020 at 03:06 pm