You must be wondering who these random people in the above picture are. What if I told you they don’t exist and were created by a computer? I was pretty amazed when I saw this at an AI talk I attended a couple of years back, and it still amazes me to this day.
These fake humans were created by a particular type of deep learning model called a Generative Adversarial Network (GAN). The concept was conceived by Ian Goodfellow in 2014 and has since gained popularity in many artistic applications like bringing the Mona Lisa to life, creating new paintings from scratch, synthesizing DeepFakes, and so on.
GANs can be highly beneficial to the medical AI domain, specifically in the medical imaging area. Most state-of-the-art ML algorithms that have been used to solve medical problems rely on large clinical and biomedical datasets for effectively training their models. However, due to the proprietary nature of medical data, obstacles like data ethics, protection, and patient confidentiality are often encountered, which prevents practitioners from getting access to these datasets. Furthermore, some imaging modalities are costly to acquire, even more so if the disease we are dealing with is rare. Retinal imaging is an example of such an area where there are many rare inherited genetic disorders for which datasets are limited. GANs can be beneficial for augmenting these datasets, allowing us to build better machine learning systems to diagnose and understand diseases. I’m excited to be working in this area for my master’s thesis, and I hope to share my learnings and progress in future posts.
In this article, I will give a simple, albeit slightly mathematical, introduction to GANs. I then train a simple GAN for MNIST digits and discuss some of the challenges with training GANs. Further details are covering in Goodfellow’s original paper titled Generative Adversarial Networks.
GANs: Looking under the hood
A GAN is essentially made up of two neural networks playing a game with each other (for a basic intro to neural networks, check out my article Deep Learning Demystified). The first network is called the generator, while the second is called the discriminator. The generator’s job is that of a counterfeiter – to create something that looks realistic but is not, for example, a 100 dollar bill. On the other hand, the discriminator’s job is that of an authentication expert: someone who checks the 100 dollar bills and judges whether it looks genuine or not. The generator and discriminator are also known as “generative” and “discriminative” models, respectively. Looking from a probabilistic perspective, if we assume a probability distribution over all the real 100 dollar bills in the world and the classification label of , a generative model tries to learn or , while a discriminative model tries to learn .
The reason why this model setup is known as “adversarial” is as follows: generator G tries to create a synthetic data distribution that is as close to the actual data distribution as possible, i.e., minimizes the distance between the distributions, while discriminator D tries to distinguish real and fake data, i.e., maximize the distance between distributions. This is formally represented as a minimax objective:
where the function is the minimax loss function as defined in the original paper:
What does this loss function mean? Well, because D is a discriminative model, is basically the probability of a real image X being classified as real. is a synthetic image generated from some random noise z and is the probability of a synthetic image being classified as real. is the probability of a synthetic image being classified as fake. The represents the average over all samples or an expectation. The minimax now makes sense! The discriminator wants to maximize this probability of real being classified as real plus fake being fake. The generator wants to minimize the probability of fake being classified as fake, i.e., the second term – the first term doesn’t depend on G, so it’s ignorable.
How to train a GAN
Now that we know the fundamental GAN components, we can next examine how to train this model. As an example use-case, I’ll generate handwritten digits. The dataset used to train the models will be the famous MNIST dataset which contains 60,000 grayscale images of handwritten digits.
First, I’ll set up the basic generator architecture. This model takes a 64-dimensional random vector z. It then performs a series of transposed convolution operations on the vector to create the (28, 28) grayscale image eventually.
class Generator(nn.Module): ''' Generator Class Values: z_dim: the dimension of the noise vector, a scalar im_chan: the number of channels of the output image, a scalar (MNIST is black-and-white, so 1 channel is your default) hidden_dim: the inner dimension, a scalar ''' def __init__(self, z_dim=64, im_chan=1, hidden_dim=64): super(Generator, self).__init__() self.z_dim = z_dim # Build the neural network self.gen = nn.Sequential( self.make_gen_block(z_dim, hidden_dim * 4), self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), self.make_gen_block(hidden_dim * 2, hidden_dim), self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): ''' Function to return a sequence of operations corresponding to a generator block of DCGAN, corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation. Parameters: input_channels: how many channels the input feature representation has output_channels: how many channels the output feature representation should have kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) stride: the stride of the convolution final_layer: a boolean, true if it is the final layer and false otherwise (affects activation and batchnorm) ''' # Build the neural block if not final_layer: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), nn.BatchNorm2d(num_features=output_channels), nn.ReLU() ) else: # Final Layer return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), nn.Tanh() ) def unsqueeze_noise(self, noise): ''' Function for completing a forward pass of the generator: Given a noise tensor, returns a copy of that noise with width and height = 1 and channels = z_dim. Parameters: noise: a noise tensor with dimensions (n_samples, z_dim) ''' return noise.view(len(noise), self.z_dim, 1, 1) def forward(self, noise): ''' Function for completing a forward pass of the generator: Given a noise tensor, returns generated images. Parameters: noise: a noise tensor with dimensions (n_samples, z_dim) ''' x = self.unsqueeze_noise(noise) return self.gen(x)
Next, I create the discriminator model. This model takes an image, either the real or fake one, and classifies it accordingly. The architecture is like any binary classifier – a few convolutions downsample the image followed by a fully connected and binary output layer.
class Discriminator(nn.Module): ''' Discriminator Class Values: im_chan: the number of channels of the output image, a scalar (MNIST is black-and-white, so 1 channel is your default) hidden_dim: the inner dimension, a scalar ''' def __init__(self, im_chan=1, hidden_dim=16): super(Discriminator, self).__init__() self.disc = nn.Sequential( self.make_disc_block(im_chan, hidden_dim), self.make_disc_block(hidden_dim, hidden_dim * 2), self.make_disc_block(hidden_dim * 2, 1, final_layer=True), ) def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False): ''' Function to return a sequence of operations corresponding to a discriminator block of DCGAN, corresponding to a convolution, a batchnorm (except for in the last layer), and an activation. Parameters: input_channels: how many channels the input feature representation has output_channels: how many channels the output feature representation should have kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) stride: the stride of the convolution final_layer: a boolean, true if it is the final layer and false otherwise (affects activation and batchnorm) ''' # Build the neural block if not final_layer: return nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size, stride), nn.BatchNorm2d(num_features=output_channels), nn.LeakyReLU(negative_slope=0.2) ) else: # Final Layer return nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size, stride) ) ''' Function for completing a forward pass of the discriminator: Given an image tensor, returns a 1-dimension tensor representing fake/real. Parameters: image: a flattened image tensor with dimension (im_dim) ''' def forward(self, image): disc_pred = self.disc(image) return disc_pred.view(len(disc_pred), -1)
Now comes the main part: training. The basic algorithm proposed in the original GAN paper is as follows: with the minimax objective function, the discriminator is first trained for k iterations by gradient ascent (because it wants to maximize!). Following this, the generator is trained by gradient descent (because it wants to minimize!) for a single iteration. The reason for this setup is because, for the generator to learn how to create an image from scratch, it needs first to receive some “signal” that gives it an incentive to improve its generative ability. Thus, training the discriminator first will allow it to become slightly better than the generator to tell a real from a fake image. Then, this information gets backpropagated when updating the generator, which allows it to improve. Finally, if you’re interested, the original paper gives mathematical proof that this algorithm will converge to the real data distribution.
Below is the code which executes this algorithm. Keeping it simple, I update the discriminator and generator alternatively (i.e., k=1). Also, notice that the minimax loss mathematically works out to the average of the Binary cross-entropy loss for the real images and the generated images!
# training requirements device='cuda' n_epochs = 50 criterion = nn.BCEwithLogitsLoss() batch_size = 128 # generator and optimizer gen = Generator(z_dim).to(device) gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2)) # prepare discriminator and optimizer disc = Discriminator().to(device) disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2)) # stores losses per epoch generator_losses = [] discriminator_losses = [] for epoch in range(n_epochs): # stores losses averaged over all batches mean_gen_loss = 0 mean_disc_loss = 0 # Dataloader returns the batches for real, _ in tqdm(dataloader): cur_batch_size = len(real) real = real.to(device) ## Update discriminator ## disc_opt.zero_grad() fake_noise = torch.rand(cur_batch_size, z_dim, device=device) fake = gen(fake_noise) disc_fake_pred = disc(fake.detach()) disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) disc_real_pred = disc(real) disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred)) disc_loss = (disc_fake_loss + disc_real_loss) / 2 # Keep track of the average discriminator loss mean_dis_loss += disc_loss.item() / cur_batch_size # Update gradients disc_loss.backward(retain_graph=True) # Update optimizer disc_opt.step() ## Update generator ## gen_opt.zero_grad() fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device) fake_2 = gen(fake_noise_2) disc_fake_pred = disc(fake_2) gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) gen_loss.backward() gen_opt.step() # Keep track of the average generator loss mean_gen_loss += gen_loss.item() / cur_batch_size generator_losses.append(mean_gen_loss) discriminator_losses.append(mean_disc_loss)
After training, I get the following image generations. Some images look a bit unclear, but overall it looks pretty good, I think!
Problems in GAN training
While the entire setup of the GAN is elegant, the actual training process is not! The handwritten digit generation task I executed is relatively simple, and in general, most image datasets will contain very detailed features which are more complicated to generate.
Also, commenting on the model training, we are now optimizing two neural networks simultaneously, so just observing the loss function is not enough – one also has to make sure that neither of the networks becomes too powerful than the other. Thinking about it intuitively, if the discriminator becomes too good that it can spot any fake image, then the generator is just going to think, “what’s the point of even improving if I’m always going to get caught.” But if the generator becomes too powerful that it fools the discriminator easily; that scenario is also bad for the generator because it just thinks, “I don’t even have to try so hard because the discriminator is crap!”. In practice, the former scenario is much more likely to happen because generating a realistic sample is much more complicated than classifying it as real or fake – meaning the discriminator is more likely to win the minimax game. Computationally, this scenario would lead to downstream issues like training divergence and vanishing gradients, which we want to avoid.
Another common issue with GANs is the idea of mode collapse – this is essentially where the generator finds the cheat code to defeat the discriminator and keeps using it to win the game. For example, if we have a multi-modal distribution – the distribution of handwritten numbers has a mode for “1”, “2”, “3”, and so on. Generating a “1” is probably much easier than other numbers, so over the training process, the generator will create “1” every time, fooling the discriminator. We don’t want this to happen because we expect that the GAN generates diverse data samples.
Finally, other aspects of data generation like conditional generation, i.e., generating images of a specific class, controllable generation, i.e., creating images with specific details in them – these features are not available with the basic GAN. Over the past five years, there has been an explosion in GANs research, and many papers have come out discussing better ways to design and train GANs to achieve these goals. Do check out the GAN zoo, which is a massive repository of all the various types of GANs that have been developed to date.
Regardless, you now know the fundamentals of generative adversarial networks, congratulations! In my future posts, I hope to delve deeper into the fine details of GANs, discuss different state-of-the-art models, and most importantly, applications to the medical setting. Until then, stay tuned!