Skip to content
Snippets Groups Projects
Commit 181420b4 authored by Benyahia Mohammed Oussama's avatar Benyahia Mohammed Oussama
Browse files

Edit README.md

parent a081e094
Branches
No related tags found
No related merge requests found
# TD 2: GAN & Diffusion
## MSO 3.4 Apprentissage Automatique
## MSO 3.4 Machine Learning
### Overview
Ce projet explore les modèles génératifs pour les images, en mettant l'accent sur les Generative Adversarial Networks (GANs) et les modèles de Diffusion. L'objectif est de comprendre leur implémentation, d'analyser des architectures spécifiques et d'appliquer différentes stratégies d'entraînement pour la génération et débriute d'images, avec et sans conditionnement.
This project explores generative models for images, focusing on Generative Adversarial Networks (GANs) and Diffusion models. The objective is to understand their implementation, analyze specific architectures, and apply different training strategies for generating and denoising images, both with and without conditioning.
---
## Part 1: DC-GAN
Dans cette partie, nous étudions les bases des Generative Adversarial Networks à travers un DCGAN. Nous nous appuyons sur le tutoriel suivant : [DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html).
In this section, we study the fundamentals of Generative Adversarial Networks through a Deep Convolutional GAN (DCGAN). We follow the tutorial: [DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html).
Nous générons des chiffres manuscrits en utilisant le dataset MNIST disponible dans le package `torchvision` : [MNIST Dataset](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST).
We generate handwritten digits using the MNIST dataset available in the `torchvision` package: [MNIST Dataset](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST).
### Modifications Implémentées
- Adaptation du code du tutoriel pour fonctionner avec le dataset MNIST.
- Affichage des courbes de perte du générateur et du discriminateur en fonction des étapes de gradient.
- Comparaison des images générées avec les images du dataset MNIST.
### Implemented Modifications
- Adapted the tutorial's code to work with the MNIST dataset.
- Displayed loss curves for both the generator and the discriminator over training steps.
- Compared generated images with real MNIST dataset images.
#### Exemples d'images générées :
#### Examples of Generated Images:
![images d'exemples de chiffres générés par le DCGAN](images/generated_mnist1.png)
![Example images of digits generated by DCGAN](images/generated_mnist1.png)
---
## Question : Comment contrôler le chiffre généré ?
## Question: How to Control the Generated Digit?
Pour contrôler quel chiffre le générateur doit produire, nous implémentons un Conditional GAN (cGAN) avec les modifications suivantes :
To control which digit the generator produces, we implement a Conditional GAN (cGAN) with the following modifications:
### Modifications du Générateur
- Au lieu d'utiliser uniquement du bruit aléatoire, nous concaténons un label de classe (one-hot encodé ou intégré) avec le vecteur de bruit.
- Le générateur apprend ainsi à produire des chiffres spécifiques en fonction du label fourni.
### Generator Modifications
- Instead of using only random noise, we concatenate a class label (one-hot encoded or embedded) with the noise vector.
- This allows the generator to learn to produce specific digits based on the provided label.
### Modifications du Discriminateur
- Plutôt que de seulement distinguer le vrai du faux, le discriminateur est modifié pour classifier les images en chiffres 0-9 ou en images générées (fake).
- Il produit une distribution de probabilité sur 11 classes (10 chiffres + 1 pour les images générées).
### Discriminator Modifications
- Instead of just distinguishing real from fake, the discriminator is modified to classify images as digits (0-9) or as generated (fake).
- It outputs a probability distribution over 11 classes (10 digits + 1 for generated images).
### Mise à Jour du Processus d'Entraînement
- Le générateur est entraîné pour tromper le discriminateur tout en générant des images correspondant au label de classe correct.
- Une perte par entropie croisée catégorielle est utilisée pour le discriminateur au lieu d'une perte binaire, puisqu'il effectue une classification multi-classes.
- La fonction de perte encourage le générateur à produire des chiffres bien classifiés.
### Training Process Update
- The generator is trained to fool the discriminator while generating images that match the correct class label.
- A categorical cross-entropy loss is used for the discriminator instead of a binary loss since it performs multi-class classification.
- The loss function encourages the generator to produce well-classified digits.
---
## Implémentation d'un cGAN avec Discriminateur Multi-Class
## Implementing a cGAN with a Multi-Class Discriminator
Pour améliorer la génération d'images et éviter les ambiguïtés entre certains chiffres (ex: 3 vs 7), nous avons mis en place un discriminateur multi-classes qui classifie les images générées en l'une des 10 catégories de chiffres ou comme une image générée.
To enhance image generation and reduce ambiguities between similar digits (e.g., 3 vs 7), we introduce a multi-class discriminator that classifies generated images into one of the 10 digit categories or as fake.
### Comparaison des Algorithmes
### Algorithm Comparison
| Modèle | Description | Résultat |
| Model | Description | Result |
|--------|------------|----------|
| **cGAN** | Le générateur apprend à produire des images conditionnées sur le label de classe. Le discriminateur ne fait que distinguer le vrai du faux. | Peut générer des chiffres réalistes mais parfois ambiguës (ex: confusion entre 3 et 7). |
| **cGAN avec Discriminateur Multi-Class** | Le générateur produit des chiffres conditionnés sur le label, et le discriminateur apprend à classifier les images dans une des 10 catégories de chiffres ou comme fausses. | Améliore la qualité des images générées et réduit l’ambiguïté entre les chiffres. |
| **cGAN** | The generator learns to produce images conditioned on the class label. The discriminator only distinguishes real from fake. | Can generate realistic digits but sometimes ambiguous (e.g., confusion between 3 and 7). |
| **cGAN with Multi-Class Discriminator** | The generator produces class-conditioned digits, and the discriminator learns to classify images into one of 10 digit categories or as fake. | Improves image quality and reduces digit ambiguity. |
#### Exemples d'images générées de numéro (3) par le cGAN avec Discriminateur (real/fake) :
![Images générées de numéro (3) par le cGAN avec Discriminateur (real/fake)](images/generated_mnist_num3_1_.png)
#### Exemples d'images générées de numéro (3) par le cGAN avec Discriminateur Multi-Class :
![Images générées de numéro (3) par le cGAN avec Discriminateur Multi-Class](images/generated_mnist_num3_2_.png)
#### Examples of Digit (3) Generated by cGAN with Real/Fake Discriminator:
![Images generated for digit (3) by cGAN with Real/Fake Discriminator](images/generated_mnist_num3_1_.png)
## Conclusion
- Les GANs permettent de générer des chiffres manuscrits réalistes.
- L'ajout d'un conditionnement via un cGAN permet de contrôler le chiffre généré.
- L'utilisation d'un discriminateur multi-classes améliore la différenciation entre les chiffres et réduit les ambiguïtés.
#### Examples of Digit (3) Generated by cGAN with Multi-Class Discriminator:
![Images generated for digit (3) by cGAN with Multi-Class Discriminator](images/generated_mnist_num3_2_.png)
### Références
- [Tutoriel DCGAN](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Dataset MNIST](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST)
---
## Conclusion
- GANs enable the generation of realistic handwritten digits.
- Adding conditioning via a cGAN allows control over the generated digit.
- Using a multi-class discriminator improves digit differentiation and reduces ambiguities.
### References
- [DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [MNIST Dataset](https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST)
## Part 2: Conditional GAN (cGAN) with U-Net
......@@ -191,7 +197,7 @@ question : how many learnable parameters this neural network has ?:
Diffusion models are a fascinating category of generative models that focus on iteratively transforming random noise into realistic data. The reverse diffusion process starts from noisy data, and with the help of a trained neural network, it gradually denoises the data, ultimately generating high-quality, detailed images. These models have been gaining popularity due to their ability to surpass GANs in generating diverse and high-quality images.
### Overview of Diffusion Models
In the context of this project, we will focus on **DDPMs** (Denoising Diffusion Probabilistic Models), which are widely used for generating images from noise. The key idea is to apply noise progressively over several timesteps, and then train a neural network to reverse this process. By doing so, the model learns to generate realistic images by denoising noisy samples.
In the context of this project, we will focus on **DDPMs** (Denoising Diffusion Probabilistic Models), which are widely used for predicting images noise. The key idea is to apply noise progressively over several timesteps, and then train a neural network to reverse this process. By doing so, the model learns to pridict the noise of the image to desoise it.
### The Diffusion Process
- **Forward Diffusion Process**: Starting with a real image, noise is gradually added to the image at each timestep. The amount of noise increases with each step, leading to a more noisy image as the timesteps increase. At the maximum timestep, the image is essentially pure noise.
......@@ -201,24 +207,32 @@ In the context of this project, we will focus on **DDPMs** (Denoising Diffusion
### Noise Scheduler
To control the diffusion process, we will create a **noise scheduler**. This scheduler defines how much noise is added to the image at each timestep. We will train the model on the MNIST dataset, which is also used in Part 1 of this project.
### Architecture for Diffusion Model: U-Net
For both the generator in the **Conditional GAN (cGAN)** and the **Diffusion model**, we will utilize a **U-Net architecture**. This is a popular model in image-to-image tasks and is well-suited for tasks requiring pixel-level precision, such as image generation and denoising.
### Architecture for Diffusion Model:
UNet2DModel (Diffusion Model)
This UNet is designed for denoising diffusion probabilistic models (DDPMs), which progressively remove noise from images. Differences include:
#### **U-Net Overview:**
- **Encoder-Decoder Structure**: The encoder progressively reduces the image size to extract features, while the decoder reconstructs the image from these features.
- **Skip Connections**: U-Net has skip connections that link corresponding layers of the encoder and decoder. This allows the decoder to access both high-level and low-level features, improving the image reconstruction quality.
For the **Diffusion U-Net**, the architecture will be slightly different:
- **Time Conditioning**: A key difference is that the model will receive the current timestep as input, in addition to the noisy image. This allows the network to adjust its denoising process based on how much noise is present.
- **ResNet Blocks**: In place of simple convolution layers, the model uses ResNet blocks (which consist of GroupNorm and SiLU activation), making the model more robust and stable.
Time Conditioning: The time_proj and time_embedding modules encode timesteps, which are crucial for diffusion models to learn the progressive denoising process.
ResNet Blocks Instead of Simple Conv Layers: Each downsampling and upsampling step includes ResnetBlock2D, which has GroupNorm + SiLU (Swish) activation, making it more robust than standard convolution layers.
SiLU (Swish) Activation: Used instead of LeakyReLU/ReLU, offering smooth gradients.
GroupNorm Instead of BatchNorm: More stable for diffusion-based models.
#### **PatchGAN Discriminator for Diffusion Model**
In contrast to traditional GANs, the PatchGAN discriminator works by classifying patches of the image rather than the whole image at once. This allows the model to focus on local details, leading to more precise image generation.
### Training the Model
We will train the diffusion model on the MNIST dataset using the **diffusers** library, which provides tools for training and using diffusion models. We will compare the results of training for different epochs and assess the quality of the generated images.
### Comparison of the UNet Architectures for cGAN and Diffusion Models
bonus : try to train also unet :
UNet in cGAN Generator
This UNet is used as the generator in a Conditional GAN (cGAN), typically for image-to-image translation tasks. The key characteristics are:
Encoder-Decoder Structure: Uses downsampling (down1 to down7) with Conv2D + BatchNorm + LeakyReLU layers and upsampling (up7 to up1) with ConvTranspose2D + BatchNorm + ReLU.
Skip Connections: Each downsampling layer has a corresponding upsampling layer that concatenates feature maps (e.g., up6 receives outputs from down6).
Dropout in Some Layers: Helps regularize training.
LeakyReLU Activation in Downsampling: Helps with learning stable representations.
No Explicit Time Embedding: Since it’s not designed for diffusion models, it doesn’t incorporate timestep embeddings.
### Comparison of the UNet Architecture for cGAN and UNet2DModel (Diffusion Model)
for compairing wuth the deffussion :
| Feature | cGAN UNet | Diffusion UNet (DDPM) |
|-----------------------------|-----------------------------------------------|--------------------------------------------------|
| **Task** | Image-to-image translation | Image denoising (diffusion) |
......@@ -231,12 +245,12 @@ We will train the diffusion model on the MNIST dataset using the **diffusers** l
### Results
Here a visual results from the U-Net models:
Here a visual results from :
**Diffusion U-Net2D**
![Diffusion U-Net Example result](images/diffuse_denoise_mnist.png)
**Conditional GAN U-Net (cGAN)**
**Conditional U-Net (cGAN)**
![cGAN U-Net Example result](images/unet_denoise_mnist.png)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment