Skip to content

Deep Learning: Generative Adversarial Networks

Stepping into the Realm of Generative Adversarial Networks

In the mesmerizing universe of deep learning, Generative Adversarial Networks (GANs) have taken center stage with their uncanny ability to generate content. From crafting lifelike images to producing music, GANs have opened up avenues previously considered the domain of human creativity. By pitting two neural networks against each other in a captivating dance, GANs have transformed our understanding of what machines can create.

The Genesis of GANs

GANs were introduced by Ian Goodfellow and his colleagues in 2014. Their revolutionary approach to generative modeling caught the AI community's imagination, leading to an explosion of research and applications.

Understanding the Yin and Yang of GANs

At its core, a GAN consists of two neural networks – the Generator and the Discriminator. These networks are trained together in a captivating game of cat and mouse.

1. The Generator: The Artist

The generator's primary role is to produce content. Starting with a random noise, it crafts data that closely resembles real-world data. However, in the early stages of training, its creations can be far from perfect.

2. The Discriminator: The Critic

The discriminator is the evaluator. It reviews the content produced by the generator and distinguishes between real data and the data produced by the generator. Its objective is to get better at identifying fake data.

Training GANs: A Game of Deception and Detection

The training process of GANs is akin to a forger trying to create a masterpiece, while an art detective tries to detect which one is fake. Over time, the forger becomes so skilled that the detective can't tell real from fake.

1. The Generator's Perspective

The generator tries to produce data that the discriminator can't distinguish from real data. Its goal is to fool the discriminator.

2. The Discriminator's Perspective

The discriminator's objective is to correctly identify whether the data it's reviewing is coming from the generator or from a real dataset.

3. The Balance

For GANs to be successful, neither the generator nor the discriminator can overpower the other. If the discriminator is too good, the generator struggles to improve. If the generator is too effective, the discriminator can't differentiate real from fake.

Applications of GANs: The Art of Creation

The potential applications of GANs are vast and continue to grow as the technology matures.

  1. Image Generation: GANs can produce high-resolution and lifelike images.
  2. Style Transfer: Transferring the style of one image to another, like painting your photo in the style of Van Gogh.
  3. Data Augmentation: Generating more data for training models, especially useful when real data is limited.
  4. Super-Resolution: Enhancing the resolution of images, turning low-res photos into high-res masterpieces.
  5. Drug Discovery: Crafting molecular structures for new potential drugs.

Challenges in Training GANs

While GANs hold immense promise, they aren't without challenges.

1. Mode Collapse

This occurs when the generator produces limited varieties of samples, thereby fooling the discriminator but not being genuinely diverse or useful.

2. Training Instability

The dynamic between the generator and discriminator can lead to oscillations, where neither network improves.

3. Hyperparameter Sensitivity

GANs can be particularly sensitive to the choice of hyperparameters, making the training process occasionally finicky.

Generating Handwritten Digits with GANs

In this tutorial, we'll craft a GAN that learns to generate images resembling handwritten digits from the famous MNIST dataset.

1. Setting Up the Environment

Ensure you have TensorFlow installed:

pip install tensorflow

2. Loading the MNIST Dataset

The MNIST dataset comprises grayscale images of handwritten digits, each of size 28x28.

from tensorflow.keras.datasets import mnist

(train_images, _), (_, _) = mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

3. Building the GAN

We'll define the generator and discriminator networks.

Generator Network

The generator starts with a dense layer and upsamples several times until it produces a 28x28x1 image.

from tensorflow.keras.layers import Dense, LeakyReLU, Reshape, Conv2DTranspose

def build_generator():
    model = tf.keras.Sequential()
    model.add(Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(LeakyReLU(alpha=0.3))
    model.add(Reshape((7, 7, 256)))

    model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(LeakyReLU(alpha=0.3))

    model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(LeakyReLU(alpha=0.3))

    model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    return model

Discriminator Network

The discriminator is a basic CNN that classifies the images as real or fake.

from tensorflow.keras.layers import Conv2D, Flatten

def build_discriminator():
    model = tf.keras.Sequential()
    model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(LeakyReLU(alpha=0.3))
    model.add(Flatten())
    model.add(Dense(1))
    return model

4. Compiling the GAN

Now, we'll compile both networks. The discriminator will be trained to classify real and fake images, while the combined model will be trained to generate convincing fake images.

generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# For the combined model, we will only train the generator.
discriminator.trainable = False

gan_input = tf.keras.Input(shape=(100,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)

gan = tf.keras.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')

5. Training the GAN

Training a GAN involves alternating between training the discriminator to differentiate real vs. fake and training the generator to produce convincing fake images.

import numpy as np

def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        for _ in range(train_images.shape[0] // batch_size):
            noise = np.random.normal(0, 1, (batch_size, 100))
            generated_images = generator.predict(noise)
            real_images = train_images[np.random.randint(0, train_images.shape[0], batch_size)]
            labels_real = np.ones((batch_size, 1))
            labels_fake = np.zeros((batch_size, 1))

            d_loss_real = discriminator.train_on_batch(real_images, labels_real)
            d_loss_fake = discriminator.train_on_batch(generated_images, labels_fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            noise = np.random.normal(0, 1, (batch_size, 100))
            valid_y = np.array([1] * batch_size)
            g_loss = gan.train_on_batch(noise, valid_y)

        print(f"Epoch: {epoch}, D Loss: {d_loss[0]}, Accuracy: {100 * d_loss[1]}, G Loss: {g_loss}")

train_gan(epochs=10000, batch_size=32)

6. Generating New Images

Once trained, the generator can be used to create new handwritten digits.

import matplotlib.pyplot as plt

noise = np.random.normal(0, 1, (1, 100))
generated_image = generator.predict(noise)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
plt.show()

Version 1.0

This is currently an early version of the learning material and it will be updated over time with more detailed information.

A video will be provided with the learning material as well.

Be sure to subscribe to stay up-to-date with the latest updates.

Need help mastering Machine Learning?

Don't just follow along — join me! Get exclusive access to me, your instructor, who can help answer any of your questions. Additionally, get access to a private learning group where you can learn together and support each other on your AI journey.