Skip to content
Snippets Groups Projects
Select Git revision
  • main default protected
1 result

README.md

Blame
  • README.md 17.31 KiB

    TD 2: GAN & Diffusion

    MSO 3.4 Machine Learning

    Overview

    This project explores generative models for images, focusing on Generative Adversarial Networks (GANs) and Diffusion models. The objective is to understand their implementation, analyze specific architectures, and apply different training strategies for generating and denoising images, both with and without conditioning.


    Part 1: DC-GAN

    In this section, we study the fundamentals of Generative Adversarial Networks through a Deep Convolutional GAN (DCGAN). We follow the tutorial: DCGAN Tutorial.

    We generate handwritten digits using the MNIST dataset available in the torchvision package: MNIST Dataset.

    Implemented Modifications

    • Adapted the tutorial's code to work with the MNIST dataset.
    • Displayed loss curves for both the generator and the discriminator over training steps.
    • Compared generated images with real MNIST dataset images.

    Examples of Generated Images:

    Example images of digits generated by DCGAN


    Question: How to Control the Generated Digit?

    To control which digit the generator produces, we implement a Conditional GAN (cGAN) with the following modifications:

    Generator Modifications

    • Instead of using only random noise, we concatenate a class label (one-hot encoded or embedded) with the noise vector.
    • This allows the generator to learn to produce specific digits based on the provided label.

    Discriminator Modifications

    • Instead of just distinguishing real from fake, the discriminator is modified to classify images as digits (0-9) or as generated (fake).
    • It outputs a probability distribution over 11 classes (10 digits + 1 for generated images).

    Training Process Update

    • The generator is trained to fool the discriminator while generating images that match the correct class label.
    • A categorical cross-entropy loss is used for the discriminator instead of a binary loss since it performs multi-class classification.
    • The loss function encourages the generator to produce well-classified digits.

    Implementing a cGAN with a Multi-Class Discriminator

    To enhance image generation and reduce ambiguities between similar digits (e.g., 3 vs 7), we introduce a multi-class discriminator that classifies generated images into one of the 10 digit categories or as fake.

    Algorithm Comparison

    Model Description Result
    cGAN The generator learns to produce images conditioned on the class label. The discriminator only distinguishes real from fake. Can generate realistic digits but sometimes ambiguous (e.g., confusion between 3 and 7).
    cGAN with Multi-Class Discriminator The generator produces class-conditioned digits, and the discriminator learns to classify images into one of 10 digit categories or as fake. Improves image quality and reduces digit ambiguity.

    Examples of Digit (3) Generated by cGAN with Real/Fake Discriminator:

    Images generated for digit (3) by cGAN with Real/Fake Discriminator

    Examples of Digit (3) Generated by cGAN with Multi-Class Discriminator:

    Images generated for digit (3) by cGAN with Multi-Class Discriminator


    Conclusion

    • GANs enable the generation of realistic handwritten digits.
    • Adding conditioning via a cGAN allows control over the generated digit.
    • Using a multi-class discriminator improves digit differentiation and reduces ambiguities.

    References

    Here is the corrected version with only the necessary adjustments:


    Part 2: Conditional GAN (cGAN) with U-Net

    Generator

    In the cGAN architecture, the generator chosen is a U-Net.

    U-Net Overview:

    • A U-Net takes an image as input and outputs another image.
    • It consists of two main parts: an encoder and a decoder.
      • The encoder reduces the image dimension to extract main features.
      • The decoder reconstructs the image using these features.
    • Unlike a simple encoder-decoder model, U-Net has skip connections that link encoder layers to corresponding decoder layers. These allow the decoder to use both high-frequency and low-frequency information.

    Architecture & Implementation:

    The encoder takes a colored picture (3 channels: RGB), processes it through a series of convolutional layers, and encodes the features. The decoder then reconstructs the image using transposed convolutional layers, utilizing skip connections to enhance details.

    architecture Unet

    Question:

    Knowing that the input and output images have a shape of 256x256 with 3 channels, what will be the dimension of the feature map "x8"?

    Answer: The dimension of the feature map x8 is [numBatch, 512, 32, 32].

    Question:

    Why are skip connections important in the U-Net architecture?

    Explanation:

    Skip connections link encoder and decoder layers, improving the model in several ways:

    • Preserving Spatial Resolution: Helps retain fine details that may be lost during encoding.
    • Preventing Information Loss: Transfers important features from the encoder to the decoder.
    • Improving Performance: Combines high-level and low-level features for better reconstruction.
    • Mitigating Vanishing Gradient: Eases training by allowing gradient flow through deeper layers.

    Discriminator

    In the cGAN architecture, we use a PatchGAN discriminator instead of a traditional binary classifier.

    PatchGAN Overview:

    • Instead of classifying the entire image as real or fake, PatchGAN classifies N × N patches of the image.
    • The size N depends on the number of convolutional layers in the network:
    Layers Patch Size
    1 16×16
    2 34×34
    3 70×70
    4 142×142
    5 286×286
    6 574×574

    For this project, we use a 70×70 PatchGAN.