πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

GAN Fundamentals

⭐ Premium

Advertisement

GAN Fundamentals

Generative Adversarial Networks learn to generate realistic data by pitting two neural networks against each other: a generator creates fake samples, and a discriminator tries to tell real from fake. Through this adversarial game, both improve until the generator produces indistinguishable data.

GAN Architecture

Generative Adversarial Networkz ~ p(z)Generator Gz β†’ G(z)FakeDiscriminator DClassifies Real vs FakeOutput: P(real)Real Datax ~ p_dataReal / Fake DecisionBackprop: fool Dmin_G max_D V(D,G)

The GAN Framework

The generator G maps noise z to data space, while the discriminator D classifies real vs generated samples. They play a minimax game: min_G max_D V(D, G).

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Simple GAN on 2D Data

def generate_real_data(n=1000):
    """Mixture of two Gaussians."""
    centers = np.array([[-2, -2], [2, 2]])
    labels = np.random.randint(0, 2, n)
    data = centers[labels] + np.random.randn(n, 2) * 0.5
    return data.astype(np.float32)

class Generator(nn.Module):
    def __init__(self, latent_dim=2, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(hidden),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(hidden),
            nn.Linear(hidden, 2)
        )
    
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

# Training loop
def train_gan(n_epochs=2000, batch_size=64, latent_dim=2, lr=0.0002):
    G = Generator(latent_dim).to(device)
    D = Discriminator().to(device)
    
    opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    real_data = generate_real_data(1000)
    dataset = TensorDataset(torch.FloatTensor(real_data))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    d_losses, g_losses = [], []
    
    for epoch in range(n_epochs):
        for (real_batch,) in loader:
            real_batch = real_batch.to(device)
            batch_size_actual = real_batch.size(0)
            
            # Train Discriminator
            z = torch.randn(batch_size_actual, latent_dim).to(device)
            fake = G(z).detach()
            
            d_real = D(real_batch)
            d_fake = D(fake)
            
            d_loss = criterion(d_real, torch.ones_like(d_real)) + \
                     criterion(d_fake, torch.zeros_like(d_fake))
            
            opt_D.zero_grad()
            d_loss.backward()
            opt_D.step()
            
            # Train Generator
            z = torch.randn(batch_size_actual, latent_dim).to(device)
            fake = G(z)
            d_fake = D(fake)
            
            g_loss = criterion(d_fake, torch.ones_like(d_fake))
            
            opt_G.zero_grad()
            g_loss.backward()
            opt_G.step()
        
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        
        if (epoch + 1) % 500 == 0:
            print(f"Epoch {epoch+1}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
    
    return G, D, d_losses, g_losses

G, D, d_losses, g_losses = train_gan()

Mode Collapse

Mode collapse occurs when the generator learns to produce only a subset of the real data distribution.

def detect_mode_collapse(generator, n_samples=5000, latent_dim=2, n_bins=50):
    """Check if generator covers all modes."""
    z = torch.randn(n_samples, latent_dim).to(device)
    generated = generator(z).detach().cpu().numpy()
    
    # Check coverage using histogram
    hist, xedges, yedges = np.histogram2d(
        generated[:, 0], generated[:, 1], bins=n_bins, range=[[-5, 5], [-5, 5]]
    )
    
    # Modes are cells with significant density
    threshold = hist.max() * 0.01
    active_modes = (hist > threshold).sum()
    
    print(f"Active modes (out of {n_bins**2} cells): {active_modes}")
    print(f"Mode coverage: {active_modes / (n_bins**2):.3f}")
    
    return generated

generated = detect_mode_collapse(G)

WGAN: Wasserstein GAN

WGAN uses Wasserstein distance instead of JS divergence, providing more stable training.

class WGAN_Generator(nn.Module):
    def __init__(self, latent_dim=2, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, 2)
        )
    
    def forward(self, z):
        return self.net(z)

class WGAN_Critic(nn.Module):  # No sigmoid – outputs unbounded score
    def __init__(self, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, 1)
        )
    
    def forward(self, x):
        return self.net(x)

def train_wgan(n_epochs=3000, batch_size=64, latent_dim=2, n_critic=5, clip_value=0.01):
    G = WGAN_Generator(latent_dim).to(device)
    C = WGAN_Critic().to(device)
    
    opt_G = optim.Adam(G.parameters(), lr=0.00005, betas=(0.5, 0.9))
    opt_C = optim.Adam(C.parameters(), lr=0.00005, betas=(0.5, 0.9))
    
    real_data = generate_real_data(1000)
    dataset = TensorDataset(torch.FloatTensor(real_data))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(n_epochs):
        for i, (real_batch,) in enumerate(loader):
            real_batch = real_batch.to(device)
            
            # Train Critic (n_critic times per generator step)
            for _ in range(n_critic):
                z = torch.randn(real_batch.size(0), latent_dim).to(device)
                fake = G(z).detach()
                
                c_real = C(real_batch).mean()
                c_fake = C(fake).mean()
                c_loss = -(c_real - c_fake)  # Wasserstein distance
                
                opt_C.zero_grad()
                c_loss.backward()
                opt_C.step()
                
                # Weight clipping for Lipschitz constraint
                for p in C.parameters():
                    p.data.clamp_(-clip_value, clip_value)
            
            # Train Generator
            z = torch.randn(real_batch.size(0), latent_dim).to(device)
            fake = G(z)
            g_loss = -C(fake).mean()
            
            opt_G.zero_grad()
            g_loss.backward()
            opt_G.step()
        
        if (epoch + 1) % 500 == 0:
            print(f"Epoch {epoch+1}: Critic loss={c_loss.item():.4f}")
    
    return G, C

G_wgan, C_wgan = train_wgan()

Conditional GAN

Generate data conditioned on class labels.

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=2, n_classes=2, hidden=64):
        super().__init__()
        self.label_embed = nn.Embedding(n_classes, 2)
        self.net = nn.Sequential(
            nn.Linear(latent_dim + 2, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, 2)
        )
    
    def forward(self, z, labels):
        label_emb = self.label_embed(labels)
        x = torch.cat([z, label_emb], dim=1)
        return self.net(x)

class ConditionalDiscriminator(nn.Module):
    def __init__(self, n_classes=2, hidden=64):
        super().__init__()
        self.label_embed = nn.Embedding(n_classes, 2)
        self.net = nn.Sequential(
            nn.Linear(4, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        label_emb = self.label_embed(labels)
        x = torch.cat([x, label_emb], dim=1)
        return self.net(x)

# Conditional generation
latent_dim = 2
G_cond = ConditionalGenerator(latent_dim).to(device)
z = torch.randn(100, latent_dim).to(device)
labels = torch.zeros(100, dtype=torch.long).to(device)  # Generate class 0
generated = G_cond(z, labels)
print(f"Conditional generation shape: {generated.shape}")

GAN Training Best Practices

# Spectral Normalization for stable training
class SN_Discriminator(nn.Module):
    def __init__(self, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(2, hidden)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(hidden, hidden)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(hidden, 1)),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

# Two Time-Scale Update Rule (TTUR)
G_ttur = Generator().to(device)
D_ttur = SN_Discriminator().to(device)

opt_G = optim.Adam(G_ttur.parameters(), lr=0.0001, betas=(0.5, 0.999))  # slower
opt_D = optim.Adam(D_ttur.parameters(), lr=0.0004, betas=(0.5, 0.999))  # faster

print("Spectral normalization and TTUR applied")

Evaluating GANs

# FID-like simplified metric
def compute_fid(real_data, generated_data):
    """Simplified FID using mean and covariance."""
    real_mean = real_data.mean(axis=0)
    fake_mean = generated_data.mean(axis=0)
    
    real_cov = np.cov(real_data.T) + np.eye(real_data.shape[1]) * 1e-6
    fake_cov = np.cov(generated_data.T) + np.eye(generated_data.shape[1]) * 1e-6
    
    from scipy.linalg import sqrtm
    diff = real_mean - fake_mean
    covmean = sqrtm(real_cov @ fake_cov)
    
    fid = diff @ diff + np.trace(real_cov + fake_cov - 2 * covmean)
    return np.real(fid)

real_data = generate_real_data(1000)
z = torch.randn(1000, 2).to(device)
generated_data = G(z).detach().cpu().numpy()

fid = compute_fid(real_data, generated_data)
print(f"Simplified FID: {fid:.4f} (lower is better)")

Best Practices

  1. Use WGAN-GP – gradient penalty is more stable than weight clipping
  2. Spectral normalization – stabilizes discriminator training
  3. TTUR – different learning rates for G and D
  4. Batch normalization in generator only – discriminator uses layer norm
  5. Monitor mode collapse – check if generated samples cover all modes
  6. Evaluate with FID/IS – quantitative metrics, not just visual inspection

Summary

GANs learn to generate realistic data through adversarial training. Understand the minimax game, mode collapse, and stabilization techniques (WGAN, spectral norm, TTUR) to train reliable generators for data augmentation, style transfer, and creative applications.

Advertisement