This commit is contained in:
John 2023-05-27 18:19:31 +02:00
parent 2e237094ae
commit 8ad143a1bb
4 changed files with 110 additions and 120 deletions

View File

@ -1,11 +1,11 @@
# Wasserstein GANs with Gradient Penalty
Goals:
- Examine the cause and effect of an issue in GAN training known as mode collapse.
- Implement a Wasserstein GAN with Gradient Penalty to remedy mode collapse.
- Understand the motivation and condition needed for Wasserstein-Loss.
## Mode Collapse
Mode is distributions
@ -17,39 +17,41 @@ A mode in a distribution of data is just an area with a high concentration of ob
Modes are peaks in the distribution of features amd are typical with real-world datasets.
![Modes](./Images/ModeCollapse.png)
Collapsing to one mode or fewer modes and some of the modes disappearing.
Collapsing to one mode or fewer modes and some of the modes disappearing.
**Mode collapse occurs when the generator gets stuck generating one mode.** (Generator is at a local minimum) The discriminator will eventually learn to differentiate the generator's fakes when this happens and outskill it, ending the model's learning.
## Problem with BCE Loss
With BCE loss GANs are prone to mode collapse and other problems.
GANs trained with BCE loss are susceptible to vanishing gradient problems
The higher this cost value (of BCE) is, the worse the discriminator is doing at it.
![BCE in GANs](./Images/BCEInGANs.png)
Often called a **minimax game**
GANs try to make the generated distribution look similar to the real one by minimizing the underlying cost function that measures how different the distributions are.
During training it's possible for the discriminator to outperform the generator, very possible, in fact, quite common. But at the beginning of training, this isn't such a big problem because the discriminator isn't that good. It has trouble distinguishing the generated and real distributions. There's some overlap and it's not quite sure. As a result, it's able to give useful feedback in the form of a non-zero gradient back to the generator. However, as it gets better at training, it starts to delineate (afbakenen) the generated and real distributions a little bit more such that it can start distinguishing them much more. Where the real distribution will be centered around one and the generated distribution will start to approach zero. As a result, when it's starting to get better, as this discriminator is getting better, it'll start giving less informative feedback. In fact, it might give gradients closer to zero, and that becomes unhelpful for the generator because then the generator doesn't know how to improve. This is how the vanishing gradient problem will arise.
What is the problem with using BCE Loss?
The discriminator does not output useful gradients (feedback) for the generator when the real/fake distributions are far apart.
## Earth Movers Distance (EMD)
Is a cost function => solves vanishing gradient problem of BCE.
What does Earth Movers distance measure?
Earth movers distance is a measure of how different two distributions are by estimating the effort it takes to make the generated distribution equal to the real one.
With Earth mover's distance there's no such ceiling to the zero and one. So the cost function continues to grow regardless of how far apart these distributions are. The gradient of this measure won't approach zero and as a result, GANs are less prone to vanishing gradient problems
With Earth mover's distance there's no such ceiling to the zero and one. So the cost function continues to grow regardless of how far apart these distributions are. The gradient of this measure won't approach zero and as a result, GANs are less prone to vanishing gradient problems
![EDM](./Images/EDM.png)
Earth mover's distance is a function of the effort to make a distribution equal to another. So it depends on both distance and amount. Reduces the likelihood of mode collapse in GANs
## Wasserstein Loss (W-loss)
approximates the Earth Mover's Distance.
![W-loss vs BCE](./Images/W-lossvsBCE.png)
@ -61,6 +63,7 @@ W-loss doesn't have a vanishing gradient problem, and this will mitigate against
Discriminator => Critic (different name because it is no classifier anymore)
## Condition on Wasserstein Critic
There is a special condition that needs to be met by the critic.
W-Loss is a simple expression that computes the difference between the expected values of the critics output for the real examples x and its predictions on the fake examples g(z). The generator tries to minimize this expression, trying to get the generative examples to be as close as possible to the real examples while the critic wants to maximize this expression because it wants to differentiate between the reals and the fakes, it wants the distance to be as large as possible.
@ -76,24 +79,25 @@ For a function like the critics neural network to be at 1-Lipschitz Continuous,
The slope can not be greater than 1 at any point on a function in order for it to be 1-Lipschitz Continuous.
This condition on the critics neural network is important for W-Loss because it assures that the W-Loss function is not only continuous and differentiable, but also that it doesn't grow too much and maintain some stability during training.
This condition on the critics neural network is important for W-Loss because it assures that the W-Loss function is not only continuous and differentiable, but also that it doesn't grow too much and maintain some stability during training.
## 1-Lipschitz Continuity Enforcement
One Lipschitz continuity or 1-L continuity of the critic neural network in your Wasserstein loss and gain ensures that Wasserstein loss is valid.
How enforce:
- **weight clipping**: the weights of the critics neural network are forced to take values between a fixed interval.
After you update the weights during gradient descent, you actually will clip any weights outside of the desired interval. Basically what that means is that weights over that interval, either too high or too low, will be set to the maximum or the minimum amount allowed.
There's a lot of hyperparameter tuning involved.
There's a lot of hyperparameter tuning involved.
- **gradient penalty** (works better)
all you need to do is add a regularization term to your loss function. What this regularization term does to your W loss function, is that it penalizes the critic when it's gradient norm is higher than one.
![Gradient Penalty](./Images/Gradient-penalty.png)
Gradient penalty during implementation, of course, all you do is sample some points by interpolating between real and fake examples. Sample all points is not practical.
![Gradient Penalty](./Images/Gradient_Penalty_1.png)
It's on X hat that you want to get the critics gradient to be less than or equal to one.
![Put it all together](./Images/Recap.png)
![Put it all together](./Images/Recap.png)

View File

@ -62,16 +62,7 @@
"colab_type": "code",
"id": "JfkorNJrnmNO"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/john/miniforge3/envs/torch/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
@ -119,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
@ -206,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
@ -336,89 +327,7 @@
"id": "IFLQ039u-qdu",
"outputId": "2969e573-0b53-49e0-b1e9-ac6058d5d6b2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 9912422/9912422 [00:00<00:00, 19522421.99it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 28881/28881 [00:00<00:00, 1784267.34it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1648877/1648877 [00:00<00:00, 13585504.76it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4542/4542 [00:00<00:00, 32509434.76it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"outputs": [],
"source": [
"n_epochs = 100\n",
"z_dim = 64\n",
@ -498,7 +407,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {
"colab": {},
"colab_type": "code",
@ -544,7 +453,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {
"colab": {},
"colab_type": "code",
@ -557,14 +466,6 @@
"text": [
"Success!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/john/miniforge3/envs/torch/lib/python3.11/site-packages/torch/nn/modules/conv.py:459: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1682343995622/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
@ -615,7 +516,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {
"colab": {},
"colab_type": "code",
@ -650,7 +551,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {
"colab": {},
"colab_type": "code",
@ -710,7 +611,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 12,
"metadata": {
"colab": {},
"colab_type": "code",
@ -736,7 +637,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -784,7 +685,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 14,
"metadata": {
"colab": {},
"colab_type": "code",
@ -814,7 +715,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 15,
"metadata": {
"colab": {},
"colab_type": "code",

View File

@ -0,0 +1,85 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 9\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorchvision\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mtransforms\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mT\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[39m# !wget 'https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxleHBsb3JlLWZlZWR8NHx8fGVufDB8fHx8&w=1000&q=80' -O white_horse.jpg\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[39m# white_torch = torchvision.io.read_image('white_horse.jpg')\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m t \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mrandint(low\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m,high\u001b[39m=\u001b[39;49m\u001b[39m255\u001b[39;49m,size\u001b[39m=\u001b[39;49m(\u001b[39m3\u001b[39;49m,\u001b[39m128\u001b[39;49m,\u001b[39m128\u001b[39;49m), dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49muint8, device\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mcuda\u001b[39;49m\u001b[39m'\u001b[39;49m)\n\u001b[1;32m 10\u001b[0m T\u001b[39m.\u001b[39mToPILImage()(t)\n",
"File \u001b[0;32m~/miniforge3/envs/torch/lib/python3.11/site-packages/torch/cuda/__init__.py:247\u001b[0m, in \u001b[0;36m_lazy_init\u001b[0;34m()\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mCUDA_MODULE_LOADING\u001b[39m\u001b[39m'\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m os\u001b[39m.\u001b[39menviron:\n\u001b[1;32m 246\u001b[0m os\u001b[39m.\u001b[39menviron[\u001b[39m'\u001b[39m\u001b[39mCUDA_MODULE_LOADING\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mLAZY\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m--> 247\u001b[0m torch\u001b[39m.\u001b[39;49m_C\u001b[39m.\u001b[39;49m_cuda_init()\n\u001b[1;32m 248\u001b[0m \u001b[39m# Some of the queued calls may reentrantly call _lazy_init();\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[39m# we need to just return without initializing in that case.\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[39m# However, we must not let any *other* threads in!\u001b[39;00m\n\u001b[1;32m 251\u001b[0m _tls\u001b[39m.\u001b[39mis_initializing \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
"\u001b[0;31mRuntimeError\u001b[0m: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero."
]
}
],
"source": [
"import torch\n",
"import torchvision\n",
"from torchvision.io import read_image\n",
"import torchvision.transforms as T\n",
"\n",
"# !wget 'https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxleHBsb3JlLWZlZWR8NHx8fGVufDB8fHx8&w=1000&q=80' -O white_horse.jpg\n",
"# white_torch = torchvision.io.read_image('white_horse.jpg')\n",
"\n",
"t = torch.randint(low=0,high=255,size=(3,128,128), dtype=torch.uint8, device='cuda')\n",
"T.ToPILImage()(t)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/john/miniforge3/envs/torch/lib/python3.11/site-packages/torch/cuda/__init__.py:107: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343995622/work/c10/cuda/CUDAFunctions.cpp:109.)\n",
" return torch._C._cuda_getDeviceCount() > 0\n"
]
},
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB