From: tensorflow_examples/models/dcgan

1. dcgan.py

#!/usr/bin/python3
#!--*-- coding: utf-8 --*--
"""
DCGAN.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
from absl import app
from absl import flags
import tensorflow as tf
import tensorflow_datasets as tfds

FLAGS = flags.FLAGS

flags.DEFINE_integer('buffer_size', 10000, 'Shuffle buffer size')
flags.DEFINE_integer('batch_size', 64, 'Batch Size')
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
flags.DEFINE_boolean('enable_function', True, 'Enable Function?')

AUTOTUNE = tf.data.experimental.AUTOTUNE


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image = (image - 127.5) / 127.5

  return image, label


def create_dataset(buffer_size, batch_size):
  train_dataset = tfds.load('mnist', split='train', as_supervised=True, shuffle_files=True)
  train_dataset = train_dataset.map(scale, num_parallel_calls=AUTOTUNE)
  train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size)
  return train_dataset


def make_generator_model():
  """
  Generator.

  Returns:
    Keras Sequential model
  """
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(7*7*256, use_bias=False),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(),
      tf.keras.layers.Reshape((7, 7, 256)),
      tf.keras.layers.Conv2DTranspose(128, 5, strides=(1, 1),
                                      padding='same', use_bias=False),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(),
      tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2),
                                      padding='same', use_bias=False),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(),
      tf.keras.layers.Conv2DTranspose(1, 5, strides=(2, 2),
                                      padding='same', use_bias=False,
                                      activation='tanh')
  ])

  return model


def make_discriminator_model():
  """
  Discriminator.

  Returns:
    Keras Sequential model
  """
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(64, 5, strides=(2, 2), padding='same'),
      tf.keras.layers.LeakyReLU(),
      tf.keras.layers.Dropout(0.3),
      tf.keras.layers.Conv2D(128, 5, strides=(2, 2), padding='same'),
      tf.keras.layers.LeakyReLU(),
      tf.keras.layers.Dropout(0.3),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(1)
  ])

  return model


def get_checkpoint_prefix():
  checkpoint_dir = './training_checkpoints'
  checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

  return checkpoint_prefix


class Dcgan(object):
  """
  Dcgan class.

  Args:
    epochs: Number of epochs.
    enable_function: If true, train step is decorated with tf.function.
    batch_size: Batch size.
  """

  def __init__(self, epochs, enable_function, batch_size):
    self.epochs = epochs
    self.enable_function = enable_function
    self.batch_size = batch_size
    self.noise_dim = 100
    self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    self.generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    self.discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
    self.generator = make_generator_model()
    self.discriminator = make_discriminator_model()
    self.checkpoint = tf.train.Checkpoint(
        generator_optimizer=self.generator_optimizer,
        discriminator_optimizer=self.discriminator_optimizer,
        generator=self.generator,
        discriminator=self.discriminator)

  def generator_loss(self, generated_output):
    return self.loss_object(tf.ones_like(generated_output), generated_output)

  def discriminator_loss(self, real_output, generated_output):
    real_loss = self.loss_object(tf.ones_like(real_output), real_output)
    generated_loss = self.loss_object(tf.zeros_like(generated_output), generated_output)

    total_loss = real_loss + generated_loss

    return total_loss

  def train_step(self, image):
    """
    One train step over the generator and discriminator model.

    Args:
      image: Input image.

    Returns:
      generator loss, discriminator loss.
    """
    noise = tf.random.normal([self.batch_size, self.noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = self.generator(noise, training=True)

      real_output = self.discriminator(image, training=True)
      generated_output = self.discriminator(generated_images, training=True)

      gen_loss = self.generator_loss(generated_output)
      disc_loss = self.discriminator_loss(real_output, generated_output)

    gradients_of_generator = gen_tape.gradient(
        gen_loss, self.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(
        disc_loss, self.discriminator.trainable_variables)

    self.generator_optimizer.apply_gradients(zip(
        gradients_of_generator, self.generator.trainable_variables))
    self.discriminator_optimizer.apply_gradients(zip(
        gradients_of_discriminator, self.discriminator.trainable_variables))

    return gen_loss, disc_loss

  def train(self, dataset, checkpoint_pr):
    """
    Train the GAN for x number of epochs.

    Args:
      dataset: train dataset.
      checkpoint_pr: prefix in which the checkpoints are stored.

    Returns:
      Time for each epoch.
    """
    time_list = []
    if self.enable_function:
      self.train_step = tf.function(self.train_step)

    for epoch in range(self.epochs):
      start_time = time.time()
      for image, _ in dataset:
        gen_loss, disc_loss = self.train_step(image)

      wall_time_sec = time.time() - start_time
      time_list.append(wall_time_sec)

      # saving (checkpoint) the model every 15 epochs
      if (epoch + 1) % 15 == 0:
        self.checkpoint.save(file_prefix=checkpoint_pr)

      template = 'Epoch {}, Generator loss {}, Discriminator Loss {}'
      print (template.format(epoch, gen_loss, disc_loss))

    return time_list


def run_main(argv):
  del argv
  kwargs = {'epochs': FLAGS.epochs,
            'enable_function': FLAGS.enable_function,
            'buffer_size': FLAGS.buffer_size,
            'batch_size': FLAGS.batch_size}
  main(**kwargs)


def main(epochs, enable_function, buffer_size, batch_size):
  train_dataset = create_dataset(buffer_size, batch_size)
  checkpoint_pr = get_checkpoint_prefix()

  dcgan_obj = Dcgan(epochs, enable_function, batch_size)
  print ('Training ...')
  return dcgan_obj.train(train_dataset, checkpoint_pr)

if __name__ == '__main__':
  app.run(run_main)

2. dcgan_test.py

"""
DCGAN tests.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags
import tensorflow as tf
from tensorflow_examples.models.dcgan import dcgan

FLAGS = flags.FLAGS


class DcganTest(tf.test.TestCase):

  def test_one_epoch_with_function(self):
    epochs = 1
    batch_size = 1
    enable_function = True

    input_image = tf.random.uniform((28, 28, 1))
    label = tf.zeros((1,))
    train_dataset = tf.data.Dataset.from_tensors(
        (input_image, label)).batch(batch_size)
    checkpoint_pr = dcgan.get_checkpoint_prefix()

    dcgan_obj = dcgan.Dcgan(epochs, enable_function, batch_size)
    dcgan_obj.train(train_dataset, checkpoint_pr)

  def test_one_epoch_without_function(self):
    epochs = 1
    batch_size = 1
    enable_function = False

    input_image = tf.random.uniform((28, 28, 1))
    label = tf.zeros((1,))
    train_dataset = tf.data.Dataset.from_tensors(
        (input_image, label)).batch(batch_size)
    checkpoint_pr = dcgan.get_checkpoint_prefix()

    dcgan_obj = dcgan.Dcgan(epochs, enable_function, batch_size)
    dcgan_obj.train(train_dataset, checkpoint_pr)


class DCGANBenchmark(tf.test.Benchmark):

  def __init__(self, output_dir=None, **kwargs):
    self.output_dir = output_dir

  def benchmark_with_function(self):
    kwargs = {"epochs": 6, 
              "enable_function": True,
              "buffer_size": 10000, 
              "batch_size": 64}
    self._run_and_report_benchmark(**kwargs)

  def benchmark_without_function(self):
    kwargs = {"epochs": 6, 
              "enable_function": False,
              "buffer_size": 10000, 
              "batch_size": 64}
    self._run_and_report_benchmark(**kwargs)

  def _run_and_report_benchmark(self, **kwargs):
    time_list = dcgan.main(**kwargs)
    # 1st epoch is the warmup epoch hence skipping it for calculating time.
    self.report_benchmark(wall_time=tf.reduce_mean(time_list[1:]))

if __name__ == "__main__":
  tf.test.main()
Last modification:November 10th, 2020 at 03:13 pm