Generative adversarial networks (GANs)

Intro

What are GANs:

GANs are Neural Network architecture comprised of two networks: generator and discriminator, that are pitted against each other. original paper

One way to think about GANs is the competition between a cop and a counterfeiter. Both learning new tricks aimed to gain advantage. This produces and double feedback loop.

Architecture Great intro from deeplearning4j

Objective is to produce a generator that produces high authenticity samples

samples Samples generated based on CelebA dataset paper by NVIDIA

Other applications:

  1. Text to Image generation
  2. Image to Image translation
  3. Super-resolution images

Other generative models?

Variational autoencoder

(Tends to produce blurry-ish images) typically attributed to the objective function difference

Training pseudo-code:

code

How it looks in the real code using TFGAN library

gan_model = tfgan.gan_model(
    generator_fn,
    discriminator_fn,
    real_data=images,
    generator_inputs=tf.random_normal([batch_size, noise_dims]))

vanilla_gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.minimax_generator_loss,
    discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)

generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)

gan_train_ops = tfgan.gan_train_ops(
    gan_model,
    improved_wgan_loss,
    generator_optimizer,
    discriminator_optimizer)

train_step_fn = tfgan.get_sequential_train_steps()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    with slim.queues.QueueRunners(sess):
        for i in range(num_steps):
            cur_loss, _ = train_step_fn(
                sess, gan_train_ops, global_step, train_step_kwargs={})

Drawing

Difficulties in training, loss functions

1. Two networks ideally should train equally fast (gradients can go very low)

2. Changing objective function (both nets are training)

3. Model collapse (noise is mapped to the same region of space that marginally looks real)

To learn more:

  • deeplearning4j nice overview, comprehensive list of papers, research directions
  • TFGAN library understandable examples based on TF library
In [ ]:
# just potentially usefull animations
#![dec](images/deconvolution.gif)